Преглед изворни кода

update layout_postprocess (#2637)

* update layout_postprocess

* update layout_postprocess

* update layout_postprocess
Sunflower7788 пре 11 месеци
родитељ
комит
4d6b62f2ce

+ 115 - 4
paddlex/inference/components/task_related/det.py

@@ -228,6 +228,55 @@ class WarpAffine(BaseComponent):
         }
 
 
+def compute_iou(box1, box2):
+    x1 = max(box1[0], box2[0])
+    y1 = max(box1[1], box2[1])
+    x2 = min(box1[2], box2[2])
+    y2 = min(box1[3], box2[3])
+    inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
+    box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
+    box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
+    iou = inter_area / float(box1_area + box2_area - inter_area)
+    return iou
+
+
+def is_box_mostly_inside(inner_box, outer_box, threshold=0.9):
+    x1 = max(inner_box[0], outer_box[0])
+    y1 = max(inner_box[1], outer_box[1])
+    x2 = min(inner_box[2], outer_box[2])
+    y2 = min(inner_box[3], outer_box[3])
+    inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
+    inner_box_area = (inner_box[2] - inner_box[0] + 1) * (inner_box[3] - inner_box[1] + 1)
+    return (inter_area / inner_box_area) >= threshold
+
+
+def non_max_suppression(boxes, scores, iou_threshold):
+    if len(boxes) == 0:
+        return []
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    order = scores.argsort()[::-1]
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        iou = inter / (areas[i] + areas[order[1:]] - inter)
+        inds = np.where(iou <= iou_threshold)[0]
+        order = order[inds + 1]
+    return keep
+
+
 class DetPostProcess(BaseComponent):
     """Save Result Transform"""
 
@@ -236,15 +285,78 @@ class DetPostProcess(BaseComponent):
     DEAULT_INPUTS = {"boxes": "boxes", "img_size": "ori_img_size"}
     DEAULT_OUTPUTS = {"boxes": "boxes"}
 
-    def __init__(self, threshold=0.5, labels=None):
+    def __init__(self, threshold=0.5, labels=None, layout_postprocess=False):
         super().__init__()
         self.threshold = threshold
         self.labels = labels
+        self.layout_postprocess = layout_postprocess
 
     def apply(self, boxes, img_size):
         """apply"""
-        expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
-        boxes = boxes[expect_boxes, :]
+        if isinstance(self.threshold, float):
+            expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
+            boxes = boxes[expect_boxes, :]
+        elif isinstance(self.threshold, dict):
+            category_filtered_boxes = []
+            for cat_id in np.unique(boxes[:, 0]):
+                category_boxes = boxes[boxes[:, 0] == cat_id]
+                category_scores = category_boxes[:, 1]
+                category_threshold = self.threshold.get(int(cat_id), 0.5)
+                selected_indices = category_scores > category_threshold
+                category_filtered_boxes.append(category_boxes[selected_indices])
+            boxes = np.vstack(category_filtered_boxes) if category_filtered_boxes else np.array([])
+
+        if self.layout_postprocess:
+            filtered_boxes = []
+            ### Layout postprocess for NMS
+            for cat_id in np.unique(boxes[:, 0]):
+                category_boxes = boxes[boxes[:, 0] == cat_id]
+                category_scores = category_boxes[:, 1]
+                if len(category_boxes) > 0:
+                    nms_indices = non_max_suppression(category_boxes[:, 2:], category_scores, 0.5)
+                    category_boxes = category_boxes[nms_indices]
+                    keep_boxes = []
+                    for i, box in enumerate(category_boxes):
+                        if all(not is_box_mostly_inside(box[2:], other_box[2:]) for j, other_box in enumerate(category_boxes) if i != j):
+                            keep_boxes.append(box)
+                    filtered_boxes.extend(keep_boxes)
+            boxes = np.array(filtered_boxes)
+            ### Layout postprocess for removing boxes inside image category box
+            if self.labels and "image" in self.labels:
+                image_cls_id = self.labels.index('image')
+                if len(boxes) > 0:
+                    image_boxes = boxes[boxes[:, 0] == image_cls_id]
+                    other_boxes = boxes[boxes[:, 0] != image_cls_id]
+                    to_keep = []
+                    for box in other_boxes:
+                        keep = True
+                        for img_box in image_boxes:
+                            if (box[2] >= img_box[2] and box[3] >= img_box[3] and
+                                box[4] <= img_box[4] and box[5] <= img_box[5]):
+                                keep = False
+                                break
+                        if keep:
+                            to_keep.append(box)
+                    boxes = np.vstack([image_boxes, to_keep]) if to_keep else image_boxes
+            ### Layout postprocess for overlaps
+            final_boxes = []
+            while len(boxes) > 0:
+                current_box = boxes[0]
+                current_score = current_box[1]
+                overlaps = [current_box]
+                non_overlaps = []
+                for other_box in boxes[1:]:
+                    iou = compute_iou(current_box[2:], other_box[2:])
+                    if iou > 0.95:
+                        if other_box[1] > current_score:
+                            overlaps.append(other_box)
+                    else:
+                        non_overlaps.append(other_box)
+                best_box = max(overlaps, key=lambda x: x[1])
+                final_boxes.append(best_box)
+                boxes = np.array(non_overlaps)
+            boxes = np.array(final_boxes)
+
         if boxes.shape[1] == 6:
             """For Normal Object Detection"""
             boxes = restructured_boxes(boxes, self.labels, img_size)
@@ -257,7 +369,6 @@ class DetPostProcess(BaseComponent):
                 f"The shape of boxes should be 6 or 10, instead of {boxes.shape[1]}"
             )
         result = {"boxes": boxes}
-
         return result
 
 

+ 1 - 0
paddlex/inference/models/object_detection.py

@@ -67,6 +67,7 @@ class DetPredictor(BasicPredictor):
                 DetPostProcess(
                     threshold=self.config["draw_threshold"],
                     labels=self.config["label_list"],
+                    layout_postprocess=self.config.get("layout_postprocess", False),
                 ),
             ]
         )

+ 2 - 2
paddlex/inference/results/det.py

@@ -31,10 +31,10 @@ def draw_box(img, boxes):
     Returns:
         img (PIL.Image.Image): visualized image
     """
-    font_size = int(0.024 * int(img.width)) + 2
+    font_size = int(0.018 * int(img.width)) + 2
     font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
 
-    draw_thickness = int(max(img.size) * 0.005)
+    draw_thickness = int(max(img.size) * 0.002)
     draw = ImageDraw.Draw(img)
     label2color = {}
     catid2fontcolor = {}