|
|
@@ -22,6 +22,8 @@ from ..base.build_model import build_model
|
|
|
from ..base.predictor import BasePredictor
|
|
|
from ...utils.errors import raise_unsupported_api_error, raise_model_not_found_error
|
|
|
from .model_list import MODELS
|
|
|
+from ...utils.download import download
|
|
|
+from ...utils.cache import CACHE_DIR
|
|
|
|
|
|
|
|
|
class TSFCPredictor(BasePredictor):
|
|
|
@@ -31,6 +33,7 @@ class TSFCPredictor(BasePredictor):
|
|
|
def __init__(self, model_name, model_dir, kernel_option, output):
|
|
|
"""initialize
|
|
|
"""
|
|
|
+ model_dir = self._download_from_url(model_dir)
|
|
|
self.model_dir = self.uncompress_tar_file(model_dir)
|
|
|
|
|
|
self.device = kernel_option.get_device()
|
|
|
@@ -69,10 +72,19 @@ is not exist, use default instead.")
|
|
|
raise_model_not_found_error(self.model_dir)
|
|
|
return None
|
|
|
|
|
|
+ def _download_from_url(self, in_path):
|
|
|
+ if in_path.startswith("http"):
|
|
|
+ file_name = Path(in_path).name
|
|
|
+ save_path = Path(CACHE_DIR) / "predict_input" / file_name
|
|
|
+ download(in_path, save_path, overwrite=True)
|
|
|
+ return save_path.as_posix()
|
|
|
+ return in_path
|
|
|
+
|
|
|
def predict(self, input):
|
|
|
"""execute model predict
|
|
|
"""
|
|
|
# self.update_config()
|
|
|
+ input['input_path'] = self._download_from_url(input['input_path'])
|
|
|
result = self.pdx_model.predict(**input, **self.get_predict_kwargs())
|
|
|
assert result.returncode == 0, f"Encountered an unexpected error({result.returncode}) in predicting!"
|
|
|
return result
|