Browse Source

support multi-label image classification with new inference

zhouchangda 11 months ago
parent
commit
673b09a03e

+ 3 - 8
paddlex/inference/models_new/__init__.py

@@ -25,23 +25,18 @@ from .object_detection import DetPredictor
 from .text_detection import TextDetPredictor
 from .text_recognition import TextRecPredictor
 from .formula_recognition import FormulaRecPredictor
-
-# from .table_recognition import TablePredictor
-# from .object_detection import DetPredictor
 from .instance_segmentation import InstanceSegPredictor
 from .semantic_segmentation import SegPredictor
 from .image_feature import ImageFeaturePredictor
-
-# from .general_recognition import ShiTuRecPredictor
-
 from .ts_forecast import TSFcPredictor
 from .ts_anomaly import TSAdPredictor
 from .ts_classify import TSClsPredictor
 from .image_unwarping import WarpPredictor
+from .image_multilabel_classification import MLClasPredictor
 
-# from .multilabel_classification import MLClasPredictor
+# from .table_recognition import TablePredictor
+# from .general_recognition import ShiTuRecPredictor
 # from .anomaly_detection import UadPredictor
-
 # from .face_recognition import FaceRecPredictor
 
 

+ 15 - 0
paddlex/inference/models_new/image_multilabel_classification/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import MLClasPredictor

+ 90 - 0
paddlex/inference/models_new/image_multilabel_classification/predictor.py

@@ -0,0 +1,90 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List
+import numpy as np
+
+from .result import MLClassResult
+from .processors import MultiLabelThreshOutput
+from ..image_classification import ClasPredictor
+from ....modules.multilabel_classification.model_list import MODELS
+
+
+class MLClasPredictor(ClasPredictor):
+    """MLClasPredictor that inherits from BasicPredictor."""
+
+    entities = MODELS
+
+    def __init__(
+        self,
+        threshold: Union[float, dict, list, None] = None,
+        *args: List,
+        **kwargs: Dict
+    ) -> None:
+        """Initializes MLClasPredictor.
+
+        Args:
+            threshold (float, dict, optional): The threshold predictions to return. If None, it will be depending on config of inference or predict. Defaults to None.
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
+        self.threshold = threshold
+        super().__init__(*args, **kwargs)
+
+    def _get_result_class(self) -> type:
+        """Returns the result class, TopkResult.
+
+        Returns:
+            type: The TopkResult class.
+        """
+
+        return MLClassResult
+
+    def process(
+        self,
+        batch_data: List[Union[str, np.ndarray]],
+        threshold: Union[int, dict, None] = None,
+    ) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+            threshold (float, dict, optional): The threshold predictions to return. If None, it will be depending on config of inference or predict. Defaults to None.
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
+        """
+        batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data)
+        batch_imgs = self.preprocessors["Resize"](imgs=batch_raw_imgs)
+        batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
+        batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
+        x = self.preprocessors["ToBatch"](imgs=batch_imgs)
+        batch_preds = self.infer(x=x)
+        batch_class_ids, batch_scores, batch_label_names = self.postprocessors[
+            "MultiLabelThreshOutput"
+        ](preds=batch_preds, threshold=threshold or self.threshold)
+        return {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+            "class_ids": batch_class_ids,
+            "scores": batch_scores,
+            "label_names": batch_label_names,
+        }
+
+    @ClasPredictor.register("MultiLabelThreshOutput")
+    def build_threshoutput(self, threshold: Union[float, dict, list], label_list=None):
+        if self.threshold is None:
+            self.threshold = threshold
+        return "MultiLabelThreshOutput", MultiLabelThreshOutput(class_ids=label_list)

+ 83 - 0
paddlex/inference/models_new/image_multilabel_classification/processors.py

@@ -0,0 +1,83 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from typing import Union
+
+
+class MultiLabelThreshOutput:
+    """MultiLabelThresh Transform"""
+
+    def __init__(self, class_ids=None, delimiter=None):
+        super().__init__()
+        self.delimiter = delimiter if delimiter is not None else " "
+        self.class_id_map = self._parse_class_id_map(class_ids)
+
+    def _parse_class_id_map(self, class_ids):
+        """parse class id to label map file"""
+        if class_ids is None:
+            return None
+        class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
+        return class_id_map
+
+    def __call__(self, preds, threshold: Union[float, dict, list]):
+        threshold_list = []
+        num_classes = preds[0].shape[-1]
+        if isinstance(threshold, float):
+            threshold_list = [threshold for _ in range(num_classes)]
+        elif isinstance(threshold, dict):
+            if threshold.get("default") is None:
+                raise ValueError(
+                    "If using dictionary format, please specify default threshold explicitly with key 'default'."
+                )
+            default_threshold = threshold.pop("default")
+            threshold_list = [default_threshold for _ in range(num_classes)]
+            for k, v in threshold.items():
+                if isinstance(k, str):
+                    assert (
+                        k.isdigit()
+                    ), f"Invalid key of threshold: {k}, it must be integer"
+                    k = int(k)
+                if not isinstance(v, float):
+                    raise ValueError(
+                        f"Invalid value type of threshold: {type(v)}, it must be float"
+                    )
+                assert (
+                    k < num_classes
+                ), f"Invalid key of threshold: {k}, it must be less than the number of classes({num_classes})"
+                threshold_list[k] = v
+        elif isinstance(threshold, list):
+            assert (
+                len(threshold) == num_classes
+            ), f"The length of threshold({len(threshold)}) should be equal to the number of classes({num_classes})."
+            threshold_list = threshold
+        else:
+            raise ValueError(
+                "Invalid type of threshold, should be 'list', 'dict' or 'float'."
+            )
+
+        pred_indexes = [
+            np.argsort(-x[x > threshold])
+            for x, threshold in zip(preds[0], threshold_list)
+        ]
+        indexes = [
+            np.where(x > threshold)[0][indices]
+            for x, indices, threshold in zip(preds[0], pred_indexes, threshold_list)
+        ]
+        scores = [
+            np.around(pred[index].astype(float), decimals=5)
+            for pred, index in zip(preds[0], indexes)
+        ]
+        label_names = [[self.class_id_map[i] for i in index] for index in indexes]
+        return indexes, scores, label_names

+ 85 - 0
paddlex/inference/models_new/image_multilabel_classification/result.py

@@ -0,0 +1,85 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+import numpy as np
+
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ...utils.color_map import get_colormap
+from ...common.result import BaseCVResult
+
+
+class MLClassResult(BaseCVResult):
+    def _to_img(self):
+        """Draw label on image"""
+        image = Image.fromarray(self._input_img)
+        label_names = self["label_names"]
+        scores = self["scores"]
+        image = image.convert("RGB")
+        image_width, image_height = image.size
+        font_size = int(image_width * 0.06)
+
+        font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size)
+        text_lines = []
+        row_width = 0
+        row_height = 0
+        row_text = "\t"
+        for label_name, score in zip(label_names, scores):
+            text = f"{label_name}({score})\t"
+            if int(PIL.__version__.split(".")[0]) < 10:
+                text_width, row_height = font.getsize(text)
+            else:
+                text_width, row_height = font.getbbox(text)[2:]
+            if row_width + text_width <= image_width:
+                row_text += text
+                row_width += text_width
+            else:
+                text_lines.append(row_text)
+                row_text = "\t" + text
+                row_width = text_width
+        text_lines.append(row_text)
+        color_list = get_colormap(rgb=True)
+        color = tuple(color_list[0])
+        new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
+        new_image = Image.new("RGB", (image_width, new_image_height), color)
+        new_image.paste(image, (0, 0))
+
+        draw = ImageDraw.Draw(new_image)
+        font_color = tuple(self._get_font_colormap(3))
+        for i, text in enumerate(text_lines):
+            if int(PIL.__version__.split(".")[0]) < 10:
+                text_width, _ = font.getsize(text)
+            else:
+                text_width, _ = font.getbbox(text)[2:]
+            draw.text(
+                (0, image_height + i * int(row_height * 1.2)),
+                text,
+                fill=font_color,
+                font=font,
+            )
+        return new_image
+
+    def _get_font_colormap(self, color_index):
+        """
+        Get font colormap
+        """
+        dark = np.array([0x14, 0x0E, 0x35])
+        light = np.array([0xFF, 0xFF, 0xFF])
+        light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+        if color_index in light_indexs:
+            return light.astype("int32")
+        else:
+            return dark.astype("int32")