Przeglądaj źródła

add hpi installation (#2875)

* add hpi installation

* opt

* update hpi deps find links

* fix
zhang-prog 9 miesięcy temu
rodzic
commit
28531bb015
1 zmienionych plików z 44 dodań i 0 usunięć
  1. 44 0
      paddlex/paddlex_cli.py

+ 44 - 0
paddlex/paddlex_cli.py

@@ -209,6 +209,32 @@ def install(args):
                 [sys.executable, "-m", "pip", "install", "-r", str(req_file)]
             )
 
+    def _install_hpi_deps(device_type):
+        support_device_type = ["cpu", "gpu"]
+        if device_type not in support_device_type:
+            logging.error(
+                "HPI installation failed!\n"
+                "Supported device_type: %s. Your input device_type: %s.\n"
+                "Please ensure the device_type is correct.",
+                support_device_type,
+                device_type,
+            )
+            sys.exit(2)
+
+        if device_type == "cpu":
+            packages = ["ultra_infer_python", "paddlex_hpi"]
+        elif device_type == "gpu":
+            packages = ["ultra_infer_gpu_python", "paddlex_hpi"]
+
+        return subprocess.check_call(
+            [sys.executable, "-m", "pip", "install"]
+            + packages
+            + [
+                "--find-links",
+                "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/pipeline_deploy/high_performance_inference.md",
+            ]
+        )
+
     # Enable debug info
     os.environ["PADDLE_PDX_DEBUG"] = "True"
     # Disable eager initialization
@@ -232,6 +258,24 @@ def install(args):
         _install_paddle2onnx_deps()
         return
 
+    hpi_plugins = list(filter(lambda name: name.startswith("hpi-"), plugins))
+    if hpi_plugins:
+        for i in hpi_plugins:
+            plugins.remove(i)
+        if plugins:
+            logging.error("`hpi` cannot be used together with other plugins.")
+            sys.exit(2)
+        if len(hpi_plugins) > 1 or len(hpi_plugins[0].split("-")) != 2:
+            logging.error(
+                "Invalid HPI plugin installation format detected.\n"
+                "Correct format: paddlex --install hpi-<device_type>\n"
+                "Example: paddlex --install hpi-gpu"
+            )
+            sys.exit(2)
+        device_type = hpi_plugins[0].split("-")[1]
+        _install_hpi_deps(device_type=device_type)
+        return
+
     if plugins:
         repo_names = plugins
     elif len(plugins) == 0: