浏览代码

feat(model): improve table recognition by merging and filtering tables

- Add functions to calculate IoU, check if tables are inside each other, and merge tables
- Implement table merging for high IoU tables
- Add filtering to remove nested tables that don't overlap but cover a large area
- Update table_res_list and layout_res to reflect these changes
myhloli 7 月之前
父节点
当前提交
df7ae4042d
共有 2 个文件被更改,包括 189 次插入8 次删除
  1. 0 1
      magic_pdf/model/batch_analyze.py
  2. 189 7
      magic_pdf/model/sub_modules/model_utils.py

+ 0 - 1
magic_pdf/model/batch_analyze.py

@@ -150,7 +150,6 @@ class BatchAnalyze:
         # 表格识别 table recognition
         # 表格识别 table recognition
         if self.model.apply_table:
         if self.model.apply_table:
             table_start = time.time()
             table_start = time.time()
-            table_count = 0
             # for table_res_list_dict in table_res_list_all_page:
             # for table_res_list_dict in table_res_list_all_page:
             for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
             for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
                 _lang = table_res_dict['lang']
                 _lang = table_res_dict['lang']

+ 189 - 7
magic_pdf/model/sub_modules/model_utils.py

@@ -29,22 +29,204 @@ def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
     return return_image, return_list
     return return_image, return_list
 
 
 
 
-# Select regions for OCR / formula regions / table regions
-def get_res_list_from_layout_res(layout_res):
+def get_coords_and_area(table):
+    """Extract coordinates and area from a table."""
+    xmin, ymin = int(table['poly'][0]), int(table['poly'][1])
+    xmax, ymax = int(table['poly'][4]), int(table['poly'][5])
+    area = (xmax - xmin) * (ymax - ymin)
+    return xmin, ymin, xmax, ymax, area
+
+
+def calculate_intersection(box1, box2):
+    """Calculate intersection coordinates between two boxes."""
+    intersection_xmin = max(box1[0], box2[0])
+    intersection_ymin = max(box1[1], box2[1])
+    intersection_xmax = min(box1[2], box2[2])
+    intersection_ymax = min(box1[3], box2[3])
+
+    # Check if intersection is valid
+    if intersection_xmax <= intersection_xmin or intersection_ymax <= intersection_ymin:
+        return None
+
+    return intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax
+
+
+def calculate_iou(box1, box2):
+    """Calculate IoU between two boxes."""
+    intersection = calculate_intersection(box1[:4], box2[:4])
+
+    if not intersection:
+        return 0
+
+    intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
+    intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
+
+    area1, area2 = box1[4], box2[4]
+    union_area = area1 + area2 - intersection_area
+
+    return intersection_area / union_area if union_area > 0 else 0
+
+
+def is_inside(small_box, big_box, overlap_threshold=0.8):
+    """Check if small_box is inside big_box by at least overlap_threshold."""
+    intersection = calculate_intersection(small_box[:4], big_box[:4])
+
+    if not intersection:
+        return False
+
+    intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
+    intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
+
+    # Check if overlap exceeds threshold
+    return intersection_area >= overlap_threshold * small_box[4]
+
+
+def do_overlap(box1, box2):
+    """Check if two boxes overlap."""
+    return calculate_intersection(box1[:4], box2[:4]) is not None
+
+
+def merge_high_iou_tables(table_res_list, layout_res, table_indices, iou_threshold=0.7):
+    """Merge tables with IoU > threshold."""
+    if len(table_res_list) < 2:
+        return table_res_list, table_indices
+
+    table_info = [get_coords_and_area(table) for table in table_res_list]
+    merged = True
+
+    while merged:
+        merged = False
+        i = 0
+        while i < len(table_res_list) - 1:
+            j = i + 1
+            while j < len(table_res_list):
+                iou = calculate_iou(table_info[i], table_info[j])
+
+                if iou > iou_threshold:
+                    # Merge tables by taking their union
+                    x1_min, y1_min, x1_max, y1_max, _ = table_info[i]
+                    x2_min, y2_min, x2_max, y2_max, _ = table_info[j]
+
+                    union_xmin = min(x1_min, x2_min)
+                    union_ymin = min(y1_min, y2_min)
+                    union_xmax = max(x1_max, x2_max)
+                    union_ymax = max(y1_max, y2_max)
+
+                    # Create merged table
+                    merged_table = table_res_list[i].copy()
+                    merged_table['poly'][0] = union_xmin
+                    merged_table['poly'][1] = union_ymin
+                    merged_table['poly'][2] = union_xmax
+                    merged_table['poly'][3] = union_ymin
+                    merged_table['poly'][4] = union_xmax
+                    merged_table['poly'][5] = union_ymax
+                    merged_table['poly'][6] = union_xmin
+                    merged_table['poly'][7] = union_ymax
+
+                    # Update layout_res
+                    to_remove = [table_indices[j], table_indices[i]]
+                    for idx in sorted(to_remove, reverse=True):
+                        del layout_res[idx]
+                    layout_res.append(merged_table)
+
+                    # Update tracking lists
+                    table_indices = [k if k < min(to_remove) else
+                                     k - 1 if k < max(to_remove) else
+                                     k - 2 if k > max(to_remove) else
+                                     len(layout_res) - 1
+                                     for k in table_indices
+                                     if k not in to_remove]
+                    table_indices.append(len(layout_res) - 1)
+
+                    # Update table lists
+                    table_res_list.pop(j)
+                    table_res_list.pop(i)
+                    table_res_list.append(merged_table)
+
+                    # Update table_info
+                    table_info = [get_coords_and_area(table) for table in table_res_list]
+
+                    merged = True
+                    break
+                j += 1
+
+            if merged:
+                break
+            i += 1
+
+    return table_res_list, table_indices
+
+
+def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0.8):
+    """Remove big tables containing multiple smaller tables within them."""
+    if len(table_res_list) < 3:
+        return table_res_list
+
+    table_info = [get_coords_and_area(table) for table in table_res_list]
+    big_tables_idx = []
+
+    for i in range(len(table_res_list)):
+        # Find tables inside this one
+        tables_inside = [j for j in range(len(table_res_list))
+                         if i != j and is_inside(table_info[j], table_info[i], overlap_threshold)]
+
+        # Continue if there are at least 2 tables inside
+        if len(tables_inside) >= 2:
+            # Check if inside tables overlap with each other
+            tables_overlap = any(do_overlap(table_info[tables_inside[idx1]], table_info[tables_inside[idx2]])
+                                 for idx1 in range(len(tables_inside))
+                                 for idx2 in range(idx1 + 1, len(tables_inside)))
+
+            # If no overlaps, check area condition
+            if not tables_overlap:
+                total_inside_area = sum(table_info[j][4] for j in tables_inside)
+                big_table_area = table_info[i][4]
+
+                if total_inside_area > area_threshold * big_table_area:
+                    big_tables_idx.append(i)
+
+    return [table for i, table in enumerate(table_res_list) if i not in big_tables_idx]
+
+
+def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
+    """Extract OCR, table and other regions from layout results."""
     ocr_res_list = []
     ocr_res_list = []
     table_res_list = []
     table_res_list = []
+    table_indices = []
     single_page_mfdetrec_res = []
     single_page_mfdetrec_res = []
-    for res in layout_res:
-        if int(res['category_id']) in [13, 14]:
+
+    # Categorize regions
+    for i, res in enumerate(layout_res):
+        category_id = int(res['category_id'])
+
+        if category_id in [13, 14]:  # Formula regions
             single_page_mfdetrec_res.append({
             single_page_mfdetrec_res.append({
                 "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                 "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                          int(res['poly'][4]), int(res['poly'][5])],
                          int(res['poly'][4]), int(res['poly'][5])],
             })
             })
-        elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
+        elif category_id in [0, 1, 2, 4, 6, 7]:  # OCR regions
             ocr_res_list.append(res)
             ocr_res_list.append(res)
-        elif int(res['category_id']) in [5]:
+        elif category_id == 5:  # Table regions
             table_res_list.append(res)
             table_res_list.append(res)
-    return ocr_res_list, table_res_list, single_page_mfdetrec_res
+            table_indices.append(i)
+
+    # Process tables: merge high IoU tables first, then filter nested tables
+    table_res_list, table_indices = merge_high_iou_tables(
+        table_res_list, layout_res, table_indices, iou_threshold)
+
+    filtered_table_res_list = filter_nested_tables(
+        table_res_list, overlap_threshold, area_threshold)
+
+    # Remove filtered out tables from layout_res
+    if len(filtered_table_res_list) < len(table_res_list):
+        kept_tables = set(id(table) for table in filtered_table_res_list)
+        to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
+                     if id(table) not in kept_tables]
+
+        for idx in sorted(to_remove, reverse=True):
+            del layout_res[idx]
+
+    return ocr_res_list, filtered_table_res_list, single_page_mfdetrec_res
 
 
 
 
 def clean_vram(device, vram_threshold=8):
 def clean_vram(device, vram_threshold=8):