Prechádzať zdrojové kódy

update hpi PIR_TRT code (#3864)

zhang-prog 7 mesiacov pred
rodič
commit
ee5dfab725
2 zmenil súbory, kde vykonal 10 pridanie a 5 odobranie
  1. 1 1
      paddlex/.version
  2. 9 4
      paddlex/inference/utils/hpi.py

+ 1 - 1
paddlex/.version

@@ -1 +1 @@
-3.0.0.rc0
+3.0.0.rc1

+ 9 - 4
paddlex/inference/utils/hpi.py

@@ -174,14 +174,19 @@ def suggest_inference_backend_and_config(
         return None, f"{repr(hpi_config.pdx_model_name)} is not a known model."
     supported_pseudo_backends = hpi_model_info_collection_for_env[
         hpi_config.pdx_model_name
-    ]
+    ].copy()
 
     # XXX
-    if (
+    if not (
         USE_PIR_TRT
         and importlib.util.find_spec("tensorrt")
         and ctypes.util.find_library("nvinfer")
     ):
+        if (
+            "paddle_tensorrt" in supported_pseudo_backends
+            or "paddle_tensorrt_fp16" in supported_pseudo_backends
+        ):
+            supported_pseudo_backends.append("paddle")
         if "paddle_tensorrt" in supported_pseudo_backends:
             supported_pseudo_backends.remove("paddle_tensorrt")
         if "paddle_tensorrt_fp16" in supported_pseudo_backends:
@@ -220,10 +225,10 @@ def suggest_inference_backend_and_config(
         pseudo_backend = backend_to_pseudo_backend["paddle"]
         assert pseudo_backend in (
             "paddle",
-            "paddle_tensorrt_fp32",
+            "paddle_tensorrt",
             "paddle_tensorrt_fp16",
         ), pseudo_backend
-        if pseudo_backend == "paddle_tensorrt_fp32":
+        if pseudo_backend == "paddle_tensorrt":
             suggested_backend_config.update({"run_mode": "trt_fp32"})
         elif pseudo_backend == "paddle_tensorrt_fp16":
             # TODO: Check if the target device supports FP16.