Bladeren bron

import_result_of_instseg

zhangyubo0722 1 jaar geleden
bovenliggende
commit
97c0431443

+ 16 - 6
paddlex/inference/components/task_related/det.py

@@ -18,16 +18,25 @@ from ...utils.io import ImageReader
 from ..base import BaseComponent
 
 
+def restructured_boxes(boxes, labels):
+    return [
+        {
+            "cls_id": int(box[0]),
+            "label": labels[int(box[0])],
+            "score": float(box[1]),
+            "coordinate": list(map(int, box[2:])),
+        }
+        for box in boxes
+    ]
+
+
 class DetPostProcess(BaseComponent):
     """Save Result Transform"""
 
     INPUT_KEYS = ["img_path", "boxes"]
-    OUTPUT_KEYS = ["boxes", "labels"]
+    OUTPUT_KEYS = ["boxes"]
     DEAULT_INPUTS = {"boxes": "boxes"}
-    DEAULT_OUTPUTS = {
-        "boxes": "boxes",
-        "labels": "labels",
-    }
+    DEAULT_OUTPUTS = {"boxes": "boxes"}
 
     def __init__(self, threshold=0.5, labels=None):
         super().__init__()
@@ -38,7 +47,8 @@ class DetPostProcess(BaseComponent):
         """apply"""
         expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
         boxes = boxes[expect_boxes, :]
-        result = {"boxes": boxes, "labels": self.labels}
+        boxes = restructured_boxes(boxes, self.labels)
+        result = {"boxes": boxes}
 
         return result
 

+ 24 - 7
paddlex/inference/components/task_related/instance_seg.py

@@ -14,20 +14,39 @@
 
 import os
 
+import numpy as np
 from ....utils import logging
 from ..base import BaseComponent
+from .det import restructured_boxes
+
+
+def extract_masks_from_boxes(boxes, masks):
+    """
+    Extracts the portion of each mask that is within the corresponding box.
+    """
+    new_masks = []
+
+    for i, box in enumerate(boxes):
+        x_min, y_min, x_max, y_max = box["coordinate"]
+        x_min, y_min, x_max, y_max = map(
+            lambda x: int(round(x)), [x_min, y_min, x_max, y_max]
+        )
+
+        cropped_mask = masks[i][y_min:y_max, x_min:x_max]
+        new_masks.append(cropped_mask)
+
+    return new_masks
 
 
 class InstanceSegPostProcess(BaseComponent):
     """Save Result Transform"""
 
     INPUT_KEYS = ["boxes", "masks"]
-    OUTPUT_KEYS = ["img_path", "boxes", "masks", "labels"]
+    OUTPUT_KEYS = ["img_path", "boxes", "masks"]
     DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks"}
     DEAULT_OUTPUTS = {
         "boxes": "boxes",
         "masks": "masks",
-        "labels": "labels",
     }
 
     def __init__(self, threshold=0.5, labels=None):
@@ -39,11 +58,9 @@ class InstanceSegPostProcess(BaseComponent):
         """apply"""
         expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
         boxes = boxes[expect_boxes, :]
+        boxes = restructured_boxes(boxes, self.labels)
         masks = masks[expect_boxes, :, :]
-        result = {
-            "boxes": boxes,
-            "masks": masks,
-            "labels": self.labels,
-        }
+        masks = extract_masks_from_boxes(boxes, masks)
+        result = {"boxes": boxes, "masks": masks}
 
         return result

+ 1 - 1
paddlex/inference/models/instance_segmentation.py

@@ -64,5 +64,5 @@ class InstanceSegPredictor(DetPredictor):
         return ops
 
     def _pack_res(self, single):
-        keys = ["img_path", "boxes", "masks", "labels"]
+        keys = ["img_path", "boxes", "masks"]
         return InstanceSegResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/object_detection.py

@@ -108,5 +108,5 @@ class DetPredictor(BasicPredictor):
         return ToCHWImage()
 
     def _pack_res(self, single):
-        keys = ["img_path", "boxes", "labels"]
+        keys = ["img_path", "boxes"]
         return DetResult({key: single[key] for key in keys})

+ 6 - 9
paddlex/inference/results/det.py

@@ -26,13 +26,11 @@ from ..utils.color_map import get_colormap, font_colormap
 from .base import BaseResult
 
 
-def draw_box(img, np_boxes, labels):
+def draw_box(img, boxes):
     """
     Args:
         img (PIL.Image.Image): PIL image
-        np_boxes (np.ndarray): shape:[N,6], N: number of box,
-                               matix element:[class, score, x_min, y_min, x_max, y_max]
-        labels (list): labels:['class1', ..., 'classn']
+        boxes (list): a list of dictionaries representing detection box information.
     Returns:
         img (PIL.Image.Image): visualized image
     """
@@ -45,8 +43,8 @@ def draw_box(img, np_boxes, labels):
     catid2fontcolor = {}
     color_list = get_colormap(rgb=True)
 
-    for i, dt in enumerate(np_boxes):
-        clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
+    for i, dt in enumerate(boxes):
+        clsid, bbox, score = dt["cls_id"], dt["coordinate"], dt["score"]
         if clsid not in clsid2color:
             color_index = i % len(color_list)
             clsid2color[clsid] = color_list[color_index]
@@ -63,7 +61,7 @@ def draw_box(img, np_boxes, labels):
         )
 
         # draw label
-        text = "{} {:.2f}".format(labels[clsid], score)
+        text = "{} {:.2f}".format(dt["label"], score)
         if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
             tw, th = draw.textsize(text, font=font)
         else:
@@ -92,10 +90,9 @@ class DetResult(BaseResult):
         """apply"""
         boxes = self["boxes"]
         img_path = self["img_path"]
-        labels = self["labels"]
         file_name = os.path.basename(img_path)
 
         image = self._img_reader.read(img_path)
-        image = draw_box(image, boxes, labels=labels)
+        image = draw_box(image, boxes)
 
         return image

+ 29 - 12
paddlex/inference/results/instance_seg.py

@@ -22,34 +22,51 @@ from PIL import Image, ImageDraw, ImageFont
 from ...utils import logging
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..utils.io import ImageWriter, ImageReader
-from ..utils.color_map import get_color_map_list, font_colormap
+from ..utils.color_map import get_colormap, font_colormap
 from .base import BaseResult
 from .det import draw_box
 
 
-def draw_mask(im, np_boxes, np_masks, labels):
+def restore_to_draw_masks(img_size, boxes, masks):
+    """
+    Restores extracted masks to the original shape and draws them on a blank image.
+
+    """
+
+    restored_masks = []
+
+    for i, (box, mask) in enumerate(zip(boxes, masks)):
+        restored_mask = np.zeros(img_size, dtype=np.uint8)
+        x_min, y_min, x_max, y_max = map(lambda x: int(round(x)), box["coordinate"])
+        restored_mask[y_min:y_max, x_min:x_max] = mask
+        restored_masks.append(restored_mask)
+
+    return np.array(restored_masks)
+
+
+def draw_mask(im, boxes, np_masks, img_size):
     """
     Args:
         im (PIL.Image.Image): PIL image
-        np_boxes (np.ndarray): shape:[N,6], N: number of box,
-            matix element:[class, score, x_min, y_min, x_max, y_max]
+        boxes (list): a list of dictionaries representing detection box information.
         np_masks (np.ndarray): shape:[N, im_h, im_w]
-        labels (list): labels:['class1', ..., 'classn']
     Returns:
         im (PIL.Image.Image): visualized image
     """
-    color_list = get_color_map_list(len(labels))
+    color_list = get_colormap(rgb=True)
     w_ratio = 0.4
     alpha = 0.7
     im = np.array(im).astype("float32")
     clsid2color = {}
+    np_masks = restore_to_draw_masks(img_size, boxes, np_masks)
     im_h, im_w = im.shape[:2]
     np_masks = np_masks[:, :im_h, :im_w]
     for i in range(len(np_masks)):
-        clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
+        clsid, score = int(boxes[i]["cls_id"]), boxes[i]["score"]
         mask = np_masks[i]
         if clsid not in clsid2color:
-            clsid2color[clsid] = color_list[clsid]
+            color_index = i % len(color_list)
+            clsid2color[clsid] = color_list[color_index]
         color_mask = clsid2color[clsid]
         for c in range(3):
             color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
@@ -71,14 +88,14 @@ class InstanceSegResult(BaseResult):
 
     def _get_res_img(self):
         """apply"""
-        boxes = self["boxes"]
+        boxes = np.array(self["boxes"])
         masks = self["masks"]
         img_path = self["img_path"]
-        labels = self["labels"]
         file_name = os.path.basename(img_path)
 
         image = self._img_reader.read(img_path)
-        image = draw_mask(image, boxes, masks, labels)
-        image = draw_box(image, boxes, labels=labels)
+        ori_img_size = list(image.size)[::-1]
+        image = draw_mask(image, boxes, masks, ori_img_size)
+        image = draw_box(image, boxes)
 
         return image