Explorar o código

add url input for ts (#1857)

Sunflower7788 hai 1 ano
pai
achega
158d9a76e7
Modificáronse 1 ficheiros con 12 adicións e 0 borrados
  1. 12 0
      paddlex/modules/ts_forecast/predictor.py

+ 12 - 0
paddlex/modules/ts_forecast/predictor.py

@@ -22,6 +22,8 @@ from ..base.build_model import build_model
 from ..base.predictor import BasePredictor
 from ...utils.errors import raise_unsupported_api_error, raise_model_not_found_error
 from .model_list import MODELS
+from ...utils.download import download
+from ...utils.cache import CACHE_DIR
 
 
 class TSFCPredictor(BasePredictor):
@@ -31,6 +33,7 @@ class TSFCPredictor(BasePredictor):
     def __init__(self, model_name, model_dir, kernel_option, output):
         """initialize
         """
+        model_dir = self._download_from_url(model_dir)
         self.model_dir = self.uncompress_tar_file(model_dir)
 
         self.device = kernel_option.get_device()
@@ -69,10 +72,19 @@ is not exist, use default instead.")
             raise_model_not_found_error(self.model_dir)
         return None
 
+    def _download_from_url(self, in_path):
+        if in_path.startswith("http"):
+            file_name = Path(in_path).name
+            save_path = Path(CACHE_DIR) / "predict_input" / file_name
+            download(in_path, save_path, overwrite=True)
+            return save_path.as_posix()
+        return in_path
+    
     def predict(self, input):
         """execute model predict
         """
         # self.update_config()
+        input['input_path'] = self._download_from_url(input['input_path'])
         result = self.pdx_model.predict(**input, **self.get_predict_kwargs())
         assert result.returncode == 0, f"Encountered an unexpected error({result.returncode}) in predicting!"
         return result