Переглянути джерело

improve new pdx inference

gaotingquan 1 рік тому
батько
коміт
fdfdc6a84a

+ 12 - 29
paddlex/inference/components/paddle_predictor/option.py

@@ -12,27 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from functools import wraps, partial
-
+from ....utils.func_register import FuncRegister
 from ....utils import logging
 
 
-def register(register_map, key):
-    """register the option setting func"""
-
-    def decorator(func):
-        register_map[key] = func
-
-        @wraps(func)
-        def wrapper(self, *args, **kwargs):
-            return func(self, *args, **kwargs)
-
-        return wrapper
-
-    return decorator
-
-
 class PaddlePredictorOption(object):
     """Paddle Inference Engine Option"""
 
@@ -45,9 +28,9 @@ class PaddlePredictorOption(object):
         "mkldnn_bf16",
     )
     SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu")
-    _REGISTER_MAP = {}
 
-    register2self = partial(register, _REGISTER_MAP)
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
 
     def __init__(self, **kwargs):
         super().__init__()
@@ -80,7 +63,7 @@ class PaddlePredictorOption(object):
             "delete_pass": [],
         }
 
-    @register2self("run_mode")
+    @register("run_mode")
     def set_run_mode(self, run_mode: str):
         """set run mode"""
         if run_mode not in self.SUPPORT_RUN_MODE:
@@ -90,14 +73,14 @@ class PaddlePredictorOption(object):
             )
         self._cfg["run_mode"] = run_mode
 
-    @register2self("batch_size")
+    @register("batch_size")
     def set_batch_size(self, batch_size: int):
         """set batch size"""
         if not isinstance(batch_size, int) or batch_size < 1:
             raise Exception()
         self._cfg["batch_size"] = batch_size
 
-    @register2self("device")
+    @register("device")
     def set_device(self, device_setting: str):
         """set device"""
         if len(device_setting.split(":")) == 1:
@@ -117,36 +100,36 @@ class PaddlePredictorOption(object):
         self._cfg["device"] = device.lower()
         self._cfg["device_id"] = int(device_id)
 
-    @register2self("min_subgraph_size")
+    @register("min_subgraph_size")
     def set_min_subgraph_size(self, min_subgraph_size: int):
         """set min subgraph size"""
         if not isinstance(min_subgraph_size, int):
             raise Exception()
         self._cfg["min_subgraph_size"] = min_subgraph_size
 
-    @register2self("shape_info_filename")
+    @register("shape_info_filename")
     def set_shape_info_filename(self, shape_info_filename: str):
         """set shape info filename"""
         self._cfg["shape_info_filename"] = shape_info_filename
 
-    @register2self("trt_calib_mode")
+    @register("trt_calib_mode")
     def set_trt_calib_mode(self, trt_calib_mode):
         """set trt calib mode"""
         self._cfg["trt_calib_mode"] = trt_calib_mode
 
-    @register2self("cpu_threads")
+    @register("cpu_threads")
     def set_cpu_threads(self, cpu_threads):
         """set cpu threads"""
         if not isinstance(cpu_threads, int) or cpu_threads < 1:
             raise Exception()
         self._cfg["cpu_threads"] = cpu_threads
 
-    @register2self("trt_use_static")
+    @register("trt_use_static")
     def set_trt_use_static(self, trt_use_static):
         """set trt use static"""
         self._cfg["trt_use_static"] = trt_use_static
 
-    @register2self("delete_pass")
+    @register("delete_pass")
     def set_delete_pass(self, delete_pass):
         self._cfg["delete_pass"] = delete_pass
 

+ 1 - 1
paddlex/inference/components/transforms/image/__init__.py

@@ -133,7 +133,7 @@ class ReadImage(BaseComponent):
     def _get_image_list(self, img_file):
         imgs_lists = []
         if img_file is None or not os.path.exists(img_file):
-            raise Exception("not found any img file in {}".format(img_file))
+            raise Exception(f"Not found any img file in path: {img_file}")
 
         if os.path.isfile(img_file) and img_file.split(".")[-1] in self.SUFFIX:
             imgs_lists.append(img_file)

+ 1 - 1
paddlex/inference/pipelines/base.py

@@ -14,7 +14,7 @@
 
 from abc import ABC, abstractmethod
 
-from ...utils.misc import AutoRegisterABCMetaClass
+from ...utils.subclass_register import AutoRegisterABCMetaClass
 
 
 def create_pipeline(

+ 1 - 1
paddlex/inference/predictors/base.py

@@ -17,7 +17,7 @@ import codecs
 from pathlib import Path
 from abc import abstractmethod
 
-from ...utils.misc import AutoRegisterABCMetaClass
+from ...utils.subclass_register import AutoRegisterABCMetaClass
 from ..components.base import BaseComponent, ComponentsEngine
 from .official_models import official_models
 

+ 10 - 26
paddlex/inference/predictors/image_classification.py

@@ -12,30 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import numpy as np
-from functools import partial, wraps
 
+from ...utils.func_register import FuncRegister
 from ...modules.image_classification.model_list import MODELS
 from ..components import *
 from .base import BasePredictor
 
 
-def register(register_map, key):
-    """register the option setting func"""
-
-    def decorator(func):
-        register_map[key] = func
-
-        @wraps(func)
-        def wrapper(self, *args, **kwargs):
-            return func(self, *args, **kwargs)
-
-        return wrapper
-
-    return decorator
-
-
 class ClasPredictor(BasePredictor):
 
     entities = MODELS
@@ -45,15 +29,15 @@ class ClasPredictor(BasePredictor):
     DEAULT_INPUTS = {"x": "x"}
     DEAULT_OUTPUTS = {"topk_res": "topk_res"}
 
-    _REGISTER_MAP = {}
-    register2self = partial(register, _REGISTER_MAP)
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
 
     def _build_components(self):
         ops = {}
         ops["ReadImage"] = ReadImage(batch_size=self.kwargs.get("batch_size", 1))
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
-            func = self._REGISTER_MAP.get(tf_key)
+            func = self._FUNC_MAP.get(tf_key)
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             ops[tf_key] = op
@@ -70,13 +54,13 @@ class ClasPredictor(BasePredictor):
 
         post_processes = self.config["PostProcess"]
         for key in post_processes:
-            func = self._REGISTER_MAP.get(key)
+            func = self._FUNC_MAP.get(key)
             args = post_processes.get(key, {})
             op = func(self, **args) if args else func(self)
             ops[key] = op
         return ops
 
-    @register2self("ResizeImage")
+    @register("ResizeImage")
     def build_resize(self, resize_short=None, size=None):
         assert resize_short or size
         if resize_short:
@@ -87,11 +71,11 @@ class ClasPredictor(BasePredictor):
             op = Resize(target_size=size)
         return op
 
-    @register2self("CropImage")
+    @register("CropImage")
     def build_crop(self, size=224):
         return Crop(crop_size=size)
 
-    @register2self("NormalizeImage")
+    @register("NormalizeImage")
     def build_normalize(
         self,
         mean=[0.485, 0.456, 0.406],
@@ -104,10 +88,10 @@ class ClasPredictor(BasePredictor):
         assert order == ""
         return Normalize(mean=mean, std=std)
 
-    @register2self("ToCHWImage")
+    @register("ToCHWImage")
     def build_to_chw(self):
         return ToCHWImage()
 
-    @register2self("Topk")
+    @register("Topk")
     def build_topk(self, topk, label_list=None):
         return Topk(topk=int(topk), class_ids=label_list)

+ 4 - 0
paddlex/inference/predictors/official_models.py

@@ -14,6 +14,7 @@
 
 from pathlib import Path
 
+from ...utils import logging
 from ...utils.cache import CACHE_DIR
 from ...utils.download import download_and_extract
 
@@ -177,6 +178,9 @@ class OfficialModelsDict(dict):
     def __getitem__(self, key):
         url = super().__getitem__(key)
         save_dir = Path(CACHE_DIR) / "official_models"
+        logging.info(
+            f"Using official model ({key}), the model files will be be automatically downloaded and saved in {save_dir}."
+        )
         download_and_extract(url, save_dir, f"{key}", overwrite=False)
         return save_dir / f"{key}"
 

+ 10 - 27
paddlex/inference/predictors/text_detection.py

@@ -12,31 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import numpy as np
-from functools import partial, wraps
 
+from ...utils.func_register import FuncRegister
 from ...modules.text_detection.model_list import MODELS
-
 from ..components import *
 from .base import BasePredictor
 
 
-def register(register_map, key):
-    """register the option setting func"""
-
-    def decorator(func):
-        register_map[key] = func
-
-        @wraps(func)
-        def wrapper(self, *args, **kwargs):
-            return func(self, *args, **kwargs)
-
-        return wrapper
-
-    return decorator
-
-
 class TextDetPredictor(BasePredictor):
 
     entities = MODELS
@@ -46,14 +29,14 @@ class TextDetPredictor(BasePredictor):
     DEAULT_INPUTS = {"x": "x"}
     DEAULT_OUTPUTS = {"text_det_res": "text_det_res"}
 
-    _REGISTER_MAP = {}
-    register2self = partial(register, _REGISTER_MAP)
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
 
     def _build_components(self):
         ops = {}
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
-            func = self._REGISTER_MAP.get(tf_key)
+            func = self._FUNC_MAP.get(tf_key)
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op:
@@ -73,16 +56,16 @@ class TextDetPredictor(BasePredictor):
         ops[key] = op
         return ops
 
-    @register2self("DecodeImage")
+    @register("DecodeImage")
     def build_readimg(self, channel_first, img_mode):
         assert channel_first == False
         return ReadImage(format=img_mode, batch_size=self.kwargs.get("batch_size", 1))
 
-    @register2self("DetResizeForTest")
+    @register("DetResizeForTest")
     def build_resize(self, resize_long=960):
         return DetResizeForTest(limit_side_len=resize_long, limit_type="max")
 
-    @register2self("NormalizeImage")
+    @register("NormalizeImage")
     def build_normalize(
         self,
         mean=[0.485, 0.456, 0.406],
@@ -95,7 +78,7 @@ class TextDetPredictor(BasePredictor):
             mean=mean, std=std, scale=scale, order=order, channel_num=channel_num
         )
 
-    @register2self("ToCHWImage")
+    @register("ToCHWImage")
     def build_to_chw(self):
         return ToCHWImage()
 
@@ -114,10 +97,10 @@ class TextDetPredictor(BasePredictor):
         else:
             raise Exception()
 
-    @register2self("DetLabelEncode")
+    @register("DetLabelEncode")
     def foo(self, *args, **kwargs):
         return None
 
-    @register2self("KeepKeys")
+    @register("KeepKeys")
     def foo(self, *args, **kwargs):
         return None

+ 9 - 26
paddlex/inference/predictors/text_recognition.py

@@ -12,31 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import numpy as np
-from functools import partial, wraps
 
+from ...utils.func_register import FuncRegister
 from ...modules.text_recognition.model_list import MODELS
-
 from ..components import *
 from .base import BasePredictor
 
 
-def register(register_map, key):
-    """register the option setting func"""
-
-    def decorator(func):
-        register_map[key] = func
-
-        @wraps(func)
-        def wrapper(self, *args, **kwargs):
-            return func(self, *args, **kwargs)
-
-        return wrapper
-
-    return decorator
-
-
 class TextRecPredictor(BasePredictor):
 
     entities = MODELS
@@ -46,15 +29,15 @@ class TextRecPredictor(BasePredictor):
     DEAULT_INPUTS = {"x": "x"}
     DEAULT_OUTPUTS = {"text_rec_res": "text_rec_res"}
 
-    _REGISTER_MAP = {}
-    register2self = partial(register, _REGISTER_MAP)
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
 
     def _build_components(self):
         ops = {}
         for cfg in self.config["PreProcess"]["transform_ops"]:
             tf_key = list(cfg.keys())[0]
-            assert tf_key in self._REGISTER_MAP
-            func = self._REGISTER_MAP.get(tf_key)
+            assert tf_key in self._FUNC_MAP
+            func = self._FUNC_MAP.get(tf_key)
             args = cfg.get(tf_key, {})
             op = func(self, **args) if args else func(self)
             if op:
@@ -74,12 +57,12 @@ class TextRecPredictor(BasePredictor):
         ops[key] = op
         return ops
 
-    @register2self("DecodeImage")
+    @register("DecodeImage")
     def build_readimg(self, channel_first, img_mode):
         assert channel_first == False
         return ReadImage(format=img_mode, batch_size=self.kwargs.get("batch_size", 1))
 
-    @register2self("RecResizeImg")
+    @register("RecResizeImg")
     def build_resize(self, image_shape):
         return OCRReisizeNormImg(rec_image_shape=image_shape)
 
@@ -91,10 +74,10 @@ class TextRecPredictor(BasePredictor):
         else:
             raise Exception()
 
-    @register2self("MultiLabelEncode")
+    @register("MultiLabelEncode")
     def foo(self, *args, **kwargs):
         return None
 
-    @register2self("KeepKeys")
+    @register("KeepKeys")
     def foo(self, *args, **kwargs):
         return None

+ 51 - 0
paddlex/inference/results/base.py

@@ -0,0 +1,51 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+from pathlib import Path
+import json
+
+from ...utils import logging
+from ..utils.io import JsonWriter, ImageReader, ImageWriter
+
+
+class BaseResult(dict):
+    def __init__(self, data):
+        super().__init__(data)
+        self._json_writer = JsonWriter()
+        self._img_reader = ImageReader(backend="opencv")
+        self._img_writer = ImageWriter(backend="opencv")
+
+    def save_to_json(self, save_path, indent=4, ensure_ascii=False):
+        if not save_path.endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
+        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
+
+    def save_to_img(self, save_path):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
+        res_img = self._get_res_img()
+        if res_img is not None:
+            self._img_writer.write(save_path.as_posix(), res_img)
+            logging.info(f"The result has been saved in {save_path}.")
+
+    def print(self, json_format=True, indent=4, ensure_ascii=False):
+        str_ = self
+        if json_format:
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        logging.info(str_)
+
+    @abstractmethod
+    def _get_res_img(self):
+        raise NotImplementedError

+ 8 - 37
paddlex/inference/results/ocr.py

@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from pathlib import Path
-import json
 import math
 import random
 import numpy as np
@@ -21,50 +19,23 @@ import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
-
-from ...utils import logging
+from .base import BaseResult
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
-from ..utils.io import JsonWriter, ImageWriter, ImageReader
-
-
-class OCRResult(dict):
-    def __init__(self, data):
-        super().__init__(data)
-        self._json_writer = JsonWriter()
-        self._img_reader = ImageReader(backend="opencv")
-        self._img_writer = ImageWriter(backend="opencv")
-
-    def save_json(self, save_path, indent=4, ensure_ascii=False):
-        if not save_path.endswith(".json"):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
-        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
+from ..utils.io import ImageReader
 
-    def save_img(self, save_path):
-        if not save_path.lower().endswith((".jpg", ".png")):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
-        res_img = self._draw_ocr_box_txt(
-            self["img_path"], self["dt_polys"], self["rec_text"], self["rec_score"]
-        )
-        self._img_writer.write(save_path.as_posix(), res_img)
-        logging.info(f"The result has been saved in {save_path}.")
 
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
-        str_ = self
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
+class OCRResult(BaseResult):
 
-    def _draw_ocr_box_txt(
+    def _get_res_img(
         self,
-        img_path,
-        boxes,
-        txts=None,
-        scores=None,
         drop_score=0.5,
         font_path=PINGFANG_FONT_FILE_PATH,
     ):
         """draw ocr result"""
-        img = self._img_reader.read(img_path)
+        boxes = self["dt_polys"]
+        txts = (self["rec_text"],)
+        scores = self["rec_score"]
+        img = self._img_reader.read(self["img_path"])
         image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
         h, w = image.height, image.width
         img_left = image.copy()

+ 9 - 33
paddlex/inference/results/text_det.py

@@ -12,45 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from pathlib import Path
-import json
 import numpy as np
 import cv2
 
-from ...utils import logging
-from ..utils.io import JsonWriter, ImageWriter, ImageReader
+from ..utils.io import ImageReader
+from .base import BaseResult
 
 
-class TextDetResult(dict):
-    def __init__(self, data):
-        super().__init__(data)
-        self._json_writer = JsonWriter()
-        self._img_reader = ImageReader(backend="opencv")
-        self._img_writer = ImageWriter(backend="opencv")
+class TextDetResult(BaseResult):
 
-    def save_json(self, save_path, indent=4, ensure_ascii=False):
-        if not save_path.endswith(".json"):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
-        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
-
-    def save_img(self, save_path):
-        if not save_path.lower().endswith((".jpg", ".png")):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
-        res_img = self._draw_rectangle(self["img_path"], self["dt_polys"])
-        self._img_writer.write(save_path.as_posix(), res_img)
-
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
-        str_ = self
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
-
-    def _draw_rectangle(self, img_path, boxes):
+    def _get_res_img(self):
         """draw rectangle"""
-        boxes = np.array(boxes)
-        img = self._img_reader.read(img_path)
-        img_show = img.copy()
+        boxes = np.array(self["dt_polys"])
+        img = self._img_reader.read(self["img_path"])
+        res_img = img.copy()
         for box in boxes.astype(int):
             box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
-            cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
-        return img_show
+            cv2.polylines(res_img, [box], True, (0, 0, 255), 2)
+        return res_img

+ 4 - 24
paddlex/inference/results/text_rec.py

@@ -12,32 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from pathlib import Path
-import json
-import numpy as np
-import cv2
+from .base import BaseResult
 
-from ...utils import logging
-from ..utils.io import JsonWriter, ImageWriter, ImageReader
 
-
-class TextRecResult(dict):
+class TextRecResult(BaseResult):
     def __init__(self, data):
         super().__init__(data)
-        self._json_writer = JsonWriter()
-        self._img_reader = ImageReader(backend="opencv")
-        self._img_writer = ImageWriter(backend="opencv")
-
-    def save_json(self, save_path, indent=4, ensure_ascii=False):
-        if not save_path.endswith(".json"):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
-        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
-
-    def save_img(self, save_path):
-        raise Exception()
 
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
-        str_ = self
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
+    def _get_res_img(self, save_path):
+        raise Exception("Don't support to save Text Rec result to img!")

+ 9 - 50
paddlex/inference/results/topk.py

@@ -12,48 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from pathlib import Path
-import json
+
 import PIL
 from PIL import ImageDraw, ImageFont
 import numpy as np
 
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
-from ...utils import logging
-from ..utils.io import JsonWriter, ImageWriter, ImageReader
 from ..utils.color_map import get_colormap
+from .base import BaseResult
 
 
-class TopkResult(dict):
+class TopkResult(BaseResult):
     def __init__(self, data):
         super().__init__(data)
-        self._json_writer = JsonWriter()
-        self._img_reader = ImageReader(backend="pil")
-        self._img_writer = ImageWriter(backend="pillow")
-
-    def save_json(self, save_path, indent=4, ensure_ascii=False):
-        if not save_path.endswith(".json"):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
-        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
-
-    def save_img(self, save_path):
-        if not save_path.lower().endswith((".jpg", ".png")):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
-        labels = self.get("label_names", self["class_ids"])
-        res_img = self._draw_label(self["img_path"], self["scores"], labels)
-        self._img_writer.write(save_path, res_img)
-
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
-        str_ = self
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
+        self._img_reader.set_backend("pillow")
+        self._img_writer.set_backend("pillow")
 
-    def _draw_label(self, img_path, scores, class_ids):
+    def _get_res_img(self):
         """Draw label on image"""
-        label_str = f"{class_ids[0]} {scores[0]:.2f}"
+        labels = self.get("label_names", self["class_ids"])
+        label_str = f"{labels[0]} {self['scores'][0]:.2f}"
 
-        image = self._img_reader.read(img_path)
+        image = self._img_reader.read(self["img_path"])
         image = image.convert("RGB")
         image_size = image.size
         draw = ImageDraw.Draw(image)
@@ -104,24 +84,3 @@ class TopkResult(dict):
             return light.astype("int32")
         else:
             return dark.astype("int32")
-
-
-# class SaveClsResults(BaseComponent):
-
-#     INPUT_KEYS = ["img_path", "cls_pred"]
-#     OUTPUT_KEYS = None
-#     DEAULT_INPUTS = {"img_path": "img_path", "cls_pred": "cls_pred"}
-#     DEAULT_OUTPUTS = {}
-
-#     def __init__(self, save_dir, class_ids=None):
-#         super().__init__()
-#         self.save_dir = save_dir
-#         self.class_id_map = _parse_class_id_map(class_ids)
-#         self._json_writer = ImageWriter(backend="pillow")
-
-
-#     def _write_image(self, path, image):
-#         """write image"""
-#         if os.path.exists(path):
-#             logging.warning(f"{path} already exists. Overwriting it.")
-#         self._json_writer.write(path, image)

+ 6 - 1
paddlex/inference/utils/io/readers.py

@@ -50,6 +50,11 @@ class _BaseReader(object):
             bk_args = self.bk_args
         return self._init_backend(self.bk_type, bk_args)
 
+    def set_backend(self, backend, **bk_args):
+        self.bk_type = backend
+        self.bk_args = bk_args
+        self._backend = self.get_backend()
+
     def _init_backend(self, bk_type, bk_args):
         """init backend"""
         raise NotImplementedError
@@ -78,7 +83,7 @@ class ImageReader(_BaseReader):
         """init backend"""
         if bk_type == "opencv":
             return OpenCVImageReaderBackend(**bk_args)
-        elif bk_type == "pil":
+        elif bk_type == "pil" or bk_type == "pillow":
             return PILImageReaderBackend(**bk_args)
         else:
             raise ValueError("Unsupported backend type")

+ 6 - 1
paddlex/inference/utils/io/writers.py

@@ -55,6 +55,11 @@ class _BaseWriter(object):
             bk_args = self.bk_args
         return self._init_backend(self.bk_type, bk_args)
 
+    def set_backend(self, backend, **bk_args):
+        self.bk_type = backend
+        self.bk_args = bk_args
+        self._backend = self.get_backend()
+
     def _init_backend(self, bk_type, bk_args):
         """init backend"""
         raise NotImplementedError
@@ -82,7 +87,7 @@ class ImageWriter(_BaseWriter):
         """init backend"""
         if bk_type == "opencv":
             return OpenCVImageWriterBackend(**bk_args)
-        elif bk_type == "pillow":
+        elif bk_type == "pil" or bk_type == "pillow":
             return PILImageWriterBackend(**bk_args)
         else:
             raise ValueError("Unsupported backend type")

+ 40 - 0
paddlex/utils/func_register.py

@@ -0,0 +1,40 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import wraps
+
+from . import logging
+
+
+class FuncRegister(object):
+    def __init__(self, register_map):
+        assert isinstance(register_map, dict)
+        self._register_map = register_map
+
+    def __call__(self, key):
+        """register the decoratored func as key in dict"""
+
+        def decorator(func):
+            self._register_map[key] = func
+            logging.debug(
+                f"The func ({func.__name__}) has been registered as key ({key})."
+            )
+
+            @wraps(func)
+            def wrapper(self, *args, **kwargs):
+                return func(self, *args, **kwargs)
+
+            return wrapper
+
+        return decorator

+ 1 - 0
paddlex/utils/misc.py

@@ -98,6 +98,7 @@ class Singleton(type):
         return cls._insts[cls]
 
 
+# TODO(gaotingquan): has been mv to subclass_register.py
 class AutoRegisterMetaClass(type):
     """meta class that automatically registry subclass to its baseclass
 

+ 101 - 0
paddlex/utils/subclass_register.py

@@ -0,0 +1,101 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABCMeta
+
+from . import logging
+from .errors import (
+    raise_class_not_found_error,
+    raise_no_entity_registered_error,
+    DuplicateRegistrationError,
+)
+
+
+class AutoRegisterMetaClass(type):
+    """meta class that automatically registry subclass to its baseclass
+
+    Args:
+        type (class): type
+
+    Returns:
+        class: meta class
+    """
+
+    __model_type_attr_name = "entities"
+    __base_class_flag = "__is_base"
+    __registered_map_name = "__registered_map"
+
+    def __new__(mcs, name, bases, attrs):
+        cls = super().__new__(mcs, name, bases, attrs)
+        mcs.__register_model_entity(bases, cls, attrs)
+        return cls
+
+    @classmethod
+    def __register_model_entity(mcs, bases, cls, attrs):
+        if bases:
+            for base in bases:
+                base_cls = mcs.__find_base_class(base)
+                if base_cls:
+                    mcs.__register_to_base_class(base_cls, cls)
+
+    @classmethod
+    def __find_base_class(mcs, cls):
+        is_base_flag = mcs.__base_class_flag
+        if is_base_flag.startswith("__"):
+            is_base_flag = f"_{cls.__name__}" + is_base_flag
+        if getattr(cls, is_base_flag, False):
+            return cls
+        for base in cls.__bases__:
+            base_cls = mcs.__find_base_class(base)
+            if base_cls:
+                return base_cls
+        return None
+
+    @classmethod
+    def __register_to_base_class(mcs, base, cls):
+        cls_entity_name = getattr(cls, mcs.__model_type_attr_name, cls.__name__)
+        if isinstance(cls_entity_name, str):
+            cls_entity_name = [cls_entity_name]
+
+        records = getattr(base, mcs.__registered_map_name, {})
+        for name in cls_entity_name:
+            if name in records and records[name] is not cls:
+                raise DuplicateRegistrationError(
+                    f"The name(`{name}`) duplicated registration! The class entities are: `{cls.__name__}` and \
+`{records[name].__name__}`."
+                )
+            records[name] = cls
+            logging.debug(
+                f"The class entity({cls.__name__}) has been register as name(`{name}`)."
+            )
+        setattr(base, mcs.__registered_map_name, records)
+
+    def all(cls):
+        """get all subclass"""
+        if not hasattr(cls, type(cls).__registered_map_name):
+            raise_no_entity_registered_error(cls)
+        return getattr(cls, type(cls).__registered_map_name)
+
+    def get(cls, name: str):
+        """get the registried class by name"""
+        all_entities = cls.all()
+        if name not in all_entities:
+            raise_class_not_found_error(name, cls, all_entities)
+        return all_entities[name]
+
+
+class AutoRegisterABCMetaClass(ABCMeta, AutoRegisterMetaClass):
+    """AutoRegisterABCMetaClass"""
+
+    pass