|
|
@@ -424,52 +424,6 @@ class WarpAffine:
|
|
|
|
|
|
return datas
|
|
|
|
|
|
-
|
|
|
-def compute_iou(box1: List[Number], box2: List[Number]) -> float:
|
|
|
- """Compute the Intersection over Union (IoU) of two bounding boxes.
|
|
|
-
|
|
|
- Args:
|
|
|
- box1 (List[Number]): Coordinates of the first bounding box in format [x1, y1, x2, y2].
|
|
|
- box2 (List[Number]): Coordinates of the second bounding box in format [x1, y1, x2, y2].
|
|
|
-
|
|
|
- Returns:
|
|
|
- float: The IoU of the two bounding boxes.
|
|
|
- """
|
|
|
- x1 = max(box1[0], box2[0])
|
|
|
- y1 = max(box1[1], box2[1])
|
|
|
- x2 = min(box1[2], box2[2])
|
|
|
- y2 = min(box1[3], box2[3])
|
|
|
- inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
|
|
|
- box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
|
|
|
- box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
|
|
|
- iou = inter_area / float(box1_area + box2_area - inter_area)
|
|
|
- return iou
|
|
|
-
|
|
|
-
|
|
|
-def is_box_mostly_inside(
|
|
|
- inner_box: List[Number], outer_box: List[Number], threshold: float = 0.9
|
|
|
-) -> bool:
|
|
|
- """Determine if one bounding box is mostly inside another bounding box.
|
|
|
-
|
|
|
- Args:
|
|
|
- inner_box (List[Number]): Coordinates of the inner bounding box in format [x1, y1, x2, y2].
|
|
|
- outer_box (List[Number]): Coordinates of the outer bounding box in format [x1, y1, x2, y2].
|
|
|
- threshold (float): The threshold for determining if the inner box is mostly inside the outer box (default is 0.9).
|
|
|
-
|
|
|
- Returns:
|
|
|
- bool: True if the ratio of the intersection area to the inner box area is greater than or equal to the threshold, False otherwise.
|
|
|
- """
|
|
|
- x1 = max(inner_box[0], outer_box[0])
|
|
|
- y1 = max(inner_box[1], outer_box[1])
|
|
|
- x2 = min(inner_box[2], outer_box[2])
|
|
|
- y2 = min(inner_box[3], outer_box[3])
|
|
|
- inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
|
|
|
- inner_box_area = (inner_box[2] - inner_box[0] + 1) * (
|
|
|
- inner_box[3] - inner_box[1] + 1
|
|
|
- )
|
|
|
- return (inter_area / inner_box_area) >= threshold
|
|
|
-
|
|
|
-
|
|
|
def restructured_boxes(
|
|
|
boxes: ndarray, labels: List[str], img_size: Tuple[int, int]
|
|
|
) -> Boxes:
|
|
|
@@ -544,48 +498,125 @@ def restructured_rotated_boxes(
|
|
|
|
|
|
return box_list
|
|
|
|
|
|
-
|
|
|
-def non_max_suppression(
|
|
|
- boxes: ndarray, scores: ndarray, iou_threshold: float
|
|
|
-) -> List[int]:
|
|
|
+def unclip_boxes(boxes, unclip_ratio=None):
|
|
|
"""
|
|
|
- Perform non-maximum suppression to remove redundant overlapping boxes with
|
|
|
- lower scores. This function is commonly used in object detection tasks.
|
|
|
-
|
|
|
+ Expand bounding boxes from (x1, y1, x2, y2) format using an unclipping ratio.
|
|
|
+
|
|
|
Parameters:
|
|
|
- boxes (ndarray): An array of shape (N, 4) representing the bounding boxes.
|
|
|
- Each row is in the format [x1, y1, x2, y2].
|
|
|
- scores (ndarray): An array of shape (N,) containing the scores for each box.
|
|
|
- iou_threshold (float): The Intersection over Union (IoU) threshold to use
|
|
|
- for suppressing overlapping boxes.
|
|
|
-
|
|
|
+ - boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
|
|
|
+ - unclip_ratio: tuple of (width_ratio, height_ratio), optional.
|
|
|
+
|
|
|
Returns:
|
|
|
- List[int]: A list of indices representing the indices of the boxes to keep.
|
|
|
+ - expanded_boxes: np.ndarray of shape (N, 4), where each row is (x1, y1, x2, y2).
|
|
|
"""
|
|
|
- if len(boxes) == 0:
|
|
|
- return []
|
|
|
- x1 = boxes[:, 0]
|
|
|
- y1 = boxes[:, 1]
|
|
|
- x2 = boxes[:, 2]
|
|
|
- y2 = boxes[:, 3]
|
|
|
- areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
|
- order = scores.argsort()[::-1]
|
|
|
- keep = []
|
|
|
- while order.size > 0:
|
|
|
- i = order[0]
|
|
|
- keep.append(i)
|
|
|
- xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
|
- yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
|
- xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
|
- yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
|
-
|
|
|
- w = np.maximum(0.0, xx2 - xx1 + 1)
|
|
|
- h = np.maximum(0.0, yy2 - yy1 + 1)
|
|
|
- inter = w * h
|
|
|
- iou = inter / (areas[i] + areas[order[1:]] - inter)
|
|
|
- inds = np.where(iou <= iou_threshold)[0]
|
|
|
- order = order[inds + 1]
|
|
|
- return keep
|
|
|
+ if unclip_ratio is None:
|
|
|
+ return boxes
|
|
|
+
|
|
|
+ widths = boxes[:, 4] - boxes[:, 2]
|
|
|
+ heights = boxes[:, 5] - boxes[:, 3]
|
|
|
+
|
|
|
+ new_w = widths * unclip_ratio[0]
|
|
|
+ new_h = heights * unclip_ratio[1]
|
|
|
+ center_x = boxes[:, 2] + widths / 2
|
|
|
+ center_y = boxes[:, 3] + heights / 2
|
|
|
+
|
|
|
+ new_x1 = center_x - new_w / 2
|
|
|
+ new_y1 = center_y - new_h / 2
|
|
|
+ new_x2 = center_x + new_w / 2
|
|
|
+ new_y2 = center_y + new_h / 2
|
|
|
+ expanded_boxes = np.column_stack((boxes[:, 0], boxes[:, 1], new_x1, new_y1, new_x2, new_y2))
|
|
|
+
|
|
|
+ return expanded_boxes
|
|
|
+
|
|
|
+
|
|
|
+def iou(box1, box2):
|
|
|
+ """Compute the Intersection over Union (IoU) of two bounding boxes."""
|
|
|
+ x1, y1, x2, y2 = box1
|
|
|
+ x1_p, y1_p, x2_p, y2_p = box2
|
|
|
+
|
|
|
+ # Compute the intersection coordinates
|
|
|
+ x1_i = max(x1, x1_p)
|
|
|
+ y1_i = max(y1, y1_p)
|
|
|
+ x2_i = min(x2, x2_p)
|
|
|
+ y2_i = min(y2, y2_p)
|
|
|
+
|
|
|
+ # Compute the area of intersection
|
|
|
+ inter_area = max(0, x2_i - x1_i + 1) * max(0, y2_i - y1_i + 1)
|
|
|
+
|
|
|
+ # Compute the area of both bounding boxes
|
|
|
+ box1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
|
+ box2_area = (x2_p - x1_p + 1) * (y2_p - y1_p + 1)
|
|
|
+
|
|
|
+ # Compute the IoU
|
|
|
+ iou_value = inter_area / float(box1_area + box2_area - inter_area)
|
|
|
+
|
|
|
+ return iou_value
|
|
|
+
|
|
|
+def nms(boxes, iou_same=0.6, iou_diff=0.95):
|
|
|
+ """Perform Non-Maximum Suppression (NMS) with different IoU thresholds for same and different classes."""
|
|
|
+ # Extract class scores
|
|
|
+ scores = boxes[:, 1]
|
|
|
+
|
|
|
+ # Sort indices by scores in descending order
|
|
|
+ indices = np.argsort(scores)[::-1]
|
|
|
+ selected_boxes = []
|
|
|
+
|
|
|
+ while len(indices) > 0:
|
|
|
+ current = indices[0]
|
|
|
+ current_box = boxes[current]
|
|
|
+ current_class = current_box[0]
|
|
|
+ current_score = current_box[1]
|
|
|
+ current_coords = current_box[2:]
|
|
|
+
|
|
|
+ selected_boxes.append(current)
|
|
|
+ indices = indices[1:]
|
|
|
+
|
|
|
+ filtered_indices = []
|
|
|
+ for i in indices:
|
|
|
+ box = boxes[i]
|
|
|
+ box_class = box[0]
|
|
|
+ box_coords = box[2:]
|
|
|
+ iou_value = iou(current_coords, box_coords)
|
|
|
+ threshold = iou_same if current_class == box_class else iou_diff
|
|
|
+
|
|
|
+ # If the IoU is below the threshold, keep the box
|
|
|
+ if iou_value < threshold:
|
|
|
+ filtered_indices.append(i)
|
|
|
+ indices = filtered_indices
|
|
|
+ return selected_boxes
|
|
|
+
|
|
|
+def is_contained(box1, box2):
|
|
|
+ """Check if box1 is contained within box2."""
|
|
|
+ _, _, x1, y1, x2, y2 = box1
|
|
|
+ _, _, x1_p, y1_p, x2_p, y2_p = box2
|
|
|
+ box1_area = (x2 - x1) * (y2 - y1)
|
|
|
+ xi1 = max(x1, x1_p)
|
|
|
+ yi1 = max(y1, y1_p)
|
|
|
+ xi2 = min(x2, x2_p)
|
|
|
+ yi2 = min(y2, y2_p)
|
|
|
+ inter_width = max(0, xi2 - xi1)
|
|
|
+ inter_height = max(0, yi2 - yi1)
|
|
|
+ intersect_area = inter_width * inter_height
|
|
|
+ iou = intersect_area / box1_area if box1_area > 0 else 0
|
|
|
+ return iou >= 0.9
|
|
|
+
|
|
|
+def check_containment(boxes, formula_index=None):
|
|
|
+ """Check containment relationships among boxes."""
|
|
|
+ n = len(boxes)
|
|
|
+ contains_other = np.zeros(n, dtype=int)
|
|
|
+ contained_by_other = np.zeros(n, dtype=int)
|
|
|
+
|
|
|
+ for i in range(n):
|
|
|
+ for j in range(n):
|
|
|
+ if i == j:
|
|
|
+ continue
|
|
|
+ if formula_index is not None:
|
|
|
+ if boxes[i][0] == formula_index and boxes[j][0] != formula_index:
|
|
|
+ continue
|
|
|
+ if is_contained(boxes[i], boxes[j]):
|
|
|
+ contained_by_other[i] = 1
|
|
|
+ contains_other[j] = 1
|
|
|
+ return contains_other, contained_by_other
|
|
|
|
|
|
|
|
|
class DetPostProcess:
|
|
|
@@ -598,9 +629,7 @@ class DetPostProcess:
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- threshold: float = 0.5,
|
|
|
- labels: Optional[List[str]] = None,
|
|
|
- layout_postprocess: bool = False,
|
|
|
+ labels: Optional[List[str]] = None
|
|
|
) -> None:
|
|
|
"""Initialize the DetPostProcess class.
|
|
|
|
|
|
@@ -610,11 +639,16 @@ class DetPostProcess:
|
|
|
layout_postprocess (bool, optional): Whether to apply layout post-processing. Defaults to False.
|
|
|
"""
|
|
|
super().__init__()
|
|
|
- self.threshold = threshold
|
|
|
self.labels = labels
|
|
|
- self.layout_postprocess = layout_postprocess
|
|
|
|
|
|
- def apply(self, boxes: ndarray, img_size, threshold: Union[float, dict]) -> Boxes:
|
|
|
+ def apply(self,
|
|
|
+ boxes: ndarray,
|
|
|
+ img_size: Tuple[int, int],
|
|
|
+ threshold: Union[float, dict],
|
|
|
+ layout_nms: bool,
|
|
|
+ layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]],
|
|
|
+ layout_merge_bboxes_mode: Optional[str]
|
|
|
+ ) -> Boxes:
|
|
|
"""Apply post-processing to the detection boxes.
|
|
|
|
|
|
Args:
|
|
|
@@ -631,9 +665,8 @@ class DetPostProcess:
|
|
|
category_filtered_boxes = []
|
|
|
for cat_id in np.unique(boxes[:, 0]):
|
|
|
category_boxes = boxes[boxes[:, 0] == cat_id]
|
|
|
- category_scores = category_boxes[:, 1]
|
|
|
category_threshold = threshold.get(int(cat_id), 0.5)
|
|
|
- selected_indices = category_scores > category_threshold
|
|
|
+ selected_indices = (category_boxes[:, 1] > category_threshold) & (category_boxes[:, 0] > -1)
|
|
|
category_filtered_boxes.append(category_boxes[selected_indices])
|
|
|
boxes = (
|
|
|
np.vstack(category_filtered_boxes)
|
|
|
@@ -641,69 +674,37 @@ class DetPostProcess:
|
|
|
else np.array([])
|
|
|
)
|
|
|
|
|
|
- if self.layout_postprocess:
|
|
|
+ if layout_nms:
|
|
|
filtered_boxes = []
|
|
|
### Layout postprocess for NMS
|
|
|
- for cat_id in np.unique(boxes[:, 0]):
|
|
|
- category_boxes = boxes[boxes[:, 0] == cat_id]
|
|
|
- category_scores = category_boxes[:, 1]
|
|
|
- if len(category_boxes) > 0:
|
|
|
- nms_indices = non_max_suppression(
|
|
|
- category_boxes[:, 2:], category_scores, 0.5
|
|
|
- )
|
|
|
- category_boxes = category_boxes[nms_indices]
|
|
|
- keep_boxes = []
|
|
|
- for i, box in enumerate(category_boxes):
|
|
|
- if all(
|
|
|
- not is_box_mostly_inside(box[2:], other_box[2:])
|
|
|
- for j, other_box in enumerate(category_boxes)
|
|
|
- if i != j
|
|
|
- ):
|
|
|
- keep_boxes.append(box)
|
|
|
- filtered_boxes.extend(keep_boxes)
|
|
|
- boxes = np.array(filtered_boxes)
|
|
|
- ### Layout postprocess for removing boxes inside image category box
|
|
|
- if self.labels and "image" in self.labels:
|
|
|
- image_cls_id = self.labels.index("image")
|
|
|
- if len(boxes) > 0:
|
|
|
- image_boxes = boxes[boxes[:, 0] == image_cls_id]
|
|
|
- other_boxes = boxes[boxes[:, 0] != image_cls_id]
|
|
|
- to_keep = []
|
|
|
- for box in other_boxes:
|
|
|
- keep = True
|
|
|
- for img_box in image_boxes:
|
|
|
- if (
|
|
|
- box[2] >= img_box[2]
|
|
|
- and box[3] >= img_box[3]
|
|
|
- and box[4] <= img_box[4]
|
|
|
- and box[5] <= img_box[5]
|
|
|
- ):
|
|
|
- keep = False
|
|
|
- break
|
|
|
- if keep:
|
|
|
- to_keep.append(box)
|
|
|
- boxes = (
|
|
|
- np.vstack([image_boxes, to_keep]) if to_keep else image_boxes
|
|
|
- )
|
|
|
- ### Layout postprocess for overlaps
|
|
|
- final_boxes = []
|
|
|
- while len(boxes) > 0:
|
|
|
- current_box = boxes[0]
|
|
|
- current_score = current_box[1]
|
|
|
- overlaps = [current_box]
|
|
|
- non_overlaps = []
|
|
|
- for other_box in boxes[1:]:
|
|
|
- iou = compute_iou(current_box[2:], other_box[2:])
|
|
|
- if iou > 0.95:
|
|
|
- if other_box[1] > current_score:
|
|
|
- overlaps.append(other_box)
|
|
|
- else:
|
|
|
- non_overlaps.append(other_box)
|
|
|
- best_box = max(overlaps, key=lambda x: x[1])
|
|
|
- final_boxes.append(best_box)
|
|
|
- boxes = np.array(non_overlaps)
|
|
|
- boxes = np.array(final_boxes)
|
|
|
-
|
|
|
+ selected_indices = nms(boxes, iou_same=0.6, iou_diff=0.98)
|
|
|
+ boxes = np.array(boxes[selected_indices])
|
|
|
+
|
|
|
+ if layout_merge_bboxes_mode:
|
|
|
+ assert layout_merge_bboxes_mode in ["union", "large", "small"], \
|
|
|
+ f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
|
|
|
+
|
|
|
+ if layout_merge_bboxes_mode == "union":
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ formula_index = self.labels.index("formula") if "formula" in self.labels else None
|
|
|
+ contains_other, contained_by_other = check_containment(boxes, formula_index)
|
|
|
+ if layout_merge_bboxes_mode == "large":
|
|
|
+ boxes = boxes[contained_by_other == 0]
|
|
|
+ elif layout_merge_bboxes_mode == "small":
|
|
|
+ boxes = boxes[(contains_other == 0) | (contained_by_other == 1)]
|
|
|
+
|
|
|
+ if layout_unclip_ratio:
|
|
|
+ if isinstance(layout_unclip_ratio, float):
|
|
|
+ layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
|
|
|
+ elif isinstance(layout_unclip_ratio, (tuple, list)):
|
|
|
+ assert len(layout_unclip_ratio) == 2, f"The length of `layout_unclip_ratio` should be 2."
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"The type of `layout_unclip_ratio` must be float or Tuple[float, float], but got {type(layout_unclip_ratio)}."
|
|
|
+ )
|
|
|
+ boxes = unclip_boxes(boxes, layout_unclip_ratio)
|
|
|
+
|
|
|
if boxes.shape[1] == 6:
|
|
|
"""For Normal Object Detection"""
|
|
|
boxes = restructured_boxes(boxes, self.labels, img_size)
|
|
|
@@ -722,6 +723,9 @@ class DetPostProcess:
|
|
|
batch_outputs: List[dict],
|
|
|
datas: List[dict],
|
|
|
threshold: Optional[Union[float, dict]] = None,
|
|
|
+ layout_nms: bool = False,
|
|
|
+ layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
|
|
|
+ layout_merge_bboxes_mode: Optional[str] = None,
|
|
|
) -> List[Boxes]:
|
|
|
"""Apply the post-processing to a batch of outputs.
|
|
|
|
|
|
@@ -737,7 +741,10 @@ class DetPostProcess:
|
|
|
boxes = self.apply(
|
|
|
output["boxes"],
|
|
|
data["ori_img_size"],
|
|
|
- threshold if threshold is not None else self.threshold,
|
|
|
+ threshold,
|
|
|
+ layout_nms,
|
|
|
+ layout_unclip_ratio,
|
|
|
+ layout_merge_bboxes_mode
|
|
|
)
|
|
|
outputs.append(boxes)
|
|
|
return outputs
|