Browse Source

文本框与标题框重叠,优先信任文本框

赵小蒙 1 year ago
parent
commit
83641d3d97
2 changed files with 31 additions and 20 deletions
  1. 29 18
      magic_pdf/model/magic_model.py
  2. 2 2
      magic_pdf/pre_proc/ocr_detect_all_bboxes.py

+ 29 - 18
magic_pdf/model/magic_model.py

@@ -21,8 +21,8 @@ class MagicModel:
     """
 
     def __fix_axis(self):
-        need_remove_list = []
         for model_page_info in self.__model_list:
+            need_remove_list = []
             page_no = model_page_info["page_info"]["page_no"]
             horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
                 model_page_info, self.__docs[page_no]
@@ -43,12 +43,24 @@ class MagicModel:
             for need_remove in need_remove_list:
                 layout_dets.remove(need_remove)
 
+    def __fix_by_confidence(self):
+        for model_page_info in self.__model_list:
+            need_remove_list = []
+            layout_dets = model_page_info["layout_dets"]
+            for layout_det in layout_dets:
+                if layout_det["score"] < 0.6:
+                    need_remove_list.append(layout_det)
+                else:
+                    continue
+            for need_remove in need_remove_list:
+                layout_dets.remove(need_remove)
 
     def __init__(self, model_list: list, docs: fitz.Document):
         self.__model_list = model_list
         self.__docs = docs
         self.__fix_axis()
-        #@todo 移除置信度小于0.6的所有block
+        #@TODO 删除掉一些低置信度的会导致分段错误,后面再修复
+        # self.__fix_by_confidence()
 
     def __reduct_overlap(self, bboxes):
         N = len(bboxes)
@@ -63,13 +75,13 @@ class MagicModel:
         return [bboxes[i] for i in range(N) if keep[i]]
 
     def __tie_up_category_by_distance(
-        self, page_no, subject_category_id, object_category_id
+            self, page_no, subject_category_id, object_category_id
     ):
         """
         假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject
         """
         ret = []
-        MAX_DIS_OF_POINT = 10**9 + 7
+        MAX_DIS_OF_POINT = 10 ** 9 + 7
 
         subjects = self.__reduct_overlap(
             list(
@@ -112,8 +124,8 @@ class MagicModel:
         for i in range(N):
             for j in range(i):
                 if (
-                    all_bboxes[i]["category_id"] == subject_category_id
-                    and all_bboxes[j]["category_id"] == subject_category_id
+                        all_bboxes[i]["category_id"] == subject_category_id
+                        and all_bboxes[j]["category_id"] == subject_category_id
                 ):
                     continue
 
@@ -143,9 +155,9 @@ class MagicModel:
                 if pos_flag_count > 1:
                     continue
                 if (
-                    all_bboxes[j]["category_id"] != object_category_id
-                    or j in used
-                    or dis[i][j] == MAX_DIS_OF_POINT
+                        all_bboxes[j]["category_id"] != object_category_id
+                        or j in used
+                        or dis[i][j] == MAX_DIS_OF_POINT
                 ):
                     continue
                 arr.append((dis[i][j], j))
@@ -174,10 +186,10 @@ class MagicModel:
                         continue
 
                     if (
-                        all_bboxes[k]["category_id"] != object_category_id
-                        or k in used
-                        or k in seen
-                        or dis[j][k] == MAX_DIS_OF_POINT
+                            all_bboxes[k]["category_id"] != object_category_id
+                            or k in used
+                            or k in seen
+                            or dis[j][k] == MAX_DIS_OF_POINT
                     ):
                         continue
                     is_nearest = True
@@ -185,12 +197,10 @@ class MagicModel:
                         if l in (j, k) or l in used or l in seen:
                             continue
 
-
                         if not float_gt(dis[l][k], dis[j][k]):
                             is_nearest = False
                             break
 
-
                     if is_nearest:
                         tmp.append(k)
                         seen.add(k)
@@ -303,8 +313,8 @@ class MagicModel:
             candidates = []
             for j in range(N):
                 if (
-                    all_bboxes[j]["category_id"] != subject_category_id
-                    or j in with_caption_subject
+                        all_bboxes[j]["category_id"] != subject_category_id
+                        or j in with_caption_subject
                 ):
                     continue
                 candidates.append((dis[i][j], j))
@@ -326,7 +336,7 @@ class MagicModel:
         ]
 
     def get_tables(
-        self, page_no: int
+            self, page_no: int
     ) -> list:  # 3个坐标, caption, table主体,table-note
         with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
         with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
@@ -441,6 +451,7 @@ class MagicModel:
                     blocks.append(block)
         return blocks
 
+
 if __name__ == "__main__":
     drw = DiskReaderWriter(r"D:/project/20231108code-clean")
     if 0:

+ 2 - 2
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -28,7 +28,7 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
         all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None])
 
     '''block嵌套问题解决'''
-    '''文本框与标题框重叠,优先信任标题框'''
+    '''文本框与标题框重叠,优先信任文本框'''
     all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
     '''任何框体与舍弃框重叠,优先信任舍弃框'''
     all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
@@ -60,7 +60,7 @@ def fix_text_overlap_title_blocks(all_bboxes):
             text_block_bbox = text_block[0], text_block[1], text_block[2], text_block[3]
             title_block_bbox = title_block[0], title_block[1], title_block[2], title_block[3]
             if calculate_iou(text_block_bbox, title_block_bbox) > 0.8:
-                all_bboxes.remove(text_block)
+                all_bboxes.remove(title_block)
 
     return all_bboxes