Преглед на файлове

improve det and instanceseg inference

zhangyubo0722 преди 1 година
родител
ревизия
01aff66316

+ 5 - 0
paddlex/inference/components/base.py

@@ -72,6 +72,11 @@ class BaseComponent(ABC):
         def _check_args_key(args):
             sig = inspect.signature(self.apply)
             for param in sig.parameters.values():
+                if param.kind == inspect.Parameter.VAR_KEYWORD:
+                    logging.debug(
+                        f"The apply function parameter of {self.__class__.__name__} is **kwargs, so would not inspect!"
+                    )
+                    continue
                 if param.default == inspect.Parameter.empty and param.name not in args:
                     raise Exception(
                         f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"

+ 1 - 1
paddlex/inference/components/paddle_predictor/__init__.py

@@ -12,4 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .predictor import ImagePredictor
+from .predictor import ImagePredictor, ImageDetPredictor, ImageInstanceSegPredictor

+ 58 - 5
paddlex/inference/components/paddle_predictor/predictor.py

@@ -14,7 +14,7 @@
 
 import os
 from abc import abstractmethod
-
+import numpy as np
 import paddle
 from paddle.inference import Config, create_predictor
 import numpy as np
@@ -150,20 +150,21 @@ No need to generate again."
         """get input names"""
         return self.input_names
 
-    def apply(self, batch_data):
-        x = self.to_batch(batch_data)
+    def apply(self, **kwargs):
+        x = self.to_batch(**kwargs)
         for idx in range(len(x)):
             self.input_handlers[idx].reshape(x[idx].shape)
             self.input_handlers[idx].copy_from_cpu(x[idx])
 
         self.predictor.run()
-
         output = []
         for out_tensor in self.output_handlers:
             batch = out_tensor.copy_to_cpu()
             output.append(batch)
+        return self.format_output(output)
 
-        return [{"pred": res} for res in zip(*output)]
+    def format_output(self, pred):
+        return [{"pred": res} for res in zip(*pred)]
 
     @abstractmethod
     def to_batch(self):
@@ -175,3 +176,55 @@ class ImagePredictor(BasePaddlePredictor):
 
     def to_batch(self, imgs):
         return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]
+
+
+class ImageDetPredictor(BasePaddlePredictor):
+    INPUT_KEYS = [["img", "scale_factors"], ["img", "scale_factors", "img_size"]]
+    OUTPUT_KEYS = [["boxes"], ["boxes", "masks"]]
+    DEAULT_INPUTS = {"img": "img", "scale_factors": "scale_factors"}
+    DEAULT_OUTPUTS = {"boxes": "boxes"}
+
+    def to_batch(self, img, scale_factors, img_size=None):
+        scale_factors = [scale_factor[::-1] for scale_factor in scale_factors]
+        if img_size is None:
+            return [
+                np.stack(img, axis=0).astype(dtype=np.float32, copy=False),
+                np.stack(scale_factors, axis=0).astype(dtype=np.float32, copy=False),
+            ]
+        else:
+            return [
+                np.stack(img_size, axis=0).astype(dtype=np.float32, copy=False),
+                np.stack(img, axis=0).astype(dtype=np.float32, copy=False),
+                np.stack(scale_factors, axis=0).astype(dtype=np.float32, copy=False),
+            ]
+
+    def format_output(self, pred):
+        box_idx_start = 0
+        pred_box = []
+        if len(pred) == 3:
+            pred_mask = []
+        for idx in range(len(pred[1])):
+            np_boxes_num = pred[1][idx]
+            box_idx_end = box_idx_start + np_boxes_num
+            np_boxes = pred[0][box_idx_start:box_idx_end]
+            pred_box.append(np_boxes)
+            if len(pred) == 3:
+                np_masks = pred[2][box_idx_start:box_idx_end]
+                pred_mask.append(np_masks)
+            box_idx_start = box_idx_end
+
+        boxes = [{"boxes": np.array(res)} for res in pred_box]
+        if len(pred) == 3:
+            masks = [{"masks": np.array(res)} for res in pred_mask]
+            return [{"boxes": boxes[0]["boxes"], "masks": masks[0]["masks"]}]
+        else:
+            return [{"boxes": np.array(res)} for res in pred_box]
+
+
+class ImageInstanceSegPredictor(ImageDetPredictor):
+    DEAULT_INPUTS = {
+        "img": "img",
+        "scale_factors": "scale_factors",
+        "img_size": "img_size",
+    }
+    DEAULT_OUTPUTS = {"boxes": "boxes", "masks": "masks"}

+ 2 - 0
paddlex/inference/components/task_related/__init__.py

@@ -16,3 +16,5 @@ from .clas import Topk, MultiLabelThreshOutput
 from .text_det import DetResizeForTest, NormalizeImage, DBPostProcess, CropByPolys
 from .text_rec import OCRReisizeNormImg, CTCLabelDecode
 from .table_rec import TableLabelDecode, TableMasterLabelDecode
+from .det import DetPostProcess
+from .instance_seg import InstanceSegPostProcess

+ 43 - 0
paddlex/inference/components/task_related/det.py

@@ -0,0 +1,43 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ....utils import logging
+from ..base import BaseComponent
+
+
+class DetPostProcess(BaseComponent):
+    """Save Result Transform"""
+
+    INPUT_KEYS = ["img_path", "boxes"]
+    OUTPUT_KEYS = ["boxes", "labels"]
+    DEAULT_INPUTS = {"boxes": "boxes"}
+    DEAULT_OUTPUTS = {
+        "boxes": "boxes",
+        "labels": "labels",
+    }
+
+    def __init__(self, threshold=0.5, labels=None):
+        super().__init__()
+        self.threshold = threshold
+        self.labels = labels
+
+    def apply(self, boxes):
+        """apply"""
+        expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
+        boxes = boxes[expect_boxes, :]
+        result = {"boxes": boxes, "labels": self.labels}
+
+        return result

+ 49 - 0
paddlex/inference/components/task_related/instance_seg.py

@@ -0,0 +1,49 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ....utils import logging
+from ..base import BaseComponent
+
+
+class InstanceSegPostProcess(BaseComponent):
+    """Save Result Transform"""
+
+    INPUT_KEYS = ["boxes", "masks"]
+    OUTPUT_KEYS = ["img_path", "boxes", "masks", "labels"]
+    DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks"}
+    DEAULT_OUTPUTS = {
+        "boxes": "boxes",
+        "masks": "masks",
+        "labels": "labels",
+    }
+
+    def __init__(self, threshold=0.5, labels=None):
+        super().__init__()
+        self.threshold = threshold
+        self.labels = labels
+
+    def apply(self, boxes, masks):
+        """apply"""
+        expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
+        boxes = boxes[expect_boxes, :]
+        masks = masks[expect_boxes, :, :]
+        result = {
+            "boxes": boxes,
+            "masks": masks,
+            "labels": self.labels,
+        }
+
+        return result

+ 2 - 0
paddlex/inference/pipelines/__init__.py

@@ -14,3 +14,5 @@
 
 from .image_classification import ClasPipeline
 from .ocr import OCRPipeline
+from .object_detection import DetPipeline
+from .instance_segmentation import InstanceSegPipeline

+ 33 - 0
paddlex/inference/pipelines/instance_segmentation.py

@@ -0,0 +1,33 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BasePipeline
+from ..predictors import create_predictor
+
+
+class InstanceSegPipeline(BasePipeline):
+    """InstanceSeg Pipeline"""
+
+    entities = "instance_segmentation"
+
+    def __init__(self, model, batch_size=1, device="gpu"):
+        super().__init__()
+        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+
+    def predict(self, x):
+        self._check_input(x)
+        yield from self._predict(x)
+
+    def _check_input(self, x):
+        pass

+ 33 - 0
paddlex/inference/pipelines/object_detection.py

@@ -0,0 +1,33 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BasePipeline
+from ..predictors import create_predictor
+
+
+class DetPipeline(BasePipeline):
+    """Det Pipeline"""
+
+    entities = "object_detection"
+
+    def __init__(self, model, batch_size=1, device="gpu"):
+        super().__init__()
+        self._predict = create_predictor(model, batch_size=batch_size, device=device)
+
+    def predict(self, x):
+        self._check_input(x)
+        yield from self._predict(x)
+
+    def _check_input(self, x):
+        pass

+ 2 - 0
paddlex/inference/predictors/__init__.py

@@ -20,6 +20,8 @@ from .image_classification import ClasPredictor
 from .text_detection import TextDetPredictor
 from .text_recognition import TextRecPredictor
 from .table_recognition import TablePredictor
+from .object_detection import DetPredictor
+from .instance_segmentation import InstanceSegPredictor
 from .official_models import official_models
 
 

+ 59 - 0
paddlex/inference/predictors/instance_segmentation.py

@@ -0,0 +1,59 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from .object_detection import DetPredictor
+from ...utils.func_register import FuncRegister
+from ...modules.instance_segmentation.model_list import MODELS
+from ..components import *
+from ..results import InstanceSegResults
+from ..utils.process_hook import batchable_method
+
+
+class InstanceSegPredictor(DetPredictor):
+
+    entities = MODELS
+
+    def _build_components(self):
+        ops = {}
+        ops["ReadImage"] = ReadImage(
+            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
+        )
+        for cfg in self.config["Preprocess"]:
+            tf_key = cfg["type"]
+            func = self._FUNC_MAP.get(tf_key)
+            cfg.pop("type")
+            args = cfg
+            op = func(self, **args) if args else func(self)
+            ops[tf_key] = op
+
+        predictor = ImageInstanceSegPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        ops["predictor"] = predictor
+
+        ops["postprocess"] = InstanceSegPostProcess(
+            threshold=self.config["draw_threshold"], labels=self.config["label_list"]
+        )
+
+        return ops
+
+    @batchable_method
+    def _pack_res(self, data):
+        keys = ["img_path", "boxes", "masks", "labels"]
+        return {"result": InstanceSegResults({key: data[key] for key in keys})}

+ 97 - 0
paddlex/inference/predictors/object_detection.py

@@ -0,0 +1,97 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...utils.func_register import FuncRegister
+from ...modules.object_detection.model_list import MODELS
+from ..components import *
+from ..results import DetResults
+from ..utils.process_hook import batchable_method
+from .base import BasicPredictor
+
+
+class DetPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def _build_components(self):
+        ops = {}
+        ops["ReadImage"] = ReadImage(
+            batch_size=self.kwargs.get("batch_size", 1), format="RGB"
+        )
+        for cfg in self.config["Preprocess"]:
+            tf_key = cfg["type"]
+            func = self._FUNC_MAP.get(tf_key)
+            cfg.pop("type")
+            args = cfg
+            op = func(self, **args) if args else func(self)
+            ops[tf_key] = op
+
+        predictor = ImageDetPredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        ops["predictor"] = predictor
+
+        ops["postprocess"] = DetPostProcess(
+            threshold=self.config["draw_threshold"], labels=self.config["label_list"]
+        )
+
+        return ops
+
+    @register("Resize")
+    def build_resize(self, target_size, keep_ratio=False, interp=2):
+        assert target_size
+        if isinstance(interp, int):
+            interp = {
+                0: "NEAREST",
+                1: "LINEAR",
+                2: "CUBIC",
+                3: "AREA",
+                4: "LANCZOS4",
+            }[interp]
+        op = Resize(target_size=target_size, keep_ratio=keep_ratio, interp=interp)
+        return op
+
+    @register("NormalizeImage")
+    def build_normalize(
+        self,
+        norm_type=None,
+        mean=[0.485, 0.456, 0.406],
+        std=[0.229, 0.224, 0.225],
+        is_scale=None,
+    ):
+        if is_scale:
+            scale = 1.0 / 255.0
+        else:
+            scale = 1
+        if norm_type != "mean_std":
+            mean = 0
+            std = 1
+        return Normalize(mean=mean, std=std)
+
+    @register("Permute")
+    def build_to_chw(self):
+        return ToCHWImage()
+
+    @batchable_method
+    def _pack_res(self, data):
+        keys = ["img_path", "boxes", "labels"]
+        return {"result": DetResults({key: data[key] for key in keys})}

+ 2 - 0
paddlex/inference/results/__init__.py

@@ -17,3 +17,5 @@ from .text_det import TextDetResult
 from .text_rec import TextRecResult
 from .table_rec import TableRecResult
 from .ocr import OCRResult
+from .det import DetResults
+from .instance_seg import InstanceSegResults

+ 105 - 0
paddlex/inference/results/det.py

@@ -0,0 +1,105 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import numpy as np
+import math
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ...utils import logging
+from ...utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..utils.io import ImageWriter, ImageReader
+from ..utils.color_map import get_colormap, font_colormap
+from .base import BaseResult
+
+
+def draw_box(img, np_boxes, labels):
+    """
+    Args:
+        img (PIL.Image.Image): PIL image
+        np_boxes (np.ndarray): shape:[N,6], N: number of box,
+                               matix element:[class, score, x_min, y_min, x_max, y_max]
+        labels (list): labels:['class1', ..., 'classn']
+    Returns:
+        img (PIL.Image.Image): visualized image
+    """
+    font_size = int(0.024 * int(img.width)) + 2
+    font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+
+    draw_thickness = int(max(img.size) * 0.005)
+    draw = ImageDraw.Draw(img)
+    clsid2color = {}
+    catid2fontcolor = {}
+    color_list = get_colormap(rgb=True)
+    expect_boxes = np_boxes[:, 0] > -1
+    np_boxes = np_boxes[expect_boxes, :]
+
+    for i, dt in enumerate(np_boxes):
+        clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
+        if clsid not in clsid2color:
+            color_index = i % len(color_list)
+            clsid2color[clsid] = color_list[color_index]
+            catid2fontcolor[clsid] = font_colormap(color_index)
+        color = tuple(clsid2color[clsid])
+        font_color = tuple(catid2fontcolor[clsid])
+
+        xmin, ymin, xmax, ymax = bbox
+        # draw bbox
+        draw.line(
+            [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
+            width=draw_thickness,
+            fill=color,
+        )
+
+        # draw label
+        text = "{} {:.2f}".format(labels[clsid], score)
+        if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+            tw, th = draw.textsize(text, font=font)
+        else:
+            left, top, right, bottom = draw.textbbox((0, 0), text, font)
+            tw, th = right - left, bottom - top
+        if ymin < th:
+            draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
+            draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+        else:
+            draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
+            draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
+
+    return img
+
+
+class DetResults(BaseResult):
+    """Save Result Transform"""
+
+    def __init__(self, data):
+        super().__init__(data)
+        self.data = data
+        # We use pillow backend to save both numpy arrays and PIL Image objects
+        self._img_reader.set_backend("pillow")
+        self._img_writer.set_backend("pillow")
+
+    def _get_res_img(self):
+        """apply"""
+        boxes = self["boxes"]
+        img_path = self["img_path"]
+        labels = self.data["labels"]
+        file_name = os.path.basename(img_path)
+
+        image = self._img_reader.read(img_path)
+        image = draw_box(image, boxes, labels=labels)
+        self["boxes"] = boxes.tolist()
+
+        return image

+ 87 - 0
paddlex/inference/results/instance_seg.py

@@ -0,0 +1,87 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import numpy as np
+import math
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ...utils import logging
+from ...utils.fonts import PINGFANG_FONT_FILE_PATH
+from ..utils.io import ImageWriter, ImageReader
+from ..utils.color_map import get_color_map_list, font_colormap
+from .base import BaseResult
+from .det import draw_box
+
+
+def draw_mask(im, np_boxes, np_masks, labels):
+    """
+    Args:
+        im (PIL.Image.Image): PIL image
+        np_boxes (np.ndarray): shape:[N,6], N: number of box,
+            matix element:[class, score, x_min, y_min, x_max, y_max]
+        np_masks (np.ndarray): shape:[N, im_h, im_w]
+        labels (list): labels:['class1', ..., 'classn']
+    Returns:
+        im (PIL.Image.Image): visualized image
+    """
+    color_list = get_color_map_list(len(labels))
+    w_ratio = 0.4
+    alpha = 0.7
+    im = np.array(im).astype("float32")
+    clsid2color = {}
+    im_h, im_w = im.shape[:2]
+    np_masks = np_masks[:, :im_h, :im_w]
+    for i in range(len(np_masks)):
+        clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
+        mask = np_masks[i]
+        if clsid not in clsid2color:
+            clsid2color[clsid] = color_list[clsid]
+        color_mask = clsid2color[clsid]
+        for c in range(3):
+            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+        idx = np.nonzero(mask)
+        color_mask = np.array(color_mask)
+        im[idx[0], idx[1], :] *= 1.0 - alpha
+        im[idx[0], idx[1], :] += alpha * color_mask
+    return Image.fromarray(im.astype("uint8"))
+
+
+class InstanceSegResults(BaseResult):
+    """Save Result Transform"""
+
+    def __init__(self, data):
+        super().__init__(data)
+        self.data = data
+        # We use pillow backend to save both numpy arrays and PIL Image objects
+        self._img_reader.set_backend("pillow")
+        self._img_writer.set_backend("pillow")
+
+    def _get_res_img(self):
+        """apply"""
+        boxes = self["boxes"]
+        masks = self["masks"]
+        img_path = self["img_path"]
+        labels = self.data["labels"]
+        file_name = os.path.basename(img_path)
+
+        image = self._img_reader.read(img_path)
+        image = draw_mask(image, boxes, masks, labels)
+        image = draw_box(image, boxes, labels=labels)
+        self["boxes"] = boxes.tolist()
+        self["masks"] = masks.tolist()
+
+        return image

+ 34 - 0
paddlex/inference/utils/color_map.py

@@ -87,3 +87,37 @@ def get_colormap(rgb=False):
     if not rgb:
         color_list = color_list[:, ::-1]
     return color_list.astype("int32")
+
+
+def get_color_map_list(num_classes):
+    """
+    Args:
+        num_classes (int): number of class
+    Returns:
+        color_map (list): RGB color list
+    """
+    color_map = num_classes * [0, 0, 0]
+    for i in range(0, num_classes):
+        j = 0
+        lab = i
+        while lab:
+            color_map[i * 3] |= ((lab >> 0) & 1) << (7 - j)
+            color_map[i * 3 + 1] |= ((lab >> 1) & 1) << (7 - j)
+            color_map[i * 3 + 2] |= ((lab >> 2) & 1) << (7 - j)
+            j += 1
+            lab >>= 3
+    color_map = [color_map[i : i + 3] for i in range(0, len(color_map), 3)]
+    return color_map
+
+
+def font_colormap(color_index):
+    """
+    Get font color according to the index of colormap
+    """
+    dark = np.array([0x14, 0x0E, 0x35])
+    light = np.array([0xFF, 0xFF, 0xFF])
+    light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+    if color_index in light_indexs:
+        return light.astype("int32")
+    else:
+        return dark.astype("int32")

+ 0 - 1
paddlex/modules/base/predictor/utils/paddle_inference_predictor.py

@@ -145,7 +145,6 @@ No need to generate again."
         for idx in range(len(x)):
             self.input_handlers[idx].reshape(x[idx].shape)
             self.input_handlers[idx].copy_from_cpu(x[idx])
-
         self.predictor.run()
 
         res = []