|
|
@@ -76,7 +76,8 @@ def _create_hp_predictor(
|
|
|
|
|
|
|
|
|
def create_predictor(
|
|
|
- model: str,
|
|
|
+ model_name: str,
|
|
|
+ model_dir: Optional[str] = None,
|
|
|
device=None,
|
|
|
pp_option=None,
|
|
|
use_hpip: bool = False,
|
|
|
@@ -84,9 +85,16 @@ def create_predictor(
|
|
|
*args,
|
|
|
**kwargs,
|
|
|
) -> BasePredictor:
|
|
|
- model_dir = check_model(model)
|
|
|
+ if model_dir is None:
|
|
|
+ model_dir = check_model(model_name)
|
|
|
+ else:
|
|
|
+ assert Path(model_dir).exists(), f"{model_dir} is not exists!"
|
|
|
+ model_dir = Path(model_dir)
|
|
|
config = BasePredictor.load_config(model_dir)
|
|
|
- model_name = config["Global"]["model_name"]
|
|
|
+ assert (
|
|
|
+ model_name == config["Global"]["model_name"]
|
|
|
+ ), f"Model name mismatch,please input the correct model dir."
|
|
|
+
|
|
|
if use_hpip:
|
|
|
return _create_hp_predictor(
|
|
|
model_name=model_name,
|