Răsfoiți Sursa

Merge pull request #9 from myhloli/refactor-pipeline

feat(model): add text region handling and improve overlap resolution
Xiaomeng Zhao 7 luni în urmă
părinte
comite
b7ff7ded64
1 a modificat fișierele cu 59 adăugiri și 1 ștergeri
  1. 59 1
      magic_pdf/model/sub_modules/model_utils.py

+ 59 - 1
magic_pdf/model/sub_modules/model_utils.py

@@ -2,6 +2,8 @@ import time
 import torch
 from loguru import logger
 import numpy as np
+
+from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio
 from magic_pdf.libs.clean_memory import clean_memory
 
 
@@ -188,9 +190,46 @@ def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0
     return [table for i, table in enumerate(table_res_list) if i not in big_tables_idx]
 
 
+def remove_overlaps_min_blocks(res_list):
+    #  重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
+    #  删除重叠blocks中较小的那些
+    need_remove = []
+    for res1 in res_list:
+        for res2 in res_list:
+            if res1 != res2:
+                overlap_box = get_minbox_if_overlap_by_ratio(
+                    res1['bbox'], res2['bbox'], 0.8
+                )
+                if overlap_box is not None:
+                    res_to_remove = next(
+                        (res for res in res_list if res['bbox'] == overlap_box),
+                        None,
+                    )
+                    if (
+                        res_to_remove is not None
+                        and res_to_remove not in need_remove
+                    ):
+                        large_res = res1 if res1 != res_to_remove else res2
+                        x1, y1, x2, y2 = large_res['bbox']
+                        sx1, sy1, sx2, sy2 = res_to_remove['bbox']
+                        x1 = min(x1, sx1)
+                        y1 = min(y1, sy1)
+                        x2 = max(x2, sx2)
+                        y2 = max(y2, sy2)
+                        large_res['bbox'] = [x1, y1, x2, y2]
+                        need_remove.append(res_to_remove)
+
+    if len(need_remove) > 0:
+        for res in need_remove:
+            res_list.remove(res)
+
+    return res_list, need_remove
+
+
 def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
     """Extract OCR, table and other regions from layout results."""
     ocr_res_list = []
+    text_res_list = []
     table_res_list = []
     table_indices = []
     single_page_mfdetrec_res = []
@@ -204,11 +243,14 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
                 "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                          int(res['poly'][4]), int(res['poly'][5])],
             })
-        elif category_id in [0, 1, 2, 4, 6, 7]:  # OCR regions
+        elif category_id in [0, 2, 4, 6, 7]:  # OCR regions
             ocr_res_list.append(res)
         elif category_id == 5:  # Table regions
             table_res_list.append(res)
             table_indices.append(i)
+        elif category_id in [1]:  # Text regions
+            res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
+            text_res_list.append(res)
 
     # Process tables: merge high IoU tables first, then filter nested tables
     table_res_list, table_indices = merge_high_iou_tables(
@@ -226,6 +268,22 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
         for idx in sorted(to_remove, reverse=True):
             del layout_res[idx]
 
+    # Remove overlaps in OCR and text regions
+    text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
+    for res in text_res_list:
+        # 将res的poly使用bbox重构
+        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
+                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
+        # 删除res的bbox
+        del res['bbox']
+
+    ocr_res_list.extend(text_res_list)
+
+    if len(need_remove) > 0:
+        for res in need_remove:
+            del res['bbox']
+            layout_res.remove(res)
+
     return ocr_res_list, filtered_table_res_list, single_page_mfdetrec_res