Ver código fonte

new_pipeline PP-ShiTuV2 (#2761)

* restructure PP-ShiTuV2

* support to specify topk, rec_threshold, hamming_radius, det_threshold
Tingquan Gao 10 meses atrás
pai
commit
d26f505b5d

+ 27 - 0
api_examples/pipelines/test_shitu.py

@@ -0,0 +1,27 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="PP-ShiTuV2")
+
+index_data = pipeline.build_index(
+    gallery_imgs="drink_dataset_v2.0/", gallery_label="drink_dataset_v2.0/gallery.txt"
+)
+index_data.save("drink_index")
+
+output = pipeline.predict("./drink_dataset_v2.0/test_images/", index=index_data)
+for res in output:
+    res.print()
+    res.save_to_img("./output/")

+ 18 - 0
paddlex/configs/pipelines/PP-ShiTuV2.yaml

@@ -0,0 +1,18 @@
+pipeline_name: PP-ShiTuV2
+
+index: None
+det_threshold: 0.5
+rec_threshold: 0.5
+rec_topk: 5
+
+SubModules:
+  Detection:
+    module_name: text_detection
+    model_name: PP-ShiTuV2_det
+    model_dir: null
+    batch_size: 1    
+  Recognition:
+    module_name: text_recognition
+    model_name: PP-ShiTuV2_rec
+    model_dir: null
+    batch_size: 1

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -32,6 +32,7 @@ from .anomaly_detection import AnomalyDetectionPipeline
 from .ts_forecasting import TSFcPipeline
 from .ts_anomaly_detection import TSAnomalyDetPipeline
 from .ts_classification import TSClsPipeline
+from .pp_shitu_v2 import ShiTuV2Pipeline
 from .attribute_recognition import (
     PedestrianAttributeRecPipeline,
     VehicleAttributeRecPipeline,

+ 1 - 0
paddlex/inference/pipelines_new/components/__init__.py

@@ -19,3 +19,4 @@ from .utils.mixin import HtmlMixin, XlsxMixin
 from .chat_server.base import BaseChat
 from .retriever.base import BaseRetriever
 from .prompt_engeering.base import BaseGeneratePrompt
+from .faisser import FaissBuilder, FaissIndexer

+ 342 - 0
paddlex/inference/pipelines_new/components/faisser.py

@@ -0,0 +1,342 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import pickle
+from pathlib import Path
+import faiss
+import numpy as np
+
+from ....utils import logging
+from ...utils.io import YAMLWriter, YAMLReader
+
+
+class IndexData:
+    VECTOR_FN = "vector"
+    VECTOR_SUFFIX = ".index"
+    IDMAP_FN = "id_map"
+    IDMAP_SUFFIX = ".yaml"
+
+    def __init__(self, index, index_info):
+        self._index = index
+        self._index_info = index_info
+        self._id_map = index_info["id_map"]
+        self._metric_type = index_info["metric_type"]
+        self._index_type = index_info["index_type"]
+
+    @property
+    def index(self):
+        return self._index
+
+    @property
+    def index_bytes(self):
+        return faiss.serialize_index(self._index)
+
+    @property
+    def id_map(self):
+        return self._id_map
+
+    @property
+    def metric_type(self):
+        return self._metric_type
+
+    @property
+    def index_type(self):
+        return self._index_type
+
+    @property
+    def index_info(self):
+        return {
+            "index_type": self.index_type,
+            "metric_type": self.metric_type,
+            "id_map": self._convert_int(self.id_map),
+        }
+
+    def _convert_int(self, id_map):
+        return {int(k): str(v) for k, v in id_map.items()}
+
+    @staticmethod
+    def _convert_int64(id_map):
+        return {np.int64(k): str(v) for k, v in id_map.items()}
+
+    def save(self, save_dir):
+        save_dir = Path(save_dir)
+        save_dir.mkdir(parents=True, exist_ok=True)
+        vector_path = (save_dir / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}").as_posix()
+        index_info_path = (save_dir / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}").as_posix()
+
+        if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
+            faiss.write_index_binary(self.index, vector_path)
+        else:
+            faiss.write_index(self.index, vector_path)
+
+        yaml_writer = YAMLWriter()
+        yaml_writer.write(
+            index_info_path,
+            self.index_info,
+            default_flow_style=False,
+            allow_unicode=True,
+        )
+
+    @classmethod
+    def load(cls, index):
+        if isinstance(index, str):
+            index_root = Path(index)
+            vector_path = index_root / f"{cls.VECTOR_FN}{cls.VECTOR_SUFFIX}"
+            index_info_path = index_root / f"{cls.IDMAP_FN}{cls.IDMAP_SUFFIX}"
+
+            assert (
+                vector_path.exists()
+            ), f"Not found the {cls.VECTOR_FN}{cls.VECTOR_SUFFIX} file in {index}!"
+            assert (
+                index_info_path.exists()
+            ), f"Not found the {cls.IDMAP_FN}{cls.IDMAP_SUFFIX} file in {index}!"
+
+            yaml_reader = YAMLReader()
+            index_info = yaml_reader.read(index_info_path)
+            assert (
+                "id_map" in index_info
+                and "metric_type" in index_info
+                and "index_type" in index_info
+            ), f"The index_info file({index_info_path}) may have been damaged, `id_map` or `metric_type` or `index_type` not found in `index_info`."
+            id_map = IndexData._convert_int64(index_info["id_map"])
+
+            if index_info["metric_type"] in FaissBuilder.BINARY_METRIC_TYPE:
+                index = faiss.read_index_binary(vector_path.as_posix())
+            else:
+                index = faiss.read_index(vector_path.as_posix())
+            assert index.ntotal == len(
+                id_map
+            ), "data number in index is not equal in in id_map"
+
+            return index, id_map, index_info["metric_type"], index_info["index_type"]
+        else:
+            assert isinstance(index, IndexData)
+            return index.index, index.id_map, index.metric_type, index.index_type
+
+
+class FaissIndexer:
+
+    def __init__(
+        self,
+        index,
+    ):
+        super().__init__()
+        self._indexer, self.id_map, self.metric_type, index_type = IndexData.load(index)
+
+    def __call__(self, feature, score_thres, hamming_radius, topk):
+        scores_list, ids_list = self._indexer.search(np.array(feature), topk)
+        preds = []
+        for scores, ids in zip(scores_list, ids_list):
+            labels = []
+            for id in ids:
+                if id > 0:
+                    labels.append(self.id_map[id])
+            preds.append({"score": scores, "label": labels})
+
+        if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
+            idxs = np.where(scores_list[:, 0] > hamming_radius)[0]
+        else:
+            idxs = np.where(scores_list[:, 0] < score_thres)[0]
+        for idx in idxs:
+            preds[idx] = {"score": None, "label": None}
+        return preds
+
+
+class FaissBuilder:
+
+    SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")
+    SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32")
+    BINARY_METRIC_TYPE = ("hamming",)
+    BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash")
+
+    @classmethod
+    def _get_index_type(cls, metric_type, index_type, num=None):
+        # if IVF method, cal ivf number automaticlly
+        if index_type == "IVF":
+            index_type = index_type + str(min(int(num // 8), 65536))
+            if metric_type in cls.BINARY_METRIC_TYPE:
+                index_type += ",BFlat"
+            else:
+                index_type += ",Flat"
+
+        # for binary index, add B at head of index_type
+        if metric_type in cls.BINARY_METRIC_TYPE:
+            assert (
+                index_type in cls.BINARY_SUPPORT_INDEX_TYPE
+            ), f"The metric type({metric_type}) only support {cls.BINARY_SUPPORT_INDEX_TYPE} index types!"
+            index_type = "B" + index_type
+
+        if index_type == "HNSW32":
+            logging.warning("The HNSW32 method dose not support 'remove' operation")
+            index_type = "HNSW32"
+
+        if index_type == "Flat":
+            index_type = "Flat"
+
+        return index_type
+
+    @classmethod
+    def _get_metric_type(cls, metric_type):
+        if metric_type == "hamming":
+            return faiss.METRIC_Hamming
+        elif metric_type == "jaccard":
+            return faiss.METRIC_Jaccard
+        elif metric_type == "IP":
+            return faiss.METRIC_INNER_PRODUCT
+        elif metric_type == "L2":
+            return faiss.METRIC_L2
+
+    @classmethod
+    def build(
+        cls,
+        gallery_imgs,
+        gallery_label,
+        predict_func,
+        metric_type="IP",
+        index_type="HNSW32",
+    ):
+        assert (
+            index_type in cls.SUPPORT_INDEX_TYPE
+        ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
+
+        assert (
+            metric_type in cls.SUPPORT_METRIC_TYPE
+        ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
+
+        if isinstance(gallery_label, str):
+            gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
+        else:
+            gallery_docs, gallery_list = gallery_label, gallery_imgs
+
+        features = [res["feature"] for res in predict_func(gallery_list)]
+        dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
+        features = np.array(features).astype(dtype)
+        vector_num, vector_dim = features.shape
+
+        if metric_type in cls.BINARY_METRIC_TYPE:
+            index = faiss.index_binary_factory(
+                vector_dim,
+                cls._get_index_type(metric_type, index_type, vector_num),
+                cls._get_metric_type(metric_type),
+            )
+        else:
+            index = faiss.index_factory(
+                vector_dim,
+                cls._get_index_type(metric_type, index_type, vector_num),
+                cls._get_metric_type(metric_type),
+            )
+            index = faiss.IndexIDMap2(index)
+        ids = {}
+
+        # calculate id for new data
+        index, ids = cls._add_gallery(
+            metric_type, index, ids, features, gallery_docs, mode="new"
+        )
+        return IndexData(
+            index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
+        )
+
+    @classmethod
+    def remove(
+        cls,
+        remove_ids,
+        index,
+    ):
+        index, ids, metric_type, index_type = IndexData.load(index)
+        if index_type == "HNSW32":
+            raise RuntimeError(
+                "The index_type: HNSW32 dose not support 'remove' operation"
+            )
+        if isinstance(remove_ids, str):
+            lines = []
+            with open(remove_ids) as f:
+                lines = f.readlines()
+            remove_ids = []
+            for line in lines:
+                id_ = int(line.strip().split(" ")[0])
+                remove_ids.append(id_)
+            remove_ids = np.asarray(remove_ids)
+        else:
+            remove_ids = np.asarray(remove_ids)
+
+        # remove ids in id_map, remove index data in faiss index
+        index.remove_ids(remove_ids)
+        ids = {k: v for k, v in ids.items() if k not in remove_ids}
+        return IndexData(
+            index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
+        )
+
+    @classmethod
+    def append(cls, gallery_imgs, gallery_label, predict_func, index):
+        index, ids, metric_type, index_type = IndexData.load(index)
+        assert (
+            metric_type in cls.SUPPORT_METRIC_TYPE
+        ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
+
+        if isinstance(gallery_label, str):
+            gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
+        else:
+            gallery_docs, gallery_list = gallery_label, gallery_imgs
+
+        features = [res["feature"] for res in predict_func(gallery_list)]
+        dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
+        features = np.array(features).astype(dtype)
+
+        # calculate id for new data
+        index, ids = cls._add_gallery(
+            metric_type, index, ids, features, gallery_docs, mode="append"
+        )
+        return IndexData(
+            index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
+        )
+
+    @classmethod
+    def _add_gallery(
+        cls, metric_type, index, ids, gallery_features, gallery_docs, mode
+    ):
+        start_id = max(ids.keys()) + 1 if ids else 0
+        ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
+
+        # only train when new index file
+        if mode == "new":
+            if metric_type in cls.BINARY_METRIC_TYPE:
+                index.add(gallery_features)
+            else:
+                index.train(gallery_features)
+
+        if metric_type not in cls.BINARY_METRIC_TYPE:
+            index.add_with_ids(gallery_features, ids_now)
+        # TODO(gaotingquan): how append when using hamming metric type
+        # else:
+        #   pass
+
+        for i, d in zip(list(ids_now), gallery_docs):
+            ids[i] = d
+        return index, ids
+
+    @classmethod
+    def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
+        lines = []
+        files = []
+        labels = []
+        root = Path(gallery_imgs_root)
+        with open(gallery_label_path, "r", encoding="utf-8") as f:
+            lines = f.readlines()
+        for line in lines:
+            path, label = line.strip().split(delimiter)
+            file_path = root / path
+            files.append(file_path.as_posix())
+            labels.append(label)
+        return labels, files

+ 15 - 0
paddlex/inference/pipelines_new/pp_shitu_v2/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import ShiTuV2Pipeline

+ 152 - 0
paddlex/inference/pipelines_new/pp_shitu_v2/pipeline.py

@@ -0,0 +1,152 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+
+import pickle
+from pathlib import Path
+import numpy as np
+
+from ...utils.pp_option import PaddlePredictorOption
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
+from ..components import CropByBoxes, FaissIndexer, FaissBuilder
+from ..base import BasePipeline
+from .result import ShiTuResult
+
+
+class ShiTuV2Pipeline(BasePipeline):
+    """ShiTuV2 Pipeline"""
+
+    entities = "PP-ShiTuV2"
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ):
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        self._topk, self._rec_threshold, self._hamming_radius, self._det_threshold = (
+            config.get("topk", 5),
+            config.get("rec_threshold", 0.5),
+            config.get("hamming_radius", None),
+            config.get("det_threshold", 0.5),
+        )
+        index = config.get("index", None)
+
+        self.img_reader = ReadImage(format="BGR")
+        self.det_model = self.create_model(config["SubModules"]["Detection"])
+        self.rec_model = self.create_model(config["SubModules"]["Recognition"])
+        self.crop_by_boxes = CropByBoxes()
+        self.indexer = self._build_indexer(index=index) if index else None
+        self.batch_sampler = ImageBatchSampler(
+            batch_size=self.det_model.batch_sampler.batch_size
+        )
+
+    def predict(self, input, index=None, **kwargs):
+        indexer = FaissIndexer(index) if index is not None else self.indexer
+        assert indexer
+        topk = kwargs.get("topk", self._topk)
+        rec_threshold = kwargs.get("rec_threshold", self._rec_threshold)
+        hamming_radius = kwargs.get("hamming_radius", self._hamming_radius)
+        det_threshold = kwargs.get("det_threshold", self._det_threshold)
+        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            raw_imgs = self.img_reader(batch_data)
+            all_det_res = list(self.det_model(raw_imgs, threshold=det_threshold))
+            for input_data, raw_img, det_res in zip(batch_data, raw_imgs, all_det_res):
+                rec_res = self.get_rec_result(
+                    raw_img, det_res, indexer, rec_threshold, hamming_radius, topk
+                )
+                yield self.get_final_result(input_data, raw_img, det_res, rec_res)
+
+    def get_rec_result(
+        self, raw_img, det_res, indexer, rec_threshold, hamming_radius, topk
+    ):
+        if len(det_res["boxes"]) == 0:
+            w, h = raw_img.shape[:2]
+            det_res["boxes"].append(
+                {
+                    "cls_id": 0,
+                    "label": "full_img",
+                    "score": 0,
+                    "coordinate": [0, 0, h, w],
+                }
+            )
+        subs_of_img = list(self.crop_by_boxes(raw_img, det_res["boxes"]))
+        img_list = [img["img"] for img in subs_of_img]
+        all_rec_res = list(self.rec_model(img_list))
+        all_rec_res = indexer(
+            [rec_res["feature"] for rec_res in all_rec_res],
+            score_thres=rec_threshold,
+            hamming_radius=hamming_radius,
+            topk=topk,
+        )
+        output = {"label": [], "score": []}
+        for res in all_rec_res:
+            output["label"].append(res["label"])
+            output["score"].append(res["score"])
+        return output
+
+    def get_final_result(self, input_data, raw_img, det_res, rec_res):
+        single_img_res = {"input_path": input_data, "input_img": raw_img, "boxes": []}
+        for i, obj in enumerate(det_res["boxes"]):
+            rec_scores = rec_res["score"][i]
+            labels = rec_res["label"][i]
+            single_img_res["boxes"].append(
+                {
+                    "labels": labels,
+                    "rec_scores": rec_scores,
+                    "det_score": obj["score"],
+                    "coordinate": obj["coordinate"],
+                }
+            )
+        return ShiTuResult(single_img_res)
+
+    def build_index(
+        self,
+        gallery_imgs,
+        gallery_label,
+        metric_type="IP",
+        index_type="HNSW32",
+        **kwargs
+    ):
+        return FaissBuilder.build(
+            gallery_imgs,
+            gallery_label,
+            self.rec_model.predict,
+            metric_type=metric_type,
+            index_type=index_type,
+        )
+
+    def remove_index(self, remove_ids, index):
+        return FaissBuilder.remove(remove_ids, index)
+
+    def append_index(
+        self,
+        gallery_imgs,
+        gallery_label,
+        index,
+    ):
+        return FaissBuilder.append(
+            gallery_imgs,
+            gallery_label,
+            self.rec_model.predict,
+            index,
+        )

+ 110 - 0
paddlex/inference/pipelines_new/pp_shitu_v2/result.py

@@ -0,0 +1,110 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH, create_font
+from ...common.result import BaseCVResult
+from ...utils.color_map import get_colormap, font_colormap
+
+
+def draw_box(img, boxes):
+    """
+    Args:
+        img (PIL.Image.Image): PIL image
+        boxes (list): a list of dictionaries representing detection box information.
+    Returns:
+        img (PIL.Image.Image): visualized image
+    """
+    img = Image.fromarray(img)
+    font_size = int(0.018 * int(img.width)) + 2
+    font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+
+    draw_thickness = int(max(img.size) * 0.002)
+    draw = ImageDraw.Draw(img)
+    label2color = {}
+    catid2fontcolor = {}
+    color_list = get_colormap(rgb=True)
+
+    for i, dt in enumerate(boxes):
+        # clsid = dt["cls_id"]
+        label, bbox, score = dt["label"], dt["coordinate"], dt["score"]
+        if label not in label2color:
+            color_index = i % len(color_list)
+            label2color[label] = color_list[color_index]
+            catid2fontcolor[label] = font_colormap(color_index)
+        color = tuple(label2color[label])
+        font_color = tuple(catid2fontcolor[label])
+
+        if len(bbox) == 4:
+            # draw bbox of normal object detection
+            xmin, ymin, xmax, ymax = bbox
+            rectangle = [
+                (xmin, ymin),
+                (xmin, ymax),
+                (xmax, ymax),
+                (xmax, ymin),
+                (xmin, ymin),
+            ]
+        elif len(bbox) == 8:
+            # draw bbox of rotated object detection
+            x1, y1, x2, y2, x3, y3, x4, y4 = bbox
+            rectangle = [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)]
+            xmin = min(x1, x2, x3, x4)
+            ymin = min(y1, y2, y3, y4)
+        else:
+            raise ValueError(
+                f"Only support bbox format of [xmin,ymin,xmax,ymax] or [x1,y1,x2,y2,x3,y3,x4,y4], got bbox of shape {len(bbox)}."
+            )
+
+        # draw bbox
+        draw.line(
+            rectangle,
+            width=draw_thickness,
+            fill=color,
+        )
+
+        # draw label
+        text = "{} {:.2f}".format(dt["label"], score)
+        if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+            tw, th = draw.textsize(text, font=font)
+        else:
+            left, top, right, bottom = draw.textbbox((0, 0), text, font)
+            tw, th = right - left, bottom - top + 4
+        if ymin < th:
+            draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
+            draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+        else:
+            draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
+            draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
+
+    return img
+
+
+class ShiTuResult(BaseCVResult):
+
+    def _to_img(self):
+        """apply"""
+        boxes = [
+            {
+                "coordinate": box["coordinate"],
+                "label": box["labels"][0],
+                "score": box["rec_scores"][0],
+            }
+            for box in self["boxes"]
+            if box["rec_scores"] is not None
+        ]
+        image = draw_box(self["input_img"], boxes)
+        return image

+ 7 - 0
paddlex/model.py

@@ -62,6 +62,13 @@ class _ModelBasedInference(_BaseModel):
     def set_predictor(self, **kwargs):
         self._predictor.set_predictor(**kwargs)
 
+    def __getattr__(self, name):
+        if hasattr(self._predictor, name):
+            return getattr(self._predictor, name)
+        raise AttributeError(
+            f"'{self.__class__.__name__}' object has no attribute '{name}'"
+        )
+
 
 class _ModelBasedConfig(_BaseModel):
     def __init__(self, config=None, *args, **kwargs):