فهرست منبع

separation model_dir (#2857)

zhangyubo0722 10 ماه پیش
والد
کامیت
664b86af1a
2فایلهای تغییر یافته به همراه15 افزوده شده و 5 حذف شده
  1. 11 3
      paddlex/inference/models_new/__init__.py
  2. 4 2
      paddlex/model.py

+ 11 - 3
paddlex/inference/models_new/__init__.py

@@ -76,7 +76,8 @@ def _create_hp_predictor(
 
 
 def create_predictor(
-    model: str,
+    model_name: str,
+    model_dir: Optional[str] = None,
     device=None,
     pp_option=None,
     use_hpip: bool = False,
@@ -84,9 +85,16 @@ def create_predictor(
     *args,
     **kwargs,
 ) -> BasePredictor:
-    model_dir = check_model(model)
+    if model_dir is None:
+        model_dir = check_model(model_name)
+    else:
+        assert Path(model_dir).exists(), f"{model_dir} is not exists!"
+        model_dir = Path(model_dir)
     config = BasePredictor.load_config(model_dir)
-    model_name = config["Global"]["model_name"]
+    assert (
+        model_name == config["Global"]["model_name"]
+    ), f"Model name mismatch,please input the correct model dir."
+
     if use_hpip:
         return _create_hp_predictor(
             model_name=model_name,

+ 4 - 2
paddlex/model.py

@@ -25,8 +25,10 @@ from .modules import (
 
 
 # TODO(gaotingquan): support _ModelBasedConfig
-def create_model(model=None, *args, **kwargs):
-    return _ModelBasedInference(model, *args, **kwargs)
+def create_model(model_name, model_dir=None, *args, **kwargs):
+    return _ModelBasedInference(
+        model_name=model_name, model_dir=model_dir, *args, **kwargs
+    )
 
 
 class _BaseModel: