Browse Source

support download ts pretrained

zhangyubo0722 1 year ago
parent
commit
9239ac279d

+ 2 - 2
paddlex/modules/base/exportor.py

@@ -20,7 +20,7 @@ from .build_model import build_model
 from ...utils.device import update_device_num, set_env_for_device
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
-from ...utils.logging import *
+from ...utils import logging
 
 
 def build_exportor(config: AttrDict) -> "BaseExportor":
@@ -71,7 +71,7 @@ class BaseExportor(ABC, metaclass=AutoRegisterABCMetaClass):
 
         config_path = Path(weight_path).parent / "config.yaml"
         if not config_path.exists():
-            warning(
+            logging.warning(
                 f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
             )
             config_path = None

+ 6 - 11
paddlex/modules/ts_anomaly_detection/exportor.py

@@ -17,6 +17,7 @@ from pathlib import Path
 
 from ..base import BaseExportor
 from .model_list import MODELS
+from ...utils import logging
 
 
 class TSADExportor(BaseExportor):
@@ -35,16 +36,10 @@ class TSADExportor(BaseExportor):
             config_path (str): The path to the config
 
         """
-        self.uncompress_tar_file()
         config_path = Path(self.export_config.weight_path).parent.parent / "config.yaml"
-        return config_path
-
-    def uncompress_tar_file(self):
-        """unpackage the tar file containing training outputs and update weight path"""
-        if tarfile.is_tarfile(self.export_config.weight_path):
-            dest_path = Path(self.export_config.weight_path).parent
-            with tarfile.open(self.export_config.weight_path, "r") as tar:
-                tar.extractall(path=dest_path)
-            self.export_config.weight_path = dest_path.joinpath(
-                "best_accuracy.pdparams/best_model/model.pdparams"
+        if not config_path.exists():
+            logging.warning(
+                f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
             )
+            config_path = None
+        return config_path

+ 6 - 11
paddlex/modules/ts_classification/exportor.py

@@ -17,6 +17,7 @@ from pathlib import Path
 
 from ..base import BaseExportor
 from .model_list import MODELS
+from ...utils import logging
 
 
 class TSCLSExportor(BaseExportor):
@@ -35,16 +36,10 @@ class TSCLSExportor(BaseExportor):
             config_path (str): The path to the config
 
         """
-        self.uncompress_tar_file()
         config_path = Path(self.export_config.weight_path).parent.parent / "config.yaml"
-        return config_path
-
-    def uncompress_tar_file(self):
-        """unpackage the tar file containing training outputs and update weight path"""
-        if tarfile.is_tarfile(self.export_config.weight_path):
-            dest_path = Path(self.export_config.weight_path).parent
-            with tarfile.open(self.export_config.weight_path, "r") as tar:
-                tar.extractall(path=dest_path)
-            self.export_config.weight_path = dest_path.joinpath(
-                "best_accuracy.pdparams/best_model/model.pdparams"
+        if not config_path.exists():
+            logging.warning(
+                f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
             )
+            config_path = None
+        return config_path

+ 6 - 11
paddlex/modules/ts_forecast/exportor.py

@@ -17,6 +17,7 @@ from pathlib import Path
 
 from ..base import BaseExportor
 from .model_list import MODELS
+from ...utils import logging
 
 
 class TSFCExportor(BaseExportor):
@@ -35,16 +36,10 @@ class TSFCExportor(BaseExportor):
             config_path (str): The path to the config
 
         """
-        self.uncompress_tar_file()
         config_path = Path(self.export_config.weight_path).parent.parent / "config.yaml"
-        return config_path
-
-    def uncompress_tar_file(self):
-        """unpackage the tar file containing training outputs and update weight path"""
-        if tarfile.is_tarfile(self.export_config.weight_path):
-            dest_path = Path(self.export_config.weight_path).parent
-            with tarfile.open(self.export_config.weight_path, "r") as tar:
-                tar.extractall(path=dest_path)
-            self.export_config.weight_path = dest_path.joinpath(
-                "best_accuracy.pdparams/best_model/model.pdparams"
+        if not config_path.exists():
+            logging.warning(
+                f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
             )
+            config_path = None
+        return config_path

+ 2 - 2
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -221,11 +221,11 @@ class TSModel(BaseModel):
         self, weight_path: str, save_dir: str = None, device: str = "gpu", **kwargs
     ):
         """export"""
-        weight_path = abspath(weight_path)
+        if not weight_path.startswith(("http://", "https://")):
+            weight_path = abspath(weight_path)
         save_dir = abspath(save_dir)
         cli_args = []
 
-        weight_path = abspath(weight_path)
         cli_args.append(CLIArgument("--checkpoints", weight_path))
         if save_dir is not None:
             save_dir = abspath(save_dir)