|
|
@@ -22,6 +22,7 @@ from .modules import (
|
|
|
build_evaluater,
|
|
|
build_exportor,
|
|
|
)
|
|
|
+from .utils.flags import NEW_PREDICTOR
|
|
|
|
|
|
|
|
|
# TODO(gaotingquan): support _ModelBasedConfig
|
|
|
@@ -82,15 +83,19 @@ class _ModelBasedConfig(_BaseModel):
|
|
|
predict_kwargs = deepcopy(self._config.Predict)
|
|
|
|
|
|
model_dir = predict_kwargs.pop("model_dir", None)
|
|
|
- # if model_dir is None, using official
|
|
|
- model = self._model_name if model_dir is None else model_dir
|
|
|
|
|
|
device = self._config.Global.get("device")
|
|
|
kernel_option = predict_kwargs.pop("kernel_option", {})
|
|
|
kernel_option.update({"device": device})
|
|
|
|
|
|
pp_option = PaddlePredictorOption(self._model_name, **kernel_option)
|
|
|
- predictor = create_predictor(model, pp_option=pp_option)
|
|
|
+ if NEW_PREDICTOR:
|
|
|
+ predictor = create_predictor(
|
|
|
+ self._model_name, model_dir, pp_option=pp_option
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ model = self._model_name if model_dir is None else model_dir
|
|
|
+ predictor = create_predictor(model, pp_option=pp_option)
|
|
|
assert "input" in predict_kwargs
|
|
|
return predict_kwargs, predictor
|
|
|
|