|
|
@@ -17,6 +17,7 @@ from abc import abstractmethod
|
|
|
import lazy_paddle as paddle
|
|
|
import numpy as np
|
|
|
|
|
|
+from ....utils.flags import FLAGS_json_format_model
|
|
|
from ....utils import logging
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
from ..utils.mixin import PPEngineMixin
|
|
|
@@ -53,10 +54,7 @@ class BasePaddlePredictor(BaseComponent, PPEngineMixin):
|
|
|
"""_create"""
|
|
|
from lazy_paddle.inference import Config, create_predictor
|
|
|
|
|
|
- use_pir = (
|
|
|
- hasattr(paddle.framework, "use_pir_api") and paddle.framework.use_pir_api()
|
|
|
- )
|
|
|
- model_postfix = ".json" if use_pir else ".pdmodel"
|
|
|
+ model_postfix = ".json" if FLAGS_json_format_model else ".pdmodel"
|
|
|
model_file = (self.model_dir / f"{self.model_prefix}{model_postfix}").as_posix()
|
|
|
params_file = (self.model_dir / f"{self.model_prefix}.pdiparams").as_posix()
|
|
|
config = Config(model_file, params_file)
|