gaotingquan 1 жил өмнө
parent
commit
7b9e113980

+ 1 - 0
paddlex/inference/models/base/__init__.py

@@ -14,3 +14,4 @@
 
 from .base_predictor import BasePredictor, BasicPredictor
 from .cv_predictor import CVPredictor
+from .ts_predictor import TSPredictor

+ 3 - 0
paddlex/inference/models/base/base_predictor.py

@@ -130,6 +130,9 @@ class BasicPredictor(
 
     def set_predict(self, **kwargs):
         for k in kwargs:
+            assert (
+                k in self._pred_set_func_map
+            ), f"The arg({k}) is not supported to specify in predict() func! Only supports: {self._pred_set_func_map.keys()}"
             self._pred_set_func_map[k](kwargs[k])
 
     @abstractmethod

+ 2 - 2
paddlex/inference/models/base/cv_predictor.py

@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ..utils.predict_set import BatchSetMixin
+from ..utils.predict_set import BatchSizeSetMixin
 from .base_predictor import BasicPredictor
 
 
-class CVPredictor(BasicPredictor, BatchSetMixin):
+class CVPredictor(BasicPredictor, BatchSizeSetMixin):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._pred_set_register("batch_size")(self.set_batch_size)

+ 20 - 0
paddlex/inference/models/base/ts_predictor.py

@@ -0,0 +1,20 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base_predictor import BasicPredictor
+
+
+class TSPredictor(BasicPredictor):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)

+ 0 - 2
paddlex/inference/models/general_recognition.py

@@ -18,7 +18,6 @@ from ...utils.func_register import FuncRegister
 from ...modules.general_recognition.model_list import MODELS
 from ..components import *
 from ..results import BaseResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 
@@ -95,7 +94,6 @@ class ShiTuRecPredictor(CVPredictor):
     def build_normalize_features(self):
         return NormalizeFeatures()
 
-    @batchable_method
     def _pack_res(self, data):
         keys = ["img_path", "rec_feature"]
         return BaseResult({key: data[key] for key in keys})

+ 0 - 1
paddlex/inference/models/image_classification.py

@@ -19,7 +19,6 @@ from ...modules.image_classification.model_list import MODELS
 from ...modules.multilabel_classification.model_list import MODELS as ML_MODELS
 from ..components import *
 from ..results import TopkResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 

+ 0 - 2
paddlex/inference/models/image_unwarping.py

@@ -15,7 +15,6 @@
 from ...modules.image_unwarping.model_list import MODELS
 from ..components import *
 from ..results import DocTrResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 
@@ -39,7 +38,6 @@ class WarpPredictor(CVPredictor):
         )
         self._add_component([("Predictor", predictor), DocTrPostProcess()])
 
-    @batchable_method
     def _pack_res(self, single):
         keys = ["img_path", "doctr_img"]
         return DocTrResult({key: single[key] for key in keys})

+ 0 - 1
paddlex/inference/models/instance_segmentation.py

@@ -19,7 +19,6 @@ from ...utils.func_register import FuncRegister
 from ...modules.instance_segmentation.model_list import MODELS
 from ..components import *
 from ..results import InstanceSegResult
-from ..utils.process_hook import batchable_method
 
 
 class InstanceSegPredictor(DetPredictor):

+ 0 - 1
paddlex/inference/models/object_detection.py

@@ -18,7 +18,6 @@ from ...utils.func_register import FuncRegister
 from ...modules.object_detection.model_list import MODELS
 from ..components import *
 from ..results import DetResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 

+ 0 - 1
paddlex/inference/models/semantic_segmentation.py

@@ -18,7 +18,6 @@ from ...utils.func_register import FuncRegister
 from ...modules.semantic_segmentation.model_list import MODELS
 from ..components import *
 from ..results import SegResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 

+ 0 - 1
paddlex/inference/models/table_recognition.py

@@ -19,7 +19,6 @@ from ...utils.func_register import FuncRegister
 from ...modules.table_recognition.model_list import MODELS
 from ..components import *
 from ..results import TableRecResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 

+ 0 - 1
paddlex/inference/models/text_detection.py

@@ -18,7 +18,6 @@ from ...utils.func_register import FuncRegister
 from ...modules.text_detection.model_list import MODELS
 from ..components import *
 from ..results import TextDetResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 

+ 0 - 1
paddlex/inference/models/text_recognition.py

@@ -18,7 +18,6 @@ from ...utils.func_register import FuncRegister
 from ...modules.text_recognition.model_list import MODELS
 from ..components import *
 from ..results import TextRecResult
-from ..utils.process_hook import batchable_method
 from .base import CVPredictor
 
 

+ 23 - 33
paddlex/inference/models/ts_ad.py

@@ -17,60 +17,50 @@ import os
 from ...modules.ts_anomaly_detection.model_list import MODELS
 from ..components import *
 from ..results import TSAdResult
-from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import TSPredictor
 
 
-class TSAdPredictor(BasicPredictor):
+class TSAdPredictor(TSPredictor):
 
     entities = MODELS
 
     def _build_components(self):
-        preprocess = self._build_preprocess()
-        predictor = TSPPPredictor(
-            model_dir=self.model_dir,
-            model_prefix=self.MODEL_FILE_PREFIX,
-            option=self.pp_option,
-        )
-        postprocess = self._build_postprocess()
-        return {**preprocess, "predictor": predictor, **postprocess}
-
-    def _build_preprocess(self):
         if not self.config.get("info_params", None):
             raise Exception("info_params is not found in config file")
 
-        ops = {}
-        ops["ReadTS"] = ReadTS()
-        ops["TSCutOff"] = TSCutOff(self.config["size"])
+        self._add_component(ReadTS())
+        self._add_component(TSCutOff(self.config["size"]))
 
         if self.config.get("scale", None):
             scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
             if not os.path.exists(scaler_file_path):
                 raise Exception(f"Cannot find scaler file: {scaler_file_path}")
-            ops["TSNormalize"] = TSNormalize(
-                scaler_file_path, self.config["info_params"]
+            self._add_component(
+                TSNormalize(scaler_file_path, self.config["info_params"])
             )
 
-        ops["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+        self._add_component(BuildTSDataset(self.config["info_params"]))
 
         if self.config.get("time_feat", None):
-            ops["TimeFeature"] = TimeFeature(
-                self.config["info_params"],
-                self.config["size"],
-                self.config["holiday"],
+            self._add_component(
+                TimeFeature(
+                    self.config["info_params"],
+                    self.config["size"],
+                    self.config["holiday"],
+                )
             )
-        ops["TStoArray"] = TStoArray(self.config["input_data"])
-        return ops
+        self._add_component(TStoArray(self.config["input_data"]))
 
-    def _build_postprocess(self):
-        if not self.config.get("info_params", None):
-            raise Exception("info_params is not found in config file")
-        ops = {}
-        ops["GetAnomaly"] = GetAnomaly(
-            self.config["model_threshold"], self.config["info_params"]
+        predictor = TSPPPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        self._add_component(("Predictor", predictor))
+
+        self._add_component(
+            GetAnomaly(self.config["model_threshold"], self.config["info_params"])
         )
-        return ops
 
-    @batchable_method
     def _pack_res(self, single):
         return TSAdResult({"ts_path": single["ts_path"], "anomaly": single["anomaly"]})

+ 15 - 20
paddlex/inference/models/ts_cls.py

@@ -16,42 +16,37 @@ import os
 from ...modules.ts_classification.model_list import MODELS
 from ..components import *
 from ..results import TSClsResult
-from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import TSPredictor
 
 
-class TSClsPredictor(BasicPredictor):
+class TSClsPredictor(TSPredictor):
 
     entities = MODELS
 
     def _build_components(self):
-        preprocess = self._build_preprocess()
-        predictor = TSPPPredictor(
-            model_dir=self.model_dir,
-            model_prefix=self.MODEL_FILE_PREFIX,
-            option=self.pp_option,
-        )
-        return {**preprocess, "predictor": predictor, "GetCls": GetCls()}
-
-    def _build_preprocess(self):
         if not self.config.get("info_params", None):
             raise Exception("info_params is not found in config file")
 
-        ops = {}
-        ops["ReadTS"] = ReadTS()
+        self._add_component(ReadTS())
         if self.config.get("scale", None):
             scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
             if not os.path.exists(scaler_file_path):
                 raise Exception(f"Cannot find scaler file: {scaler_file_path}")
-            ops["TSNormalize"] = TSNormalize(
-                scaler_file_path, self.config["info_params"]
+            self._add_component(
+                TSNormalize(scaler_file_path, self.config["info_params"])
             )
 
-        ops["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
-        ops["BuildPadMask"] = BuildPadMask(self.config["input_data"])
-        ops["TStoArray"] = TStoArray(self.config["input_data"])
+        self._add_component(BuildTSDataset(self.config["info_params"]))
+        self._add_component(BuildPadMask(self.config["input_data"]))
+        self._add_component(TStoArray(self.config["input_data"]))
 
-        return ops
+        predictor = TSPPPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        self._add_component(("Predictor", predictor))
+        self._add_component(GetCls())
 
     def _pack_res(self, single):
         return TSClsResult(

+ 23 - 33
paddlex/inference/models/ts_fc.py

@@ -17,65 +17,55 @@ import os
 from ...modules.ts_forecast.model_list import MODELS
 from ..components import *
 from ..results import TSFcResult
-from ..utils.process_hook import batchable_method
-from .base import BasicPredictor
+from .base import TSPredictor
 
 
-class TSFcPredictor(BasicPredictor):
+class TSFcPredictor(TSPredictor):
 
     entities = MODELS
 
     def _build_components(self):
-        preprocess = self._build_preprocess()
-        predictor = TSPPPredictor(
-            model_dir=self.model_dir,
-            model_prefix=self.MODEL_FILE_PREFIX,
-            option=self.pp_option,
-        )
-        postprocess = self._build_postprocess()
-        return {**preprocess, "predictor": predictor, **postprocess}
-
-    def _build_preprocess(self):
         if not self.config.get("info_params", None):
             raise Exception("info_params is not found in config file")
 
-        ops = {}
-        ops["ReadTS"] = ReadTS()
-        ops["TSCutOff"] = TSCutOff(self.config["size"])
+        self._add_component(ReadTS())
+        self._add_component(TSCutOff(self.config["size"]))
 
         if self.config.get("scale", None):
             scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
             if not os.path.exists(scaler_file_path):
                 raise Exception(f"Cannot find scaler file: {scaler_file_path}")
-            ops["TSNormalize"] = TSNormalize(
-                scaler_file_path, self.config["info_params"]
+            self._add_component(
+                TSNormalize(scaler_file_path, self.config["info_params"])
             )
 
-        ops["BuildTSDataset"] = BuildTSDataset(self.config["info_params"])
+        self._add_component(BuildTSDataset(self.config["info_params"]))
 
         if self.config.get("time_feat", None):
-            ops["TimeFeature"] = TimeFeature(
-                self.config["info_params"],
-                self.config["size"],
-                self.config["holiday"],
+            self._add_component(
+                TimeFeature(
+                    self.config["info_params"],
+                    self.config["size"],
+                    self.config["holiday"],
+                )
             )
-        ops["TStoArray"] = TStoArray(self.config["input_data"])
-        return ops
+        self._add_component(TStoArray(self.config["input_data"]))
 
-    def _build_postprocess(self):
-        if not self.config.get("info_params", None):
-            raise Exception("info_params is not found in config file")
+        predictor = TSPPPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        self._add_component(("Predictor", predictor))
 
-        ops = {}
-        ops["ArraytoTS"] = ArraytoTS(self.config["info_params"])
+        self._add_component(ArraytoTS(self.config["info_params"]))
         if self.config.get("scale", None):
             scaler_file_path = os.path.join(self.model_dir, "scaler.pkl")
             if not os.path.exists(scaler_file_path):
                 raise Exception(f"Cannot find scaler file: {scaler_file_path}")
-            ops["TSDeNormalize"] = TSDeNormalize(
-                scaler_file_path, self.config["info_params"]
+            self._add_component(
+                TSDeNormalize(scaler_file_path, self.config["info_params"])
             )
-        return ops
 
     def _pack_res(self, single):
         return TSFcResult({"ts_path": single["ts_path"], "forecast": single["pred"]})

+ 1 - 1
paddlex/inference/models/utils/predict_set.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 
-class BatchSetMixin:
+class BatchSizeSetMixin:
     def set_batch_size(self, batch_size):
         self.components["ReadImage"].set_batch_size(batch_size)
 

+ 2 - 2
paddlex/paddlex_cli.py

@@ -91,10 +91,10 @@ def install(args):
     return
 
 
-def pipeline_predict(pipeline, input_path, device=None, save_dir=None):
+def pipeline_predict(pipeline, input, device=None, save_dir=None):
     """pipeline predict"""
     pipeline = create_pipeline(pipeline)
-    result = pipeline(input_path, device=device)
+    result = pipeline(input, device=device)
     for res in result:
         res.print(json_format=False)
         # TODO(gaotingquan): support to save all