|
|
@@ -17,6 +17,7 @@ from pathlib import Path
|
|
|
|
|
|
from ..base import BaseExportor
|
|
|
from .model_list import MODELS
|
|
|
+from ...utils import logging
|
|
|
|
|
|
|
|
|
class TSCLSExportor(BaseExportor):
|
|
|
@@ -35,16 +36,10 @@ class TSCLSExportor(BaseExportor):
|
|
|
config_path (str): The path to the config
|
|
|
|
|
|
"""
|
|
|
- self.uncompress_tar_file()
|
|
|
config_path = Path(self.export_config.weight_path).parent.parent / "config.yaml"
|
|
|
- return config_path
|
|
|
-
|
|
|
- def uncompress_tar_file(self):
|
|
|
- """unpackage the tar file containing training outputs and update weight path"""
|
|
|
- if tarfile.is_tarfile(self.export_config.weight_path):
|
|
|
- dest_path = Path(self.export_config.weight_path).parent
|
|
|
- with tarfile.open(self.export_config.weight_path, "r") as tar:
|
|
|
- tar.extractall(path=dest_path)
|
|
|
- self.export_config.weight_path = dest_path.joinpath(
|
|
|
- "best_accuracy.pdparams/best_model/model.pdparams"
|
|
|
+ if not config_path.exists():
|
|
|
+ logging.warning(
|
|
|
+ f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
|
|
|
)
|
|
|
+ config_path = None
|
|
|
+ return config_path
|