Browse Source

update ts export and infer models

Sunflower7788 1 năm trước cách đây
mục cha
commit
3391cda2f4
27 tập tin đã thay đổi với 309 bổ sung63 xóa
  1. 5 2
      paddlex/configs/ts_anomaly_detection/AutoEncoder_ad.yaml
  2. 5 2
      paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml
  3. 5 2
      paddlex/configs/ts_anomaly_detection/Nonstationary_ad.yaml
  4. 5 2
      paddlex/configs/ts_anomaly_detection/PatchTST_ad.yaml
  5. 5 2
      paddlex/configs/ts_anomaly_detection/TimesNet_ad.yaml
  6. 5 2
      paddlex/configs/ts_classification/TimesNet_cls.yaml
  7. 5 2
      paddlex/configs/ts_forecast/DLinear.yaml
  8. 5 2
      paddlex/configs/ts_forecast/NLinear.yaml
  9. 5 2
      paddlex/configs/ts_forecast/Nonstationary.yaml
  10. 5 2
      paddlex/configs/ts_forecast/PatchTST.yaml
  11. 5 2
      paddlex/configs/ts_forecast/RLinear.yaml
  12. 5 2
      paddlex/configs/ts_forecast/TiDE.yaml
  13. 5 2
      paddlex/configs/ts_forecast/TimesNet.yaml
  14. 9 1
      paddlex/modules/__init__.py
  15. 13 0
      paddlex/modules/base/predictor/utils/official_models.py
  16. 2 1
      paddlex/modules/ts_anomaly_detection/__init__.py
  17. 15 5
      paddlex/modules/ts_anomaly_detection/evaluator.py
  18. 51 0
      paddlex/modules/ts_anomaly_detection/exportor.py
  19. 6 6
      paddlex/modules/ts_anomaly_detection/trainer.py
  20. 2 1
      paddlex/modules/ts_classification/__init__.py
  21. 15 5
      paddlex/modules/ts_classification/evaluator.py
  22. 51 0
      paddlex/modules/ts_classification/exportor.py
  23. 6 6
      paddlex/modules/ts_classification/trainer.py
  24. 2 1
      paddlex/modules/ts_forecast/__init__.py
  25. 15 5
      paddlex/modules/ts_forecast/evaluator.py
  26. 51 0
      paddlex/modules/ts_forecast/exportor.py
  27. 6 6
      paddlex/modules/ts_forecast/trainer.py

+ 5 - 2
paddlex/configs/ts_anomaly_detection/AutoEncoder_ad.yaml

@@ -27,6 +27,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/best_accuracy.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_ad/ts_anomaly_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_ad.csv"

+ 5 - 2
paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml

@@ -27,6 +27,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/best_accuracy.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_ad/ts_anomaly_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_ad.csv"

+ 5 - 2
paddlex/configs/ts_anomaly_detection/Nonstationary_ad.yaml

@@ -27,6 +27,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/best_accuracy.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_ad/ts_anomaly_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_ad.csv"

+ 5 - 2
paddlex/configs/ts_anomaly_detection/PatchTST_ad.yaml

@@ -27,6 +27,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/best_accuracy.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_ad/ts_anomaly_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_ad.csv"

+ 5 - 2
paddlex/configs/ts_anomaly_detection/TimesNet_ad.yaml

@@ -27,6 +27,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/TimesNet_ad.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_ad/ts_anomaly_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_ad.csv"

+ 5 - 2
paddlex/configs/ts_classification/TimesNet_cls.yaml

@@ -27,6 +27,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/TimesNet_cls.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_cls/ts_classify_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_cls.csv"

+ 5 - 2
paddlex/configs/ts_forecast/DLinear.yaml

@@ -28,6 +28,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/DLinear.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_fc/ts_dataset_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_fc.csv"

+ 5 - 2
paddlex/configs/ts_forecast/NLinear.yaml

@@ -28,6 +28,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/NLinear.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_fc/ts_dataset_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_fc.csv"

+ 5 - 2
paddlex/configs/ts_forecast/Nonstationary.yaml

@@ -28,6 +28,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/Nonstationary.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_fc/ts_dataset_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_fc.csv"

+ 5 - 2
paddlex/configs/ts_forecast/PatchTST.yaml

@@ -28,6 +28,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/PatchTST.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_fc/ts_dataset_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_fc.csv"

+ 5 - 2
paddlex/configs/ts_forecast/RLinear.yaml

@@ -28,6 +28,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/RLinear.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_fc/ts_dataset_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_fc.csv"

+ 5 - 2
paddlex/configs/ts_forecast/TiDE.yaml

@@ -28,6 +28,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/TiDE.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_fc/ts_dataset_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_fc.csv"

+ 5 - 2
paddlex/configs/ts_forecast/TimesNet.yaml

@@ -28,6 +28,9 @@ Train:
 Evaluate:
   weight_path: "output/best_accuracy.pdparams.tar"
 
+Export:
+  weight_path: https://paddlets.bj.bcebos.com/dygraph/TimesNet.pdparams.tar
+
 Predict:
-  model_dir: "output/best_accuracy.pdparams.tar"
-  input_path: "/paddle/dataset/paddlex/ts_fc/ts_dataset_examples/test.csv"
+  model_dir: "output/inference"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/ts/demo_ts/ts_fc.csv"

+ 9 - 1
paddlex/modules/__init__.py

@@ -75,15 +75,23 @@ from .ts_anomaly_detection import (
     TSADDatasetChecker,
     TSADTrainer,
     TSADEvaluator,
+    TSADExportor,
     TSADPredictor,
 )
 from .ts_classification import (
     TSCLSDatasetChecker,
     TSCLSTrainer,
     TSCLSEvaluator,
+    TSCLSExportor,
     TSCLSPredictor,
 )
-from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator, TSFCPredictor
+from .ts_forecast import (
+    TSFCDatasetChecker,
+    TSFCTrainer,
+    TSFCEvaluator,
+    TSFCExportor,
+    TSFCPredictor,
+)
 
 from .base.predictor.transforms import image_common
 from .image_classification import transforms as cls_transforms

+ 13 - 0
paddlex/modules/base/predictor/utils/official_models.py

@@ -168,6 +168,19 @@ openatom_rec_svtrv2_ch_infer.tar",
     "PicoDet_layout_1x": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet-L_layout_infer.tar",
     "SLANet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/SLANet_infer.tar",
     "LaTeX_OCR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/LaTeX_OCR_rec_infer.tar",
+    "DLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/DLinear_infer.tar",
+    "NLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/NLinear_infer.tar",
+    "RLinear": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/RLinear_infer.tar",
+    "Nonstationary": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/Nonstationary_infer.tar",
+    "TimesNet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/TimesNet_infer.tar",
+    "TiDE": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/TiDE_infer.tar",
+    "PatchTST": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/PatchTST_infer.tar",
+    "DLinear_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/DLinear_ad_infer.tar",
+    "AutoEncoder_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/AutoEncoder_ad_infer.tar",
+    "Nonstationary_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/Nonstationary_ad_infer.tar",
+    "PatchTST_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/PatchTST_ad_infer.tar",
+    "TimesNet_ad": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/TimesNet_ad_infer.tar",
+    "TimesNet_cls": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/TimesNet_cls_infer.tar",
 }
 
 

+ 2 - 1
paddlex/modules/ts_anomaly_detection/__init__.py

@@ -16,4 +16,5 @@
 from .dataset_checker import TSADDatasetChecker
 from .trainer import TSADTrainer
 from .evaluator import TSADEvaluator
-from .predictor import TSADPredictor
+from .predictor import TSADPredictor, transforms
+from .exportor import TSADExportor

+ 15 - 5
paddlex/modules/ts_anomaly_detection/evaluator.py

@@ -25,6 +25,21 @@ class TSADEvaluator(BaseEvaluator):
 
     entities = MODELS
 
+    def get_config_path(self, weight_path):
+        """
+        get config path
+
+        Args:
+            weight_path (str): The path to the weight
+
+        Returns:
+            config_path (str): The path to the config
+
+        """
+        self.uncompress_tar_file()
+        config_path = Path(self.eval_config.weight_path).parent.parent / "config.yaml"
+        return config_path
+
     def update_config(self):
         """update evalution config"""
         self.pdx_config.update_dataset(self.global_config.dataset_dir, "TSADDataset")
@@ -40,11 +55,6 @@ class TSADEvaluator(BaseEvaluator):
                 "best_accuracy.pdparams/best_model/model.pdparams"
             )
 
-    def evaluate(self):
-        """firstly, update evaluation config, then evaluate model, finally return the evaluation result"""
-        self.uncompress_tar_file()
-        return super().evaluate()
-
     def get_eval_kwargs(self) -> dict:
         """get key-value arguments of model evalution function
 

+ 51 - 0
paddlex/modules/ts_anomaly_detection/exportor.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tarfile
+from pathlib import Path
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class TSADExportor(BaseExportor):
+    """Image Classification Model Exportor"""
+
+    entities = MODELS
+
+    def get_config_path(self, weight_path):
+        """
+        get config path
+
+        Args:
+            weight_path (str): The path to the weight
+
+        Returns:
+            config_path (str): The path to the config
+
+        """
+        self.uncompress_tar_file()
+        config_path = Path(self.export_config.weight_path).parent.parent / "config.yaml"
+        return config_path
+
+    def uncompress_tar_file(self):
+        """unpackage the tar file containing training outputs and update weight path"""
+        if tarfile.is_tarfile(self.export_config.weight_path):
+            dest_path = Path(self.export_config.weight_path).parent
+            with tarfile.open(self.export_config.weight_path, "r") as tar:
+                tar.extractall(path=dest_path)
+            self.export_config.weight_path = dest_path.joinpath(
+                "best_accuracy.pdparams/best_model/model.pdparams"
+            )
+            

+ 6 - 6
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -180,12 +180,12 @@ class TSADTrainDeamon(BaseTrainDeamon):
                 "pdiparams": pdparams,
                 "pdiparams.info": "",
             }
-        self.update_inference_model(
-            model,
-            train_output,
-            train_output.joinpath(f"inference"),
-            result["models"][model_key],
-        )
+            self.update_inference_model(
+                model,
+                train_output,
+                train_output.joinpath(f"inference"),
+                result["models"][model_key],
+            )
 
     def update_inference_model(
         self, model, weight_path, export_save_dir, result_the_model

+ 2 - 1
paddlex/modules/ts_classification/__init__.py

@@ -16,4 +16,5 @@
 from .dataset_checker import TSCLSDatasetChecker
 from .trainer import TSCLSTrainer
 from .evaluator import TSCLSEvaluator
-from .predictor import TSCLSPredictor
+from .predictor import TSCLSPredictor, transforms
+from .exportor import TSCLSExportor

+ 15 - 5
paddlex/modules/ts_classification/evaluator.py

@@ -25,6 +25,21 @@ class TSCLSEvaluator(BaseEvaluator):
 
     entities = MODELS
 
+    def get_config_path(self, weight_path):
+        """
+        get config path
+
+        Args:
+            weight_path (str): The path to the weight
+
+        Returns:
+            config_path (str): The path to the config
+
+        """
+        self.uncompress_tar_file()
+        config_path = Path(self.eval_config.weight_path).parent.parent / "config.yaml"
+        return config_path
+
     def update_config(self):
         """update evalution config"""
         self.pdx_config.update_dataset(self.global_config.dataset_dir, "TSCLSDataset")
@@ -49,8 +64,3 @@ class TSCLSEvaluator(BaseEvaluator):
             self.eval_config.weight_path = dest_path.joinpath(
                 "best_accuracy.pdparams/best_model/model.pdparams"
             )
-
-    def evaluate(self):
-        """firstly, update evaluation config, then evaluate model, finally return the evaluation result"""
-        self.uncompress_tar_file()
-        return super().evaluate()

+ 51 - 0
paddlex/modules/ts_classification/exportor.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tarfile
+from pathlib import Path
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class TSCLSExportor(BaseExportor):
+    """Image Classification Model Exportor"""
+
+    entities = MODELS
+
+    def get_config_path(self, weight_path):
+        """
+        get config path
+
+        Args:
+            weight_path (str): The path to the weight
+
+        Returns:
+            config_path (str): The path to the config
+
+        """
+        self.uncompress_tar_file()
+        config_path = Path(self.export_config.weight_path).parent.parent / "config.yaml"
+        return config_path
+    
+    def uncompress_tar_file(self):
+        """unpackage the tar file containing training outputs and update weight path"""
+        if tarfile.is_tarfile(self.export_config.weight_path):
+            dest_path = Path(self.export_config.weight_path).parent
+            with tarfile.open(self.export_config.weight_path, "r") as tar:
+                tar.extractall(path=dest_path)
+            self.export_config.weight_path = dest_path.joinpath(
+                "best_accuracy.pdparams/best_model/model.pdparams"
+            )
+    

+ 6 - 6
paddlex/modules/ts_classification/trainer.py

@@ -175,12 +175,12 @@ class TSCLSTrainDeamon(BaseTrainDeamon):
                 "pdiparams": pdparams,
                 "pdiparams.info": "",
             }
-        self.update_inference_model(
-            model,
-            train_output,
-            train_output.joinpath(f"inference"),
-            result["models"][model_key],
-        )
+            self.update_inference_model(
+                model,
+                train_output,
+                train_output.joinpath(f"inference"),
+                result["models"][model_key],
+            )
 
     def update_inference_model(
         self, model, weight_path, export_save_dir, result_the_model

+ 2 - 1
paddlex/modules/ts_forecast/__init__.py

@@ -16,4 +16,5 @@
 from .dataset_checker import TSFCDatasetChecker
 from .trainer import TSFCTrainer
 from .evaluator import TSFCEvaluator
-from .predictor import TSFCPredictor
+from .predictor import TSFCPredictor, transforms
+from .exportor import TSFCExportor

+ 15 - 5
paddlex/modules/ts_forecast/evaluator.py

@@ -25,6 +25,21 @@ class TSFCEvaluator(BaseEvaluator):
 
     entities = MODELS
 
+    def get_config_path(self, weight_path):
+        """
+        get config path
+
+        Args:
+            weight_path (str): The path to the weight
+
+        Returns:
+            config_path (str): The path to the config
+
+        """
+        self.uncompress_tar_file()
+        config_path = Path(self.eval_config.weight_path).parent.parent / "config.yaml"
+        return config_path
+
     def update_config(self):
         """update evalution config"""
         self.pdx_config.update_dataset(self.global_config.dataset_dir, "TSDataset")
@@ -49,8 +64,3 @@ class TSFCEvaluator(BaseEvaluator):
             self.eval_config.weight_path = dest_path.joinpath(
                 "best_accuracy.pdparams/best_model/model.pdparams"
             )
-
-    def evaluate(self):
-        """firstly, update evaluation config, then evaluate model, finally return the evaluation result"""
-        self.uncompress_tar_file()
-        return super().evaluate()

+ 51 - 0
paddlex/modules/ts_forecast/exportor.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tarfile
+from pathlib import Path
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class TSFCExportor(BaseExportor):
+    """Image Classification Model Exportor"""
+
+    entities = MODELS
+
+    def get_config_path(self, weight_path):
+        """
+        get config path
+
+        Args:
+            weight_path (str): The path to the weight
+
+        Returns:
+            config_path (str): The path to the config
+
+        """
+        self.uncompress_tar_file()
+        config_path = Path(self.export_config.weight_path).parent.parent / "config.yaml"
+        return config_path
+
+    def uncompress_tar_file(self):
+        """unpackage the tar file containing training outputs and update weight path"""
+        if tarfile.is_tarfile(self.export_config.weight_path):
+            dest_path = Path(self.export_config.weight_path).parent
+            with tarfile.open(self.export_config.weight_path, "r") as tar:
+                tar.extractall(path=dest_path)
+            self.export_config.weight_path = dest_path.joinpath(
+                "best_accuracy.pdparams/best_model/model.pdparams"
+            )
+    

+ 6 - 6
paddlex/modules/ts_forecast/trainer.py

@@ -175,12 +175,12 @@ class TSFCTrainDeamon(BaseTrainDeamon):
                 "pdiparams": pdparams,
                 "pdiparams.info": "",
             }
-        self.update_inference_model(
-            model,
-            train_output,
-            train_output.joinpath(f"inference"),
-            result["models"][model_key],
-        )
+            self.update_inference_model(
+                model,
+                train_output,
+                train_output.joinpath(f"inference"),
+                result["models"][model_key],
+            )
 
     def update_inference_model(
         self, model, weight_path, export_save_dir, result_the_model