Ver código fonte

improve shitu rec inference

zhangyubo0722 1 ano atrás
pai
commit
ed3f0780fc

+ 1 - 1
paddlex/inference/components/task_related/__init__.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .clas import Topk, MultiLabelThreshOutput
+from .clas import Topk, MultiLabelThreshOutput, NormalizeFeatures
 from .text_det import DetResizeForTest, NormalizeImage, DBPostProcess, CropByPolys
 from .text_rec import OCRReisizeNormImg, CTCLabelDecode
 from .table_rec import TableLabelDecode, TableMasterLabelDecode

+ 8 - 8
paddlex/inference/components/task_related/clas.py

@@ -112,13 +112,13 @@ class MultiLabelThreshOutput(BaseComponent):
 class NormalizeFeatures(BaseComponent):
     """Normalize Features Transform"""
 
-    INPUT_KEYS = ["cls_pred"]
-    OUTPUT_KEYS = ["cls_res"]
-    DEAULT_INPUTS = {"cls_res": "cls_res"}
-    DEAULT_OUTPUTS = {"cls_pred": "cls_pred"}
+    INPUT_KEYS = ["pred"]
+    OUTPUT_KEYS = ["rec_feature"]
+    DEAULT_INPUTS = {"pred": "pred"}
+    DEAULT_OUTPUTS = {"rec_feature": "rec_feature"}
 
-    def apply(self, cls_pred):
+    def apply(self, pred):
         """apply"""
-        feas_norm = np.sqrt(np.sum(np.square(cls_pred), axis=0, keepdims=True))
-        cls_res = np.divide(cls_pred, feas_norm)
-        return {"cls_res": cls_res}
+        feas_norm = np.sqrt(np.sum(np.square(pred[0]), axis=0, keepdims=True))
+        rec_feature = np.divide(pred[0], feas_norm)
+        return {"rec_feature": rec_feature}

+ 1 - 0
paddlex/inference/pipelines/__init__.py

@@ -17,3 +17,4 @@ from .ocr import OCRPipeline
 from .object_detection import DetPipeline
 from .instance_segmentation import InstanceSegPipeline
 from .semantic_segmentation import SegPipeline
+from .general_recognition import ShiTuRecPipeline

+ 33 - 0
paddlex/inference/pipelines/general_recognition.py

@@ -0,0 +1,33 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BasePipeline
+from ..predictors import create_predictor
+
+
+class ShiTuRecPipeline(BasePipeline):
+    """ShiTu Rec Pipeline"""
+
+    entities = "general_recognition"
+
+    def __init__(self, model, batch_size=1, device="gpu"):
+        super().__init__()
+        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+
+    def predict(self, x):
+        self._check_input(x)
+        yield from self._predict(x)
+
+    def _check_input(self, x):
+        pass

+ 1 - 0
paddlex/inference/predictors/__init__.py

@@ -23,6 +23,7 @@ from .table_recognition import TablePredictor
 from .object_detection import DetPredictor
 from .instance_segmentation import InstanceSegPredictor
 from .semantic_segmentation import SegPredictor
+from .general_recognition import ShiTuRecPredictor
 from .official_models import official_models
 
 

+ 109 - 0
paddlex/inference/predictors/general_recognition.py

@@ -0,0 +1,109 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...utils.func_register import FuncRegister
+from ...modules.general_recognition.model_list import MODELS
+from ..components import *
+from ..results import BaseResult
+from ..utils.process_hook import batchable_method
+from .base import BasicPredictor
+
+
+class ShiTuRecPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def _check_args(self, kwargs):
+        assert set(kwargs.keys()).issubset(set(["batch_size"]))
+        return kwargs
+
+    def _build_components(self):
+        ops = {}
+        ops["ReadImage"] = ReadImage(
+            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
+        )
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            func = self._FUNC_MAP.get(tf_key)
+            args = cfg.get(tf_key, {})
+            op = func(self, **args) if args else func(self)
+            ops[tf_key] = op
+
+        predictor = ImagePredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        ops["predictor"] = predictor
+
+        post_processes = self.config["PostProcess"]
+        for key in post_processes:
+            func = self._FUNC_MAP.get(key)
+            args = post_processes.get(key, {})
+            op = func(self, **args) if args else func(self)
+            ops[key] = op
+        return ops
+
+    @register("ResizeImage")
+    # TODO(gaotingquan): backend & interpolation
+    def build_resize(
+        self,
+        resize_short=None,
+        size=None,
+        backend="cv2",
+        interpolation="LINEAR",
+        return_numpy=False,
+    ):
+        assert resize_short or size
+        if resize_short:
+            op = ResizeByShort(
+                target_short_edge=resize_short, size_divisor=None, interp="LINEAR"
+            )
+        else:
+            op = Resize(target_size=size)
+        return op
+
+    @register("CropImage")
+    def build_crop(self, size=224):
+        return Crop(crop_size=size)
+
+    @register("NormalizeImage")
+    def build_normalize(
+        self,
+        mean=[0.485, 0.456, 0.406],
+        std=[0.229, 0.224, 0.225],
+        scale=1 / 255,
+        order="",
+        channel_num=3,
+    ):
+        assert channel_num == 3
+        return Normalize(mean=mean, std=std)
+
+    @register("ToCHWImage")
+    def build_to_chw(self):
+        return ToCHWImage()
+
+    @register("NormalizeFeatures")
+    def build_normalize_features(self):
+        return NormalizeFeatures()
+
+    @batchable_method
+    def _pack_res(self, data):
+        keys = ["img_path", "rec_feature"]
+        return {"result": BaseResult({key: data[key] for key in keys})}

+ 2 - 0
paddlex/inference/predictors/object_detection.py

@@ -82,6 +82,8 @@ class DetPredictor(BasicPredictor):
             scale = 1.0 / 255.0
         else:
             scale = 1
+        if not norm_type or norm_type == "none":
+            norm_type = "mean_std"
         if norm_type != "mean_std":
             mean = 0
             std = 1

+ 1 - 0
paddlex/inference/results/__init__.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from .base import BaseResult
 from .topk import TopkResult
 from .text_det import TextDetResult
 from .text_rec import TextRecResult

+ 0 - 2
paddlex/inference/results/instance_seg.py

@@ -81,7 +81,5 @@ class InstanceSegResult(BaseResult):
         image = self._img_reader.read(img_path)
         image = draw_mask(image, boxes, masks, labels)
         image = draw_box(image, boxes, labels=labels)
-        self["boxes"] = boxes.tolist()
-        self["masks"] = masks.tolist()
 
         return image