Jelajahi Sumber

fix multi label classification bug (#2079)

changdazhou 1 tahun lalu
induk
melakukan
a8a29a9f35

+ 1 - 0
paddlex/inference/models/__init__.py

@@ -29,6 +29,7 @@ from .ts_fc import TSFcPredictor
 from .ts_ad import TSAdPredictor
 from .ts_cls import TSClsPredictor
 from .image_unwarping import WarpPredictor
+from .multilabel_classification import MLClasPredictor
 from .anomaly_detection import UadPredictor
 
 

+ 1 - 2
paddlex/inference/models/image_classification.py

@@ -16,7 +16,6 @@ import numpy as np
 
 from ...utils.func_register import FuncRegister
 from ...modules.image_classification.model_list import MODELS
-from ...modules.multilabel_classification.model_list import MODELS as ML_MODELS
 from ..components import *
 from ..results import TopkResult
 from .base import BasicPredictor
@@ -24,7 +23,7 @@ from .base import BasicPredictor
 
 class ClasPredictor(BasicPredictor):
 
-    entities = [*MODELS, *ML_MODELS]
+    entities = [*MODELS]
 
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)

+ 33 - 0
paddlex/inference/models/multilabel_classification.py

@@ -0,0 +1,33 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...utils.func_register import FuncRegister
+from ...modules.multilabel_classification.model_list import MODELS
+from ..components import *
+from ..results import MLClassResult
+from ..utils.process_hook import batchable_method
+from .image_classification import ClasPredictor
+
+
+class MLClasPredictor(ClasPredictor):
+
+    entities = [*MODELS]
+
+    def _pack_res(self, single):
+        keys = ["img_path", "class_ids", "scores"]
+        if "label_names" in single:
+            keys.append("label_names")
+        return MLClassResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/results/__init__.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from .base import BaseResult
-from .topk import TopkResult
+from .clas import TopkResult, MLClassResult
 from .text_det import TextDetResult
 from .text_rec import TextRecResult
 from .table_rec import TableRecResult, StructureTableResult, TableResult

+ 52 - 1
paddlex/inference/results/topk.py → paddlex/inference/results/clas.py

@@ -14,7 +14,7 @@
 
 
 import PIL
-from PIL import ImageDraw, ImageFont
+from PIL import ImageDraw, ImageFont, Image
 import numpy as np
 
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
@@ -84,3 +84,54 @@ class TopkResult(BaseResult):
             return light.astype("int32")
         else:
             return dark.astype("int32")
+
+
+class MLClassResult(TopkResult):
+
+    def __init__(self, data):
+        super().__init__(data)
+        self._img_reader.set_backend("pillow")
+        self._img_writer.set_backend("pillow")
+
+    def _get_res_img(self):
+        """Draw label on image"""
+        image = self._img_reader.read(self["img_path"])
+        label_names = self["label_names"]
+        scores = self["scores"]
+        image = image.convert("RGB")
+        image_width, image_height = image.size
+        font_size = int(image_width * 0.06)
+
+        font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size)
+        text_lines = []
+        row_width = 0
+        row_height = 0
+        row_text = "\t"
+        for label_name, score in zip(label_names, scores):
+            text = f"{label_name}({score})\t"
+            text_width, row_height = font.getsize(text)
+            if row_width + text_width <= image_width:
+                row_text += text
+                row_width += text_width
+            else:
+                text_lines.append(row_text)
+                row_text = "\t" + text
+                row_width = text_width
+        text_lines.append(row_text)
+        color_list = get_colormap(rgb=True)
+        color = tuple(color_list[0])
+        new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
+        new_image = Image.new("RGB", (image_width, new_image_height), color)
+        new_image.paste(image, (0, 0))
+
+        draw = ImageDraw.Draw(new_image)
+        font_color = tuple(self._get_font_colormap(3))
+        for i, text in enumerate(text_lines):
+            text_width, _ = font.getsize(text)
+            draw.text(
+                (0, image_height + i * int(row_height * 1.2)),
+                text,
+                fill=font_color,
+                font=font,
+            )
+        return new_image