zhangyubo0722 10 mesi fa
parent
commit
ea2bbf17d9
1 ha cambiato i file con 8 aggiunte e 3 eliminazioni
  1. 8 3
      paddlex/model.py

+ 8 - 3
paddlex/model.py

@@ -22,6 +22,7 @@ from .modules import (
     build_evaluater,
     build_exportor,
 )
+from .utils.flags import NEW_PREDICTOR
 
 
 # TODO(gaotingquan): support _ModelBasedConfig
@@ -82,15 +83,19 @@ class _ModelBasedConfig(_BaseModel):
         predict_kwargs = deepcopy(self._config.Predict)
 
         model_dir = predict_kwargs.pop("model_dir", None)
-        # if model_dir is None, using official
-        model = self._model_name if model_dir is None else model_dir
 
         device = self._config.Global.get("device")
         kernel_option = predict_kwargs.pop("kernel_option", {})
         kernel_option.update({"device": device})
 
         pp_option = PaddlePredictorOption(self._model_name, **kernel_option)
-        predictor = create_predictor(model, pp_option=pp_option)
+        if NEW_PREDICTOR:
+            predictor = create_predictor(
+                self._model_name, model_dir, pp_option=pp_option
+            )
+        else:
+            model = self._model_name if model_dir is None else model_dir
+            predictor = create_predictor(model, pp_option=pp_option)
         assert "input" in predict_kwargs
         return predict_kwargs, predictor