|
|
@@ -29,22 +29,204 @@ def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
|
|
|
return return_image, return_list
|
|
|
|
|
|
|
|
|
-# Select regions for OCR / formula regions / table regions
|
|
|
-def get_res_list_from_layout_res(layout_res):
|
|
|
+def get_coords_and_area(table):
|
|
|
+ """Extract coordinates and area from a table."""
|
|
|
+ xmin, ymin = int(table['poly'][0]), int(table['poly'][1])
|
|
|
+ xmax, ymax = int(table['poly'][4]), int(table['poly'][5])
|
|
|
+ area = (xmax - xmin) * (ymax - ymin)
|
|
|
+ return xmin, ymin, xmax, ymax, area
|
|
|
+
|
|
|
+
|
|
|
+def calculate_intersection(box1, box2):
|
|
|
+ """Calculate intersection coordinates between two boxes."""
|
|
|
+ intersection_xmin = max(box1[0], box2[0])
|
|
|
+ intersection_ymin = max(box1[1], box2[1])
|
|
|
+ intersection_xmax = min(box1[2], box2[2])
|
|
|
+ intersection_ymax = min(box1[3], box2[3])
|
|
|
+
|
|
|
+ # Check if intersection is valid
|
|
|
+ if intersection_xmax <= intersection_xmin or intersection_ymax <= intersection_ymin:
|
|
|
+ return None
|
|
|
+
|
|
|
+ return intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax
|
|
|
+
|
|
|
+
|
|
|
+def calculate_iou(box1, box2):
|
|
|
+ """Calculate IoU between two boxes."""
|
|
|
+ intersection = calculate_intersection(box1[:4], box2[:4])
|
|
|
+
|
|
|
+ if not intersection:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
|
|
|
+ intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
|
|
|
+
|
|
|
+ area1, area2 = box1[4], box2[4]
|
|
|
+ union_area = area1 + area2 - intersection_area
|
|
|
+
|
|
|
+ return intersection_area / union_area if union_area > 0 else 0
|
|
|
+
|
|
|
+
|
|
|
+def is_inside(small_box, big_box, overlap_threshold=0.8):
|
|
|
+ """Check if small_box is inside big_box by at least overlap_threshold."""
|
|
|
+ intersection = calculate_intersection(small_box[:4], big_box[:4])
|
|
|
+
|
|
|
+ if not intersection:
|
|
|
+ return False
|
|
|
+
|
|
|
+ intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
|
|
|
+ intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
|
|
|
+
|
|
|
+ # Check if overlap exceeds threshold
|
|
|
+ return intersection_area >= overlap_threshold * small_box[4]
|
|
|
+
|
|
|
+
|
|
|
+def do_overlap(box1, box2):
|
|
|
+ """Check if two boxes overlap."""
|
|
|
+ return calculate_intersection(box1[:4], box2[:4]) is not None
|
|
|
+
|
|
|
+
|
|
|
+def merge_high_iou_tables(table_res_list, layout_res, table_indices, iou_threshold=0.7):
|
|
|
+ """Merge tables with IoU > threshold."""
|
|
|
+ if len(table_res_list) < 2:
|
|
|
+ return table_res_list, table_indices
|
|
|
+
|
|
|
+ table_info = [get_coords_and_area(table) for table in table_res_list]
|
|
|
+ merged = True
|
|
|
+
|
|
|
+ while merged:
|
|
|
+ merged = False
|
|
|
+ i = 0
|
|
|
+ while i < len(table_res_list) - 1:
|
|
|
+ j = i + 1
|
|
|
+ while j < len(table_res_list):
|
|
|
+ iou = calculate_iou(table_info[i], table_info[j])
|
|
|
+
|
|
|
+ if iou > iou_threshold:
|
|
|
+ # Merge tables by taking their union
|
|
|
+ x1_min, y1_min, x1_max, y1_max, _ = table_info[i]
|
|
|
+ x2_min, y2_min, x2_max, y2_max, _ = table_info[j]
|
|
|
+
|
|
|
+ union_xmin = min(x1_min, x2_min)
|
|
|
+ union_ymin = min(y1_min, y2_min)
|
|
|
+ union_xmax = max(x1_max, x2_max)
|
|
|
+ union_ymax = max(y1_max, y2_max)
|
|
|
+
|
|
|
+ # Create merged table
|
|
|
+ merged_table = table_res_list[i].copy()
|
|
|
+ merged_table['poly'][0] = union_xmin
|
|
|
+ merged_table['poly'][1] = union_ymin
|
|
|
+ merged_table['poly'][2] = union_xmax
|
|
|
+ merged_table['poly'][3] = union_ymin
|
|
|
+ merged_table['poly'][4] = union_xmax
|
|
|
+ merged_table['poly'][5] = union_ymax
|
|
|
+ merged_table['poly'][6] = union_xmin
|
|
|
+ merged_table['poly'][7] = union_ymax
|
|
|
+
|
|
|
+ # Update layout_res
|
|
|
+ to_remove = [table_indices[j], table_indices[i]]
|
|
|
+ for idx in sorted(to_remove, reverse=True):
|
|
|
+ del layout_res[idx]
|
|
|
+ layout_res.append(merged_table)
|
|
|
+
|
|
|
+ # Update tracking lists
|
|
|
+ table_indices = [k if k < min(to_remove) else
|
|
|
+ k - 1 if k < max(to_remove) else
|
|
|
+ k - 2 if k > max(to_remove) else
|
|
|
+ len(layout_res) - 1
|
|
|
+ for k in table_indices
|
|
|
+ if k not in to_remove]
|
|
|
+ table_indices.append(len(layout_res) - 1)
|
|
|
+
|
|
|
+ # Update table lists
|
|
|
+ table_res_list.pop(j)
|
|
|
+ table_res_list.pop(i)
|
|
|
+ table_res_list.append(merged_table)
|
|
|
+
|
|
|
+ # Update table_info
|
|
|
+ table_info = [get_coords_and_area(table) for table in table_res_list]
|
|
|
+
|
|
|
+ merged = True
|
|
|
+ break
|
|
|
+ j += 1
|
|
|
+
|
|
|
+ if merged:
|
|
|
+ break
|
|
|
+ i += 1
|
|
|
+
|
|
|
+ return table_res_list, table_indices
|
|
|
+
|
|
|
+
|
|
|
+def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0.8):
|
|
|
+ """Remove big tables containing multiple smaller tables within them."""
|
|
|
+ if len(table_res_list) < 3:
|
|
|
+ return table_res_list
|
|
|
+
|
|
|
+ table_info = [get_coords_and_area(table) for table in table_res_list]
|
|
|
+ big_tables_idx = []
|
|
|
+
|
|
|
+ for i in range(len(table_res_list)):
|
|
|
+ # Find tables inside this one
|
|
|
+ tables_inside = [j for j in range(len(table_res_list))
|
|
|
+ if i != j and is_inside(table_info[j], table_info[i], overlap_threshold)]
|
|
|
+
|
|
|
+ # Continue if there are at least 2 tables inside
|
|
|
+ if len(tables_inside) >= 2:
|
|
|
+ # Check if inside tables overlap with each other
|
|
|
+ tables_overlap = any(do_overlap(table_info[tables_inside[idx1]], table_info[tables_inside[idx2]])
|
|
|
+ for idx1 in range(len(tables_inside))
|
|
|
+ for idx2 in range(idx1 + 1, len(tables_inside)))
|
|
|
+
|
|
|
+ # If no overlaps, check area condition
|
|
|
+ if not tables_overlap:
|
|
|
+ total_inside_area = sum(table_info[j][4] for j in tables_inside)
|
|
|
+ big_table_area = table_info[i][4]
|
|
|
+
|
|
|
+ if total_inside_area > area_threshold * big_table_area:
|
|
|
+ big_tables_idx.append(i)
|
|
|
+
|
|
|
+ return [table for i, table in enumerate(table_res_list) if i not in big_tables_idx]
|
|
|
+
|
|
|
+
|
|
|
+def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
|
|
|
+ """Extract OCR, table and other regions from layout results."""
|
|
|
ocr_res_list = []
|
|
|
table_res_list = []
|
|
|
+ table_indices = []
|
|
|
single_page_mfdetrec_res = []
|
|
|
- for res in layout_res:
|
|
|
- if int(res['category_id']) in [13, 14]:
|
|
|
+
|
|
|
+ # Categorize regions
|
|
|
+ for i, res in enumerate(layout_res):
|
|
|
+ category_id = int(res['category_id'])
|
|
|
+
|
|
|
+ if category_id in [13, 14]: # Formula regions
|
|
|
single_page_mfdetrec_res.append({
|
|
|
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
|
|
|
int(res['poly'][4]), int(res['poly'][5])],
|
|
|
})
|
|
|
- elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
|
|
|
+ elif category_id in [0, 1, 2, 4, 6, 7]: # OCR regions
|
|
|
ocr_res_list.append(res)
|
|
|
- elif int(res['category_id']) in [5]:
|
|
|
+ elif category_id == 5: # Table regions
|
|
|
table_res_list.append(res)
|
|
|
- return ocr_res_list, table_res_list, single_page_mfdetrec_res
|
|
|
+ table_indices.append(i)
|
|
|
+
|
|
|
+ # Process tables: merge high IoU tables first, then filter nested tables
|
|
|
+ table_res_list, table_indices = merge_high_iou_tables(
|
|
|
+ table_res_list, layout_res, table_indices, iou_threshold)
|
|
|
+
|
|
|
+ filtered_table_res_list = filter_nested_tables(
|
|
|
+ table_res_list, overlap_threshold, area_threshold)
|
|
|
+
|
|
|
+ # Remove filtered out tables from layout_res
|
|
|
+ if len(filtered_table_res_list) < len(table_res_list):
|
|
|
+ kept_tables = set(id(table) for table in filtered_table_res_list)
|
|
|
+ to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
|
|
|
+ if id(table) not in kept_tables]
|
|
|
+
|
|
|
+ for idx in sorted(to_remove, reverse=True):
|
|
|
+ del layout_res[idx]
|
|
|
+
|
|
|
+ return ocr_res_list, filtered_table_res_list, single_page_mfdetrec_res
|
|
|
|
|
|
|
|
|
def clean_vram(device, vram_threshold=8):
|