|
|
@@ -642,7 +642,7 @@ def is_contained(box1, box2):
|
|
|
return iou >= 0.9
|
|
|
|
|
|
|
|
|
-def check_containment(boxes, formula_index=None):
|
|
|
+def check_containment(boxes, formula_index=None, category_index=None, mode=None):
|
|
|
"""Check containment relationships among boxes."""
|
|
|
n = len(boxes)
|
|
|
contains_other = np.zeros(n, dtype=int)
|
|
|
@@ -655,9 +655,19 @@ def check_containment(boxes, formula_index=None):
|
|
|
if formula_index is not None:
|
|
|
if boxes[i][0] == formula_index and boxes[j][0] != formula_index:
|
|
|
continue
|
|
|
- if is_contained(boxes[i], boxes[j]):
|
|
|
- contained_by_other[i] = 1
|
|
|
- contains_other[j] = 1
|
|
|
+ if category_index is not None and mode is not None:
|
|
|
+ if mode == "large" and boxes[j][0] == category_index:
|
|
|
+ if is_contained(boxes[i], boxes[j]):
|
|
|
+ contained_by_other[i] = 1
|
|
|
+ contains_other[j] = 1
|
|
|
+ if mode == "small" and boxes[i][0] == category_index:
|
|
|
+ if is_contained(boxes[i], boxes[j]):
|
|
|
+ contained_by_other[i] = 1
|
|
|
+ contains_other[j] = 1
|
|
|
+ else:
|
|
|
+ if is_contained(boxes[i], boxes[j]):
|
|
|
+ contained_by_other[i] = 1
|
|
|
+ contains_other[j] = 1
|
|
|
return contains_other, contained_by_other
|
|
|
|
|
|
|
|
|
@@ -724,25 +734,48 @@ class DetPostProcess:
|
|
|
boxes = np.array(boxes[selected_indices])
|
|
|
|
|
|
if layout_merge_bboxes_mode:
|
|
|
- assert layout_merge_bboxes_mode in [
|
|
|
- "union",
|
|
|
- "large",
|
|
|
- "small",
|
|
|
- ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
|
|
|
-
|
|
|
- if layout_merge_bboxes_mode == "union":
|
|
|
- pass
|
|
|
- else:
|
|
|
- formula_index = (
|
|
|
- self.labels.index("formula") if "formula" in self.labels else None
|
|
|
- )
|
|
|
- contains_other, contained_by_other = check_containment(
|
|
|
- boxes, formula_index
|
|
|
- )
|
|
|
- if layout_merge_bboxes_mode == "large":
|
|
|
- boxes = boxes[contained_by_other == 0]
|
|
|
- elif layout_merge_bboxes_mode == "small":
|
|
|
- boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
|
|
|
+ formula_index = (self.labels.index("formula") if "formula" in self.labels else None)
|
|
|
+ if isinstance(layout_merge_bboxes_mode, str):
|
|
|
+ assert layout_merge_bboxes_mode in [
|
|
|
+ "union",
|
|
|
+ "large",
|
|
|
+ "small",
|
|
|
+ ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
|
|
|
+
|
|
|
+ if layout_merge_bboxes_mode == "union":
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ contains_other, contained_by_other = check_containment(
|
|
|
+ boxes, formula_index
|
|
|
+ )
|
|
|
+ if layout_merge_bboxes_mode == "large":
|
|
|
+ boxes = boxes[contained_by_other == 0]
|
|
|
+ elif layout_merge_bboxes_mode == "small":
|
|
|
+ boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
|
|
|
+ elif isinstance(layout_merge_bboxes_mode, dict):
|
|
|
+ keep_mask = np.ones(len(boxes), dtype=bool)
|
|
|
+ for category_index, layout_mode in layout_merge_bboxes_mode.items():
|
|
|
+ assert layout_mode in [
|
|
|
+ "union",
|
|
|
+ "large",
|
|
|
+ "small",
|
|
|
+ ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_mode}"
|
|
|
+ if layout_mode == "union":
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ if layout_mode == "large":
|
|
|
+ contains_other, contained_by_other = check_containment(
|
|
|
+ boxes, formula_index, category_index, mode=layout_mode
|
|
|
+ )
|
|
|
+ # Remove boxes that are contained by other boxes
|
|
|
+ keep_mask &= (contained_by_other == 0)
|
|
|
+ elif layout_mode == "small":
|
|
|
+ contains_other, contained_by_other = check_containment(
|
|
|
+ boxes, formula_index, category_index, mode=layout_mode
|
|
|
+ )
|
|
|
+ # Keep boxes that do not contain others or are contained by others
|
|
|
+ keep_mask &= (contains_other == 0) | (contained_by_other == 1)
|
|
|
+ boxes = boxes[keep_mask]
|
|
|
|
|
|
if layout_unclip_ratio:
|
|
|
if isinstance(layout_unclip_ratio, float):
|