Sunflower7788 пре 8 месеци
родитељ
комит
74f2c4895d
1 измењених фајлова са 56 додато и 23 уклоњено
  1. 56 23
      paddlex/inference/models/object_detection/processors.py

+ 56 - 23
paddlex/inference/models/object_detection/processors.py

@@ -642,7 +642,7 @@ def is_contained(box1, box2):
     return iou >= 0.9
 
 
-def check_containment(boxes, formula_index=None):
+def check_containment(boxes, formula_index=None, category_index=None, mode=None):
     """Check containment relationships among boxes."""
     n = len(boxes)
     contains_other = np.zeros(n, dtype=int)
@@ -655,9 +655,19 @@ def check_containment(boxes, formula_index=None):
             if formula_index is not None:
                 if boxes[i][0] == formula_index and boxes[j][0] != formula_index:
                     continue
-            if is_contained(boxes[i], boxes[j]):
-                contained_by_other[i] = 1
-                contains_other[j] = 1
+            if category_index is not None and mode is not None:
+                if mode == "large" and boxes[j][0] == category_index:
+                    if is_contained(boxes[i], boxes[j]):
+                        contained_by_other[i] = 1
+                        contains_other[j] = 1                   
+                if mode == "small" and boxes[i][0] == category_index:
+                    if is_contained(boxes[i], boxes[j]):
+                        contained_by_other[i] = 1
+                        contains_other[j] = 1
+            else:
+                if is_contained(boxes[i], boxes[j]):
+                    contained_by_other[i] = 1
+                    contains_other[j] = 1
     return contains_other, contained_by_other
 
 
@@ -724,25 +734,48 @@ class DetPostProcess:
             boxes = np.array(boxes[selected_indices])
 
         if layout_merge_bboxes_mode:
-            assert layout_merge_bboxes_mode in [
-                "union",
-                "large",
-                "small",
-            ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
-
-            if layout_merge_bboxes_mode == "union":
-                pass
-            else:
-                formula_index = (
-                    self.labels.index("formula") if "formula" in self.labels else None
-                )
-                contains_other, contained_by_other = check_containment(
-                    boxes, formula_index
-                )
-                if layout_merge_bboxes_mode == "large":
-                    boxes = boxes[contained_by_other == 0]
-                elif layout_merge_bboxes_mode == "small":
-                    boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
+            formula_index = (self.labels.index("formula") if "formula" in self.labels else None)
+            if isinstance(layout_merge_bboxes_mode, str): 
+                assert layout_merge_bboxes_mode in [
+                    "union",
+                    "large",
+                    "small",
+                ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
+
+                if layout_merge_bboxes_mode == "union":
+                    pass
+                else:
+                    contains_other, contained_by_other = check_containment(
+                        boxes, formula_index
+                    )
+                    if layout_merge_bboxes_mode == "large":
+                        boxes = boxes[contained_by_other == 0]
+                    elif layout_merge_bboxes_mode == "small":
+                        boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
+            elif isinstance(layout_merge_bboxes_mode, dict):
+                keep_mask = np.ones(len(boxes), dtype=bool)
+                for category_index, layout_mode in layout_merge_bboxes_mode.items():
+                    assert layout_mode in [
+                        "union",
+                        "large",
+                        "small",
+                    ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_mode}"
+                    if layout_mode == "union":
+                        pass
+                    else:
+                        if layout_mode == "large":
+                            contains_other, contained_by_other = check_containment(
+                                boxes, formula_index, category_index, mode=layout_mode
+                            )
+                            # Remove boxes that are contained by other boxes
+                            keep_mask &= (contained_by_other == 0)
+                        elif layout_mode == "small":
+                            contains_other, contained_by_other = check_containment(
+                                boxes, formula_index, category_index, mode=layout_mode
+                            )
+                            # Keep boxes that do not contain others or are contained by others
+                            keep_mask &= (contains_other == 0) | (contained_by_other == 1)
+                boxes = boxes[keep_mask]
 
         if layout_unclip_ratio:
             if isinstance(layout_unclip_ratio, float):