Эх сурвалжийг харах

improve new pdx inference

gaotingquan 1 жил өмнө
parent
commit
7ca0769ee0

+ 1 - 1
paddlex/inference/components/paddle_predictor/__init__.py

@@ -13,4 +13,4 @@
 # limitations under the License.
 
 from .option import PaddlePredictorOption
-from .image_predictor import ImagePredictor
+from .predictor import ImagePredictor

+ 0 - 26
paddlex/inference/components/paddle_predictor/image_predictor.py

@@ -1,26 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import numpy as np
-
-from .base_predictor import BasePaddlePredictor
-
-
-class ImagePredictor(BasePaddlePredictor):
-
-    def to_batch(self, imgs):
-        return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]
-
-    def format_output(self, output):
-        return [{"pred": np.array(res)} for res in output[0].tolist()]

+ 9 - 4
paddlex/inference/components/paddle_predictor/option.py

@@ -83,6 +83,12 @@ class PaddlePredictorOption(object):
     @register("device")
     def set_device(self, device_setting: str):
         """set device"""
+        self._cfg["device"], self._cfg["device_id"] = self.parse_device_setting(
+            device_setting
+        )
+
+    @classmethod
+    def parse_device_setting(cls, device_setting):
         if len(device_setting.split(":")) == 1:
             device = device_setting.split(":")[0]
             device_id = 0
@@ -92,13 +98,12 @@ class PaddlePredictorOption(object):
             device_id = device_setting.split(":")[1].split(",")[0]
             logging.warning(f"The device id has been set to {device_id}.")
 
-        if device.lower() not in self.SUPPORT_DEVICE:
-            support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
+        if device.lower() not in cls.SUPPORT_DEVICE:
+            support_run_mode_str = ", ".join(cls.SUPPORT_DEVICE)
             raise ValueError(
                 f"`device` must be {support_run_mode_str}, but received {repr(device)}."
             )
-        self._cfg["device"] = device.lower()
-        self._cfg["device_id"] = int(device_id)
+        return device.lower(), int(device_id)
 
     @register("min_subgraph_size")
     def set_min_subgraph_size(self, min_subgraph_size: int):

+ 15 - 11
paddlex/inference/components/paddle_predictor/base_predictor.py → paddlex/inference/components/paddle_predictor/predictor.py

@@ -17,6 +17,7 @@ from abc import abstractmethod
 
 import paddle
 from paddle.inference import Config, create_predictor
+import numpy as np
 
 from ..base import BaseComponent
 from ....utils import logging
@@ -25,9 +26,9 @@ from ....utils import logging
 class BasePaddlePredictor(BaseComponent):
     """Predictor based on Paddle Inference"""
 
-    INPUT_KEYS = "imgs"
+    INPUT_KEYS = "batch_data"
     OUTPUT_KEYS = "pred"
-    DEAULT_INPUTS = {"x": "x"}
+    DEAULT_INPUTS = {"batch_data": "batch_data"}
     DEAULT_OUTPUTS = {"pred": "pred"}
     ENABLE_BATCH = True
 
@@ -146,8 +147,8 @@ No need to generate again."
         """get input names"""
         return self.input_names
 
-    def apply(self, imgs):
-        x = self.to_batch(imgs)
+    def apply(self, batch_data):
+        x = self.to_batch(batch_data)
         for idx in range(len(x)):
             self.input_handlers[idx].reshape(x[idx].shape)
             self.input_handlers[idx].copy_from_cpu(x[idx])
@@ -156,15 +157,18 @@ No need to generate again."
 
         output = []
         for out_tensor in self.output_handlers:
-            out_arr = out_tensor.copy_to_cpu()
-            output.append(out_arr)
+            batch = out_tensor.copy_to_cpu()
+            output.append(batch)
 
-        return self.format_output(output)
+        return [{"pred": res} for res in zip(*output)]
 
     @abstractmethod
-    def to_batch(self, imgs):
+    def to_batch(self):
         raise NotImplementedError
 
-    @abstractmethod
-    def format_output(self, output):
-        raise NotImplementedError
+
+class ImagePredictor(BasePaddlePredictor):
+    DEAULT_INPUTS = {"batch_data": "img"}
+
+    def to_batch(self, imgs):
+        return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]

+ 11 - 10
paddlex/inference/components/task_related/clas.py

@@ -15,7 +15,6 @@
 import numpy as np
 
 from ....utils import logging
-from ...results import TopkResult
 from ..base import BaseComponent
 
 
@@ -33,10 +32,14 @@ def _parse_class_id_map(class_ids):
 class Topk(BaseComponent):
     """Topk Transform"""
 
-    INPUT_KEYS = ["pred", "img_path"]
-    OUTPUT_KEYS = ["topk_res"]
-    DEAULT_INPUTS = {"pred": "pred", "img_path": "img_path"}
-    DEAULT_OUTPUTS = {"topk_res": "topk_res"}
+    INPUT_KEYS = ["pred"]
+    OUTPUT_KEYS = [["class_ids", "scores"], ["class_ids", "scores", "label_names"]]
+    DEAULT_INPUTS = {"pred": "pred"}
+    DEAULT_OUTPUTS = {
+        "class_ids": "class_ids",
+        "scores": "scores",
+        "label_names": "label_names",
+    }
 
     def __init__(self, topk, class_ids=None):
         super().__init__()
@@ -44,9 +47,9 @@ class Topk(BaseComponent):
         self.topk = topk
         self.class_id_map = _parse_class_id_map(class_ids)
 
-    def apply(self, pred, img_path):
+    def apply(self, pred):
         """apply"""
-        cls_pred = pred
+        cls_pred = pred[0]
         class_id_map = self.class_id_map
 
         index = cls_pred.argsort(axis=0)[-self.topk :][::-1].astype("int32")
@@ -59,14 +62,12 @@ class Topk(BaseComponent):
             if class_id_map is not None:
                 label_name_list.append(class_id_map[i.item()])
         result = {
-            "img_path": img_path,
             "class_ids": clas_id_list,
             "scores": np.around(score_list, decimals=5).tolist(),
         }
         if label_name_list is not None:
             result["label_names"] = label_name_list
-
-        return {"topk_res": TopkResult(result)}
+        return result
 
 
 class NormalizeFeatures(BaseComponent):

+ 7 - 11
paddlex/inference/components/task_related/text_det.py

@@ -25,7 +25,6 @@ from shapely.geometry import Polygon
 
 from ...utils.io import ImageReader
 from ....utils import logging
-from ...results import TextDetResult
 from ..base import BaseComponent
 
 
@@ -206,10 +205,10 @@ class DBPostProcess(BaseComponent):
     The post process for Differentiable Binarization (DB).
     """
 
-    INPUT_KEYS = ["pred", "img_shape", "img_path"]
-    OUTPUT_KEYS = ["text_det_res"]
-    DEAULT_INPUTS = {"pred": "pred", "img_shape": "img_shape", "img_path": "img_path"}
-    DEAULT_OUTPUTS = {"text_det_res": "text_det_res"}
+    INPUT_KEYS = ["pred", "img_shape"]
+    OUTPUT_KEYS = ["dt_polys", "dt_scores"]
+    DEAULT_INPUTS = {"pred": "pred", "img_shape": "img_shape"}
+    DEAULT_OUTPUTS = {"dt_polys": "dt_polys", "dt_scores": "dt_scores"}
 
     def __init__(
         self,
@@ -392,9 +391,9 @@ class DBPostProcess(BaseComponent):
         cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
         return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
 
-    def apply(self, pred, img_shape, img_path):
+    def apply(self, pred, img_shape):
         """apply"""
-        pred = pred[0, :, :]
+        pred = pred[0][0, :, :]
         segmentation = pred > self.thresh
 
         src_h, src_w, ratio_h, ratio_w = img_shape
@@ -412,10 +411,7 @@ class DBPostProcess(BaseComponent):
         else:
             raise ValueError("box_type can only be one of ['quad', 'poly']")
 
-        text_det_res = TextDetResult(
-            {"img_path": img_path, "dt_polys": boxes, "dt_scores": scores}
-        )
-        return {"text_det_res": text_det_res}
+        return {"dt_polys": boxes, "dt_scores": scores}
 
 
 class CropByPolys(BaseComponent):

+ 9 - 24
paddlex/inference/components/task_related/text_rec.py

@@ -27,7 +27,6 @@ import tempfile
 from tokenizers import Tokenizer as TokenizerFast
 
 from ....utils import logging
-from ...results import TextRecResult
 from ..base import BaseComponent
 
 __all__ = [
@@ -192,10 +191,10 @@ class OCRReisizeNormImg(BaseComponent):
 class BaseRecLabelDecode(BaseComponent):
     """Convert between text-label and text-index"""
 
-    INPUT_KEYS = ["pred", "img_path"]
-    OUTPUT_KEYS = ["text_rec_res"]
-    DEAULT_INPUTS = {"pred": "pred", "img_path": "img_path"}
-    DEAULT_OUTPUTS = {"text_rec_res": "text_rec_res"}
+    INPUT_KEYS = ["pred"]
+    OUTPUT_KEYS = ["rec_text", "rec_score"]
+    DEAULT_INPUTS = {"pred": "pred"}
+    DEAULT_OUTPUTS = {"rec_text": "rec_text", "rec_score": "rec_score"}
 
     ENABLE_BATCH = True
 
@@ -271,7 +270,7 @@ class BaseRecLabelDecode(BaseComponent):
         """get_ignored_tokens"""
         return [0]  # for ctc blank
 
-    def apply(self, pred, img_path):
+    def apply(self, pred):
         """apply"""
         preds = np.array(pred)
         if isinstance(preds, tuple) or isinstance(preds, list):
@@ -279,14 +278,7 @@ class BaseRecLabelDecode(BaseComponent):
         preds_idx = preds.argmax(axis=2)
         preds_prob = preds.max(axis=2)
         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
-        return [
-            {
-                "text_rec_res": TextRecResult(
-                    {"img_path": path, "rec_text": t[0], "rec_score": t[1]}
-                )
-            }
-            for path, t in zip(img_path, text)
-        ]
+        return [{"rec_text": t[0], "rec_score": t[1]} for t in text]
 
 
 class CTCLabelDecode(BaseRecLabelDecode):
@@ -295,20 +287,13 @@ class CTCLabelDecode(BaseRecLabelDecode):
     def __init__(self, character_list=None, use_space_char=True):
         super().__init__(character_list, use_space_char=use_space_char)
 
-    def apply(self, pred, img_path):
+    def apply(self, pred):
         """apply"""
-        preds = np.array(pred)
+        preds = np.array(pred[0])
         preds_idx = preds.argmax(axis=2)
         preds_prob = preds.max(axis=2)
         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
-        return [
-            {
-                "text_rec_res": TextRecResult(
-                    {"img_path": path, "rec_text": t[0], "rec_score": t[1]}
-                )
-            }
-            for path, t in zip(img_path, text)
-        ]
+        return [{"rec_text": t[0], "rec_score": t[1]} for t in text]
 
     def add_special_char(self, character_list):
         """add_special_char"""

+ 25 - 0
paddlex/inference/predictors/__init__.py

@@ -12,6 +12,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+
+from pathlib import Path
+
+from .base import BasePredictor
 from .image_classification import ClasPredictor
 from .text_detection import TextDetPredictor
 from .text_recognition import TextRecPredictor
+from .official_models import official_models
+
+
+def create_predictor(model: str, device: str, *args, **kwargs) -> BasePredictor:
+    model_dir = check_model(model)
+    config = BasePredictor.load_config(model_dir)
+    model_name = config["Global"]["model_name"]
+    return BasePredictor.get(model_name)(
+        model_dir=model_dir, config=config, device=device, *args, **kwargs
+    )
+
+
+def check_model(model):
+    if Path(model).exists():
+        return Path(model)
+    elif model in official_models:
+        return official_models[model]
+    else:
+        raise Exception(
+            f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!"
+        )

+ 17 - 17
paddlex/inference/predictors/base.py

@@ -19,7 +19,7 @@ from abc import abstractmethod
 
 from ...utils.subclass_register import AutoRegisterABCMetaClass
 from ..components.base import BaseComponent, ComponentsEngine
-from .official_models import official_models
+from ..utils.process_hook import generatorable_method
 
 
 class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
@@ -32,36 +32,36 @@ class BasePredictor(BaseComponent, metaclass=AutoRegisterABCMetaClass):
 
     MODEL_FILE_PREFIX = "inference"
 
-    def __init__(self, model, **kwargs):
+    def __init__(self, model_dir, config=None, device="gpu", **kwargs):
         super().__init__()
-        self.model_dir = self._check_model(model)
+        self.model_dir = Path(model_dir)
+        self.config = config if config else self.load_config(self.model_dir)
+        self.device = device
         self.kwargs = kwargs
-        self.config = self._load_config()
         self.components = self._build_components()
         self.engine = ComponentsEngine(self.components)
         # alias predict() to the __call__()
         self.predict = self.__call__
 
-    def _check_model(self, model):
-        if Path(model).exists():
-            return Path(model)
-        elif model in official_models:
-            return official_models[model]
-        else:
-            raise Exception(
-                f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!"
-            )
-
-    def _load_config(self):
-        config_path = self.model_dir / f"{self.MODEL_FILE_PREFIX}.yml"
+    @classmethod
+    def load_config(cls, model_dir):
+        config_path = model_dir / f"{cls.MODEL_FILE_PREFIX}.yml"
         with codecs.open(config_path, "r", "utf-8") as file:
             dic = yaml.load(file, Loader=yaml.FullLoader)
         return dic
 
     def apply(self, x):
         """predict"""
-        yield from self.engine(x)
+        yield from self._generate_res(self.engine(x))
+
+    @generatorable_method
+    def _generate_res(self, data):
+        return self._pack_res(data)
 
     @abstractmethod
     def _build_components(self):
         raise NotImplementedError
+
+    @abstractmethod
+    def _pack_res(self, data):
+        raise NotImplementedError

+ 10 - 2
paddlex/inference/predictors/image_classification.py

@@ -17,6 +17,8 @@ import numpy as np
 from ...utils.func_register import FuncRegister
 from ...modules.image_classification.model_list import MODELS
 from ..components import *
+from ..results import TopkResult
+from ..utils.process_hook import batchable_method
 from .base import BasePredictor
 
 
@@ -43,13 +45,12 @@ class ClasPredictor(BasePredictor):
             ops[tf_key] = op
 
         kernel_option = PaddlePredictorOption()
-        # kernel_option.set_device(self.device)
+        kernel_option.set_device(self.device)
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=kernel_option,
         )
-        predictor.set_inputs({"imgs": "img"})
         ops["predictor"] = predictor
 
         post_processes = self.config["PostProcess"]
@@ -95,3 +96,10 @@ class ClasPredictor(BasePredictor):
     @register("Topk")
     def build_topk(self, topk, label_list=None):
         return Topk(topk=int(topk), class_ids=label_list)
+
+    @batchable_method
+    def _pack_res(self, data):
+        keys = ["img_path", "class_ids", "scores"]
+        if "label_names" in data:
+            keys.append("label_names")
+        return {"topk_res": TopkResult({key: data[key] for key in keys})}

+ 8 - 2
paddlex/inference/predictors/text_detection.py

@@ -17,6 +17,8 @@ import numpy as np
 from ...utils.func_register import FuncRegister
 from ...modules.text_detection.model_list import MODELS
 from ..components import *
+from ..results import TextDetResult
+from ..utils.process_hook import batchable_method
 from .base import BasePredictor
 
 
@@ -43,13 +45,12 @@ class TextDetPredictor(BasePredictor):
                 ops[tf_key] = op
 
         kernel_option = PaddlePredictorOption()
-        # kernel_option.set_device(self.device)
+        kernel_option.set_device(self.device)
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=kernel_option,
         )
-        predictor.set_inputs({"imgs": "img"})
         ops["predictor"] = predictor
 
         key, op = self.build_postprocess(**self.config["PostProcess"])
@@ -104,3 +105,8 @@ class TextDetPredictor(BasePredictor):
     @register("KeepKeys")
     def foo(self, *args, **kwargs):
         return None
+
+    @batchable_method
+    def _pack_res(self, data):
+        keys = ["img_path", "dt_polys", "dt_scores"]
+        return {"text_det_res": TextDetResult({key: data[key] for key in keys})}

+ 8 - 2
paddlex/inference/predictors/text_recognition.py

@@ -17,6 +17,8 @@ import numpy as np
 from ...utils.func_register import FuncRegister
 from ...modules.text_recognition.model_list import MODELS
 from ..components import *
+from ..results import TextRecResult
+from ..utils.process_hook import batchable_method
 from .base import BasePredictor
 
 
@@ -44,13 +46,12 @@ class TextRecPredictor(BasePredictor):
                 ops[tf_key] = op
 
         kernel_option = PaddlePredictorOption()
-        # kernel_option.set_device(self.device)
+        kernel_option.set_device(self.device)
         predictor = ImagePredictor(
             model_dir=self.model_dir,
             model_prefix=self.MODEL_FILE_PREFIX,
             option=kernel_option,
         )
-        predictor.set_inputs({"imgs": "img"})
         ops["predictor"] = predictor
 
         key, op = self.build_postprocess(**self.config["PostProcess"])
@@ -81,3 +82,8 @@ class TextRecPredictor(BasePredictor):
     @register("KeepKeys")
     def foo(self, *args, **kwargs):
         return None
+
+    @batchable_method
+    def _pack_res(self, data):
+        keys = ["img_path", "rec_text", "rec_score"]
+        return {"text_rec_res": TextRecResult({key: data[key] for key in keys})}

+ 19 - 0
paddlex/utils/logging.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 
+import inspect
 import logging
 import sys
 
@@ -36,6 +37,24 @@ _logger = logging.getLogger(LOGGER_NAME)
 
 def debug(msg, *args, **kwargs):
     """debug"""
+    if DEBUG:
+        frame = inspect.currentframe()
+        caller_frame = frame.f_back
+        caller_func_name = caller_frame.f_code.co_name
+
+        if "self" in caller_frame.f_locals:
+            caller_class_name = caller_frame.f_locals["self"].__class__.__name__
+        elif "cls" in caller_frame.f_locals:
+            caller_class_name = caller_frame.f_locals["cls"].__name__
+        else:
+            caller_class_name = None
+
+        if caller_class_name:
+            caller_info = f"{caller_class_name}::{caller_func_name}"
+        else:
+            caller_info = f"{caller_func_name}"
+        msg = f"【{caller_info}】{msg}"
+
     _logger.debug(msg, *args, **kwargs)