瀏覽代碼

upgrade PaddleX Inference

1. update API create_predictor;
2. support ML Clas;
3. support pass pp_option;
4. unify key to 'result';
gaotingquan 1 年之前
父節點
當前提交
d327137b91

+ 15 - 0
paddlex/inference/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .components.paddle_predictor.option import PaddlePredictorOption

+ 3 - 3
paddlex/inference/components/paddle_predictor/option.py

@@ -39,12 +39,12 @@ class PaddlePredictorOption(object):
 
     def _init_option(self, **kwargs):
         for k, v in kwargs.items():
-            if k not in self._REGISTER_MAP:
+            if k not in self._FUNC_MAP:
                 raise Exception(
                     f"{k} is not supported to set! The supported option is: \
-{list(self._REGISTER_MAP.keys())}"
+{list(self._FUNC_MAP.keys())}"
                 )
-            self._REGISTER_MAP.get(k)(self, v)
+            self._FUNC_MAP.get(k)(self, v)
         for k, v in self._get_default_config().items():
             self._cfg.setdefault(k, v)
 

+ 1 - 1
paddlex/inference/components/task_related/__init__.py

@@ -12,6 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .clas import Topk
+from .clas import Topk, MultiLabelThreshOutput
 from .text_det import DetResizeForTest, NormalizeImage, DBPostProcess, CropByPolys
 from .text_rec import OCRReisizeNormImg, CTCLabelDecode

+ 43 - 4
paddlex/inference/components/task_related/clas.py

@@ -50,8 +50,6 @@ class Topk(BaseComponent):
     def apply(self, pred):
         """apply"""
         cls_pred = pred[0]
-        class_id_map = self.class_id_map
-
         index = cls_pred.argsort(axis=0)[-self.topk :][::-1].astype("int32")
         clas_id_list = []
         score_list = []
@@ -59,8 +57,49 @@ class Topk(BaseComponent):
         for i in index:
             clas_id_list.append(i.item())
             score_list.append(cls_pred[i].item())
-            if class_id_map is not None:
-                label_name_list.append(class_id_map[i.item()])
+            if self.class_id_map is not None:
+                label_name_list.append(self.class_id_map[i.item()])
+        result = {
+            "class_ids": clas_id_list,
+            "scores": np.around(score_list, decimals=5).tolist(),
+        }
+        if label_name_list is not None:
+            result["label_names"] = label_name_list
+        return result
+
+
+class MultiLabelThreshOutput(BaseComponent):
+
+    INPUT_KEYS = ["pred"]
+    OUTPUT_KEYS = [["class_ids", "scores"], ["class_ids", "scores", "label_names"]]
+    DEAULT_INPUTS = {"pred": "pred"}
+    DEAULT_OUTPUTS = {
+        "class_ids": "class_ids",
+        "scores": "scores",
+        "label_names": "label_names",
+    }
+
+    def __init__(self, threshold=0.5, class_ids=None, delimiter=None):
+        super().__init__()
+        assert isinstance(threshold, (float,))
+        self.threshold = threshold
+        self.delimiter = delimiter if delimiter is not None else " "
+        self.class_id_map = _parse_class_id_map(class_ids)
+
+    def apply(self, pred):
+        """apply"""
+        y = []
+        x = pred[0]
+        pred_index = np.where(x >= self.threshold)[0].astype("int32")
+        index = pred_index[np.argsort(x[pred_index])][::-1]
+        clas_id_list = []
+        score_list = []
+        label_name_list = []
+        for i in index:
+            clas_id_list.append(i.item())
+            score_list.append(x[i].item())
+            if self.class_id_map is not None:
+                label_name_list.append(self.class_id_map[i.item()])
         result = {
             "class_ids": clas_id_list,
             "scores": np.around(score_list, decimals=5).tolist(),

+ 2 - 2
paddlex/inference/pipelines/image_classification.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import ClasPredictor
+from ..predictors import create_predictor
 
 
 class ClasPipeline(BasePipeline):
@@ -23,7 +23,7 @@ class ClasPipeline(BasePipeline):
 
     def __init__(self, model, batch_size=1, device="gpu"):
         super().__init__()
-        self._predict = ClasPredictor(model, batch_size=batch_size)
+        self._predict = create_predictor(model, batch_size=batch_size, device=device)
 
     def predict(self, x):
         self._check_input(x)

+ 6 - 8
paddlex/inference/pipelines/ocr.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from .base import BasePipeline
-from ..predictors import TextDetPredictor, TextRecPredictor
+from ..predictors import create_predictor
 from ..components import CropByPolys
 from ..results import OCRResult
 
@@ -24,8 +24,8 @@ class OCRPipeline(BasePipeline):
     entities = "ocr"
 
     def __init__(self, det_model, rec_model, det_batch_size, rec_batch_size, **kwargs):
-        self._det_predict = TextDetPredictor(det_model, batch_size=det_batch_size)
-        self._rec_predict = TextRecPredictor(rec_model, batch_size=rec_batch_size)
+        self._det_predict = create_predictor(det_model, batch_size=det_batch_size)
+        self._rec_predict = create_predictor(rec_model, batch_size=rec_batch_size)
         # TODO: foo
         self._crop_by_polys = CropByPolys(det_box_type="foo")
 
@@ -33,17 +33,15 @@ class OCRPipeline(BasePipeline):
         batch_ocr_res = []
         for batch_det_res in self._det_predict(x):
             for det_res in batch_det_res:
-                single_img_res = det_res["text_det_res"]
+                single_img_res = det_res["result"]
                 single_img_res["rec_text"] = []
                 single_img_res["rec_score"] = []
                 all_subs_of_img = list(self._crop_by_polys(single_img_res))
                 for batch_rec_res in self._rec_predict(all_subs_of_img):
                     for rec_res in batch_rec_res:
-                        single_img_res["rec_text"].append(
-                            rec_res["text_rec_res"]["rec_text"]
-                        )
+                        single_img_res["rec_text"].append(rec_res["result"]["rec_text"])
                         single_img_res["rec_score"].append(
-                            rec_res["text_rec_res"]["rec_score"]
+                            rec_res["result"]["rec_score"]
                         )
                 # TODO(gaotingquan): using "ocr_res" or new a component or dict only?
                 batch_ocr_res.append({"ocr_res": OCRResult(single_img_res)})

+ 9 - 2
paddlex/inference/predictors/__init__.py

@@ -22,12 +22,19 @@ from .text_recognition import TextRecPredictor
 from .official_models import official_models
 
 
-def create_predictor(model: str, device: str, *args, **kwargs) -> BasePredictor:
+def create_predictor(
+    model: str, device: str = None, pp_option=None, *args, **kwargs
+) -> BasePredictor:
     model_dir = check_model(model)
     config = BasePredictor.load_config(model_dir)
     model_name = config["Global"]["model_name"]
     return BasePredictor.get(model_name)(
-        model_dir=model_dir, config=config, device=device, *args, **kwargs
+        model_dir=model_dir,
+        config=config,
+        device=device,
+        pp_option=pp_option,
+        *args,
+        **kwargs,
     )
 
 

+ 20 - 4
paddlex/inference/predictors/base.py

@@ -18,7 +18,9 @@ from pathlib import Path
 from abc import abstractmethod
 
 from ...utils.subclass_register import AutoRegisterABCMetaClass
+from ...utils import logging
 from ..components.base import BaseComponent, ComponentsEngine
+from ..components.paddle_predictor.option import PaddlePredictorOption
 from ..utils.process_hook import generatorable_method
 
 
@@ -26,23 +28,34 @@ class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
     __is_base = True
 
     INPUT_KEYS = "x"
-    OUTPUT_KEYS = None
+    DEAULT_INPUTS = {"x": "x"}
+    OUTPUT_KEYS = "result"
+    DEAULT_OUTPUTS = {"result": "result"}
 
     KEEP_INPUT = False
 
     MODEL_FILE_PREFIX = "inference"
 
-    def __init__(self, model_dir, config=None, device="gpu", **kwargs):
+    def __init__(self, model_dir, config=None, device=None, pp_option=None, **kwargs):
         super().__init__()
         self.model_dir = Path(model_dir)
         self.config = config if config else self.load_config(self.model_dir)
-        self.device = device
-        self.kwargs = kwargs
+        self.kwargs = self._check_args(kwargs)
+
+        self.pp_option = PaddlePredictorOption() if pp_option is None else pp_option
+        if device is not None:
+            self.pp_option.set_device(device)
+
         self.components = self._build_components()
         self.engine = ComponentsEngine(self.components)
+
         # alias predict() to the __call__()
         self.predict = self.__call__
 
+        logging.debug(
+            f"-------------------- {self.__class__.__name__} --------------------\nModel: {self.model_dir}\nEnv: {self.pp_option}"
+        )
+
     @classmethod
     def load_config(cls, model_dir):
         config_path = model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
@@ -58,6 +71,9 @@ class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
     def _generate_res(self, data):
         return self._pack_res(data)
 
+    def _check_args(self, kwargs):
+        return kwargs
+
     @abstractmethod
     def _build_components(self):
         raise NotImplementedError

+ 14 - 10
paddlex/inference/predictors/image_classification.py

@@ -26,14 +26,13 @@ class ClasPredictor(BasePredictor):
 
     entities = MODELS
 
-    INPUT_KEYS = "x"
-    OUTPUT_KEYS = "topk_res"
-    DEAULT_INPUTS = {"x": "x"}
-    DEAULT_OUTPUTS = {"topk_res": "topk_res"}
-
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
+    def _check_args(self, kwargs):
+        assert set(kwargs.keys()).issubset(set(["batch_size"]))
+        return kwargs
+
     def _build_components(self):
         ops = {}
         ops["ReadImage"] = ReadImage(batch_size=self.kwargs.get("batch_size", 1))
@@ -44,12 +43,10 @@ class ClasPredictor(BasePredictor):
             op = func(self, **args) if args else func(self)
             ops[tf_key] = op
 
-        kernel_option = PaddlePredictorOption()
-        kernel_option.set_device(self.device)
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
-            option=kernel_option,
+            option=self.pp_option,
         )
         ops["predictor"] = predictor
 
@@ -62,7 +59,10 @@ class ClasPredictor(BasePredictor):
         return ops
 
     @register("ResizeImage")
-    def build_resize(self, resize_short=None, size=None):
+    # TODO(gaotingquan): backend & interpolation
+    def build_resize(
+        self, resize_short=None, size=None, backend="cv2", interpolation="LINEAR"
+    ):
         assert resize_short or size
         if resize_short:
             op = ResizeByShort(
@@ -97,9 +97,13 @@ class ClasPredictor(BasePredictor):
     def build_topk(self, topk, label_list=None):
         return Topk(topk=int(topk), class_ids=label_list)
 
+    @register("MultiLabelThreshOutput")
+    def build_threshoutput(self, threshold, label_list=None):
+        return MultiLabelThreshOutput(threshold=float(threshold), class_ids=label_list)
+
     @batchable_method
     def _pack_res(self, data):
         keys = ["img_path", "class_ids", "scores"]
         if "label_names" in data:
             keys.append("label_names")
-        return {"topk_res": TopkResult({key: data[key] for key in keys})}
+        return {"result": TopkResult({key: data[key] for key in keys})}

+ 2 - 9
paddlex/inference/predictors/text_detection.py

@@ -26,11 +26,6 @@ class TextDetPredictor(BasePredictor):
 
     entities = MODELS
 
-    INPUT_KEYS = "x"
-    OUTPUT_KEYS = "text_det_res"
-    DEAULT_INPUTS = {"x": "x"}
-    DEAULT_OUTPUTS = {"text_det_res": "text_det_res"}
-
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
@@ -44,12 +39,10 @@ class TextDetPredictor(BasePredictor):
             if op:
                 ops[tf_key] = op
 
-        kernel_option = PaddlePredictorOption()
-        kernel_option.set_device(self.device)
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
-            option=kernel_option,
+            option=self.pp_option,
         )
         ops["predictor"] = predictor
 
@@ -109,4 +102,4 @@ class TextDetPredictor(BasePredictor):
     @batchable_method
     def _pack_res(self, data):
         keys = ["img_path", "dt_polys", "dt_scores"]
-        return {"text_det_res": TextDetResult({key: data[key] for key in keys})}
+        return {"result": TextDetResult({key: data[key] for key in keys})}

+ 2 - 9
paddlex/inference/predictors/text_recognition.py

@@ -26,11 +26,6 @@ class TextRecPredictor(BasePredictor):
 
     entities = MODELS
 
-    INPUT_KEYS = "x"
-    OUTPUT_KEYS = "text_rec_res"
-    DEAULT_INPUTS = {"x": "x"}
-    DEAULT_OUTPUTS = {"text_rec_res": "text_rec_res"}
-
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
@@ -45,12 +40,10 @@ class TextRecPredictor(BasePredictor):
             if op:
                 ops[tf_key] = op
 
-        kernel_option = PaddlePredictorOption()
-        kernel_option.set_device(self.device)
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
-            option=kernel_option,
+            option=self.pp_option,
         )
         ops["predictor"] = predictor
 
@@ -86,4 +79,4 @@ class TextRecPredictor(BasePredictor):
     @batchable_method
     def _pack_res(self, data):
         keys = ["img_path", "rec_text", "rec_score"]
-        return {"text_rec_res": TextRecResult({key: data[key] for key in keys})}
+        return {"result": TextRecResult({key: data[key] for key in keys})}