gaotingquan 1 рік тому
батько
коміт
1569dfd7e4

+ 4 - 1
paddlex/configs/image_classification/PP-LCNet_x0_25.yaml

@@ -30,9 +30,12 @@ Evaluate:
   weight_path: "output/best_model.pdparams"
   log_interval: 1
 
+Export:
+  weight_path: "output/best_model.pdparams"
+
 Predict:
   model_dir: "output/best_model"
   input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_classification_001.jpg"
   kernel_option:
     run_mode: paddle
-    batch_size: 1
+    batch_size: 1

+ 3 - 0
paddlex/configs/image_classification/PP-LCNet_x1_0.yaml

@@ -30,6 +30,9 @@ Evaluate:
   weight_path: "output/best_model.pdparams"
   log_interval: 1
 
+Export:
+  weight_path: "output/best_model.pdparams"
+
 Predict:
   model_dir: "output/best_model"
   input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_classification_001.jpg"

+ 4 - 4
paddlex/engine.py

@@ -12,10 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
-from .modules.base import build_dataset_checker, build_trainer, build_evaluater, build_predictor
+from .modules.base import build_dataset_checker, build_trainer, build_evaluater, build_exportor, build_predictor
 from .utils.result_saver import try_except_decorator
 from .utils import config
 from .utils.errors import raise_unsupported_api_error
@@ -44,10 +43,11 @@ class Engine(object):
             evaluator = build_evaluater(self.config)
             return evaluator.evaluate()
         elif self.config.Global.mode == "export":
-            raise_unsupported_api_error("export", self.__class__)
+            exportor = build_exportor(self.config)
+            return exportor.export()
         elif self.config.Global.mode == "predict":
             predictor = build_predictor(self.config)
             return predictor.predict()
         else:
             raise_unsupported_api_error(f"{self.config.Global.mode}",
-                                        self.__class__)
+                                        self.__class__)

+ 8 - 8
paddlex/modules/__init__.py

@@ -13,16 +13,16 @@
 # limitations under the License.
 
 
-from .base import build_dataset_checker, build_trainer, build_evaluater, build_predictor, create_model, \
+from .base import build_dataset_checker, build_trainer, build_evaluater, build_exportor, build_predictor, create_model, \
 PaddleInferenceOption
-from .image_classification import ClsDatasetChecker, ClsTrainer, ClsEvaluator, ClsPredictor
-from .object_detection import COCODatasetChecker, DetTrainer, DetEvaluator, DetPredictor
-from .text_detection import TextDetDatasetChecker, TextDetTrainer, TextDetEvaluator, TextDetPredictor
-from .text_recognition import TextRecDatasetChecker, TextRecTrainer, TextRecEvaluator, TextRecPredictor
-from .table_recognition import TableRecDatasetChecker, TableRecTrainer, TableRecEvaluator, TableRecPredictor
-from .semantic_segmentation import SegDatasetChecker, SegTrainer, SegEvaluator, SegPredictor
+from .image_classification import ClsDatasetChecker, ClsTrainer, ClsEvaluator, ClsExportor, ClsPredictor
+from .object_detection import COCODatasetChecker, DetTrainer, DetEvaluator, DetExportor, DetPredictor
+from .text_detection import TextDetDatasetChecker, TextDetTrainer, TextDetEvaluator, TextDetExportor, TextDetPredictor
+from .text_recognition import TextRecDatasetChecker, TextRecTrainer, TextRecEvaluator, TextRecExportor, TextRecPredictor
+from .table_recognition import TableRecDatasetChecker, TableRecTrainer, TableRecEvaluator, TableRecExportor, TableRecPredictor
+from .semantic_segmentation import SegDatasetChecker, SegTrainer, SegEvaluator, SegExportor, SegPredictor
 from .instance_segmentation import COCOInstSegDatasetChecker, InstanceSegTrainer, InstanceSegEvaluator, \
-InstanceSegPredictor
+InstanceSegExportor, InstanceSegPredictor
 from .ts_anomaly_detection import TSADDatasetChecker, TSADTrainer, TSADEvaluator, TSADPredictor
 from .ts_classification import TSCLSDatasetChecker, TSCLSTrainer, TSCLSEvaluator, TSCLSPredictor
 from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator, TSFCPredictor

+ 1 - 2
paddlex/modules/base/__init__.py

@@ -12,9 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-
 from .dataset_checker import build_dataset_checker, BaseDatasetChecker
 from .trainer import build_trainer, BaseTrainer, BaseTrainDeamon
 from .evaluator import build_evaluater, BaseEvaluator
+from .exportor import build_exportor, BaseExportor
 from .predictor import build_predictor, BasePredictor, BaseTransform, PaddleInferenceOption, create_model

+ 7 - 8
paddlex/modules/base/evaluator.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 from pathlib import Path
 from abc import ABC, abstractmethod
@@ -53,11 +52,7 @@ class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
         self.eval_config = config.Evaluate
 
         config_path = self.get_config_path(self.eval_config.weight_path)
-        if not config_path.exists():
-            warning(
-                f"The config file(`{config_path}`) related to weight file(`{self.eval_config.weight_path}`) is not exist, use default instead."
-            )
-            config_path = None
+
         self.pdx_config, self.pdx_model = build_model(
             self.global_config.model, config_path=config_path)
 
@@ -74,7 +69,11 @@ class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
         """
 
         config_path = Path(weight_path).parent / "config.yaml"
-
+        if not config_path.exists():
+            warning(
+                f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
+            )
+            config_path = None
         return config_path
 
     def check_return(self, metrics: dict) -> bool:
@@ -95,7 +94,7 @@ class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
         return True
 
     def evaluate(self) -> dict:
-        """execute model training
+        """execute model evaluating
 
         Returns:
             dict: the evaluation metrics

+ 118 - 0
paddlex/modules/base/exportor.py

@@ -0,0 +1,118 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from pathlib import Path
+from abc import ABC, abstractmethod
+
+from .build_model import build_model
+from ...utils.device import get_device
+from ...utils.misc import AutoRegisterABCMetaClass
+from ...utils.config import AttrDict
+from ...utils.logging import *
+
+
+def build_exportor(config: AttrDict) -> "BaseExportor":
+    """build model exportor
+
+    Args:
+        config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
+
+    Returns:
+        BaseExportor: the exportor, which is subclass of BaseExportor.
+    """
+    model_name = config.Global.model
+    return BaseExportor.get(model_name)(config)
+
+
+class BaseExportor(ABC, metaclass=AutoRegisterABCMetaClass):
+    """ Base Model Exportor """
+
+    __is_base = True
+
+    def __init__(self, config):
+        """Initialize the instance.
+
+        Args:
+            config (AttrDict):  PaddleX pipeline config, which is loaded from pipeline yaml file.
+        """
+        super().__init__()
+        self.global_config = config.Global
+        self.export_config = config.Export
+
+        config_path = self.get_config_path(self.export_config.weight_path)
+
+        self.pdx_config, self.pdx_model = build_model(
+            self.global_config.model, config_path=config_path)
+
+    def get_config_path(self, weight_path):
+        """
+        get config path
+
+        Args:
+            weight_path (str): The path to the weight
+
+        Returns:
+            config_path (str): The path to the config
+
+        """
+
+        config_path = Path(weight_path).parent / "config.yaml"
+        if not config_path.exists():
+            warning(
+                f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
+            )
+            config_path = None
+
+        return config_path
+
+    def export(self) -> dict:
+        """execute model exporting
+
+        Returns:
+            dict: the export metrics
+        """
+        self.update_config()
+        export_result = self.pdx_model.export(**self.get_export_kwargs())
+        assert export_result.returncode == 0, f"Encountered an unexpected error({export_result.returncode}) in \
+exporting!"
+
+        return None
+
+    def get_device(self, using_device_number: int=None) -> str:
+        """get device setting from config
+
+        Args:
+            using_device_number (int, optional): specify device number to use.
+                Defaults to None, means that base on config setting.
+
+        Returns:
+            str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
+        """
+        # return get_device(
+        #     self.global_config.device, using_device_number=using_device_number)
+        return get_device("cpu")
+
+    def update_config(self):
+        """update export config
+        """
+        pass
+
+    def get_export_kwargs(self):
+        """get key-value arguments of model export function
+        """
+        return {
+            "weight_path": self.export_config.weight_path,
+            "save_dir": self.global_config.output
+        }

+ 1 - 1
paddlex/modules/image_classification/__init__.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from .trainer import ClsTrainer
 from .dataset_checker import ClsDatasetChecker
 from .evaluator import ClsEvaluator
+from .exportor import ClsExportor
 from .predictor import ClsPredictor, transforms

+ 21 - 0
paddlex/modules/image_classification/exportor.py

@@ -0,0 +1,21 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class ClsExportor(BaseExportor):
+    """ Image Classification Model Exportor """
+    entities = MODELS

+ 1 - 1
paddlex/modules/instance_segmentation/__init__.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from .dataset_checker import COCOInstSegDatasetChecker
 from .trainer import InstanceSegTrainer
 from .evaluator import InstanceSegEvaluator
 from .predictor import InstanceSegPredictor, transforms
+from .exportor import InstanceSegExportor

+ 21 - 0
paddlex/modules/instance_segmentation/exportor.py

@@ -0,0 +1,21 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class InstanceSegExportor(BaseExportor):
+    """ Instance Segmentation Model Exportor """
+    entities = MODELS

+ 1 - 1
paddlex/modules/object_detection/__init__.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from .trainer import DetTrainer
 from .dataset_checker import COCODatasetChecker
 from .evaluator import DetEvaluator
 from .predictor import DetPredictor, transforms
+from .exportor import DetExportor

+ 21 - 0
paddlex/modules/object_detection/exportor.py

@@ -0,0 +1,21 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class DetExportor(BaseExportor):
+    """ Object Detection Model Exportor """
+    entities = MODELS

+ 1 - 1
paddlex/modules/semantic_segmentation/__init__.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from .dataset_checker import SegDatasetChecker
 from .trainer import SegTrainer
 from .evaluator import SegEvaluator
 from .predictor import SegPredictor, transforms
+from .exportor import SegExportor

+ 21 - 0
paddlex/modules/semantic_segmentation/exportor.py

@@ -0,0 +1,21 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class SegExportor(BaseExportor):
+    """ Semantic Segmentation Model Exportor """
+    entities = MODELS

+ 1 - 1
paddlex/modules/table_recognition/__init__.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from .dataset_checker import TableRecDatasetChecker
 from .trainer import TableRecTrainer
 from .evaluator import TableRecEvaluator
 from .predictor import TableRecPredictor, transforms
+from .exportor import TableRecExportor

+ 21 - 0
paddlex/modules/table_recognition/exportor.py

@@ -0,0 +1,21 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class TableRecExportor(BaseExportor):
+    """ Table Recognition Model Exportor """
+    entities = MODELS

+ 1 - 1
paddlex/modules/text_detection/__init__.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from .dataset_checker import TextDetDatasetChecker
 from .trainer import TextDetTrainer
 from .evaluator import TextDetEvaluator
 from .predictor import TextDetPredictor, transforms
+from .exportor import TextDetExportor

+ 21 - 0
paddlex/modules/text_detection/exportor.py

@@ -0,0 +1,21 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class TextDetExportor(BaseExportor):
+    """ Text Detection Model Exportor """
+    entities = MODELS

+ 1 - 1
paddlex/modules/text_recognition/__init__.py

@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from .dataset_checker import TextRecDatasetChecker
 from .trainer import TextRecTrainer
 from .evaluator import TextRecEvaluator
 from .predictor import TextRecPredictor, transforms
+from .exportor import TextRecExportor

+ 21 - 0
paddlex/modules/text_recognition/exportor.py

@@ -0,0 +1,21 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class TextRecExportor(BaseExportor):
+    """ Text Recognition Model Exportor """
+    entities = MODELS

+ 2 - 4
paddlex/utils/errors/others.py

@@ -1,5 +1,5 @@
 # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-
 import json
 import signal
 from typing import Union
@@ -120,7 +118,7 @@ def raise_class_not_found_error(cls_name, base_cls, all_entities=None):
 def raise_no_entity_registered_error(base_cls):
     """ raise no entity registered error """
     base_cls_name = base_cls.__name__
-    msg = f"There no entity register on {base_cls_name}."
+    msg = f"There no entity register on {base_cls_name}. Hint: Maybe the subclass is not imported."
     raise NoEntityRegisteredException(msg)