Parcourir la source

refactor: enhance overlap handling in pipeline_magic_model.py for image and table bodies

myhloli il y a 4 mois
Parent
commit
6094699cdf
1 fichiers modifiés avec 91 ajouts et 39 suppressions
  1. 91 39
      mineru/backend/pipeline/pipeline_magic_model.py

+ 91 - 39
mineru/backend/pipeline/pipeline_magic_model.py

@@ -1,4 +1,4 @@
-from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
+from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
 from mineru.utils.enum_class import CategoryId, ContentType
 
 
@@ -13,7 +13,54 @@ class MagicModel:
         self.__fix_by_remove_low_confidence()
         """删除高iou(>0.9)数据中置信度较低的那个"""
         self.__fix_by_remove_high_iou_and_low_confidence()
+        """将部分tbale_footnote修正为image_footnote"""
         self.__fix_footnote()
+        """处理重叠的image_body和table_body"""
+        self.__fix_by_remove_overlap_image_table_body()
+
+
+    def __fix_by_remove_overlap_image_table_body(self):
+        need_remove_list = []
+        layout_dets = self.__page_model_info['layout_dets']
+        image_blocks = list(filter(
+            lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
+        ))
+        table_blocks = list(filter(
+            lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
+        ))
+        def add_need_remove_block(blocks):
+            for i in range(len(blocks)):
+                for j in range(i + 1, len(blocks)):
+                    block1 = blocks[i]
+                    block2 = blocks[j]
+                    overlap_box = get_minbox_if_overlap_by_ratio(
+                        block1['bbox'], block2['bbox'], 0.8
+                    )
+                    if overlap_box is not None:
+                        block_to_remove = next(
+                            (block for block in blocks if block['bbox'] == overlap_box),
+                            None,
+                        )
+                        if (
+                            block_to_remove is not None
+                            and block_to_remove not in need_remove_list
+                        ):
+                            large_block = block1 if block1 != block_to_remove else block2
+                            x1, y1, x2, y2 = large_block['bbox']
+                            sx1, sy1, sx2, sy2 = block_to_remove['bbox']
+                            x1 = min(x1, sx1)
+                            y1 = min(y1, sy1)
+                            x2 = max(x2, sx2)
+                            y2 = max(y2, sy2)
+                            large_block['bbox'] = [x1, y1, x2, y2]
+                            need_remove_list.append(block_to_remove)
+
+        add_need_remove_block(image_blocks)
+        add_need_remove_block(table_blocks)
+
+        for need_remove in need_remove_list:
+            layout_dets.remove(need_remove)
+
 
     def __fix_axis(self):
         need_remove_list = []
@@ -46,42 +93,46 @@ class MagicModel:
 
     def __fix_by_remove_high_iou_and_low_confidence(self):
         need_remove_list = []
-        layout_dets = self.__page_model_info['layout_dets']
+        layout_dets = list(filter(
+            lambda x: x['category_id'] in [
+                    CategoryId.Title,
+                    CategoryId.Text,
+                    CategoryId.ImageBody,
+                    CategoryId.ImageCaption,
+                    CategoryId.TableBody,
+                    CategoryId.TableCaption,
+                    CategoryId.TableFootnote,
+                    CategoryId.InterlineEquation_Layout,
+                    CategoryId.InterlineEquationNumber_Layout,
+                ], self.__page_model_info['layout_dets']
+            )
+        )
         for i in range(len(layout_dets)):
             for j in range(i + 1, len(layout_dets)):
                 layout_det1 = layout_dets[i]
                 layout_det2 = layout_dets[j]
-                if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
-                    if (
-                        calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
-                        > 0.9
-                    ):
-                        if layout_det1['score'] < layout_det2['score']:
-                            layout_det_need_remove = layout_det1
-                        else:
-                            layout_det_need_remove = layout_det2
-
-                        if layout_det_need_remove not in need_remove_list:
-                            need_remove_list.append(layout_det_need_remove)
-                    else:
-                        continue
-                else:
-                    continue
+
+                if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
+
+                    layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
+
+                    if layout_det_need_remove not in need_remove_list:
+                        need_remove_list.append(layout_det_need_remove)
+
         for need_remove in need_remove_list:
-            layout_dets.remove(need_remove)
+            self.__page_model_info['layout_dets'].remove(need_remove)
 
     def __fix_footnote(self):
-        # 3: figure, 5: table, 7: footnote
         footnotes = []
         figures = []
         tables = []
 
         for obj in self.__page_model_info['layout_dets']:
-            if obj['category_id'] == 7:
+            if obj['category_id'] == CategoryId.TableFootnote:
                 footnotes.append(obj)
-            elif obj['category_id'] == 3:
+            elif obj['category_id'] == CategoryId.ImageBody:
                 figures.append(obj)
-            elif obj['category_id'] == 5:
+            elif obj['category_id'] == CategoryId.TableBody:
                 tables.append(obj)
             if len(footnotes) * len(figures) == 0:
                 continue
@@ -314,10 +365,10 @@ class MagicModel:
 
     def get_imgs(self):
         with_captions = self.__tie_up_category_by_distance_v3(
-            3, 4
+            CategoryId.ImageBody, CategoryId.ImageCaption
         )
         with_footnotes = self.__tie_up_category_by_distance_v3(
-            3, CategoryId.ImageFootnote
+            CategoryId.ImageBody, CategoryId.ImageFootnote
         )
         ret = []
         for v in with_captions:
@@ -333,10 +384,10 @@ class MagicModel:
 
     def get_tables(self) -> list:
         with_captions = self.__tie_up_category_by_distance_v3(
-            5, 6
+            CategoryId.TableBody, CategoryId.TableCaption
         )
         with_footnotes = self.__tie_up_category_by_distance_v3(
-            5, 7
+            CategoryId.TableBody, CategoryId.TableFootnote
         )
         ret = []
         for v in with_captions:
@@ -385,20 +436,21 @@ class MagicModel:
 
         all_spans = []
         layout_dets = self.__page_model_info['layout_dets']
-        allow_category_id_list = [3, 5, 13, 14, 15]
+        allow_category_id_list = [
+            CategoryId.ImageBody,
+            CategoryId.TableBody,
+            CategoryId.InlineEquation,
+            CategoryId.InterlineEquation_YOLO,
+            CategoryId.OcrText,
+        ]
         """当成span拼接的"""
-        #  3: 'image', # 图片
-        #  5: 'table',       # 表格
-        #  13: 'inline_equation',     # 行内公式
-        #  14: 'interline_equation',      # 行间公式
-        #  15: 'text',      # ocr识别文本
         for layout_det in layout_dets:
             category_id = layout_det['category_id']
             if category_id in allow_category_id_list:
                 span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
-                if category_id == 3:
+                if category_id == CategoryId.ImageBody:
                     span['type'] = ContentType.IMAGE
-                elif category_id == 5:
+                elif category_id == CategoryId.TableBody:
                     # 获取table模型结果
                     latex = layout_det.get('latex', None)
                     html = layout_det.get('html', None)
@@ -407,13 +459,13 @@ class MagicModel:
                     elif html:
                         span['html'] = html
                     span['type'] = ContentType.TABLE
-                elif category_id == 13:
+                elif category_id == CategoryId.InlineEquation:
                     span['content'] = layout_det['latex']
                     span['type'] = ContentType.INLINE_EQUATION
-                elif category_id == 14:
+                elif category_id == CategoryId.InterlineEquation_YOLO:
                     span['content'] = layout_det['latex']
                     span['type'] = ContentType.INTERLINE_EQUATION
-                elif category_id == 15:
+                elif category_id == CategoryId.OcrText:
                     span['content'] = layout_det['text']
                     span['type'] = ContentType.TEXT
                 all_spans.append(span)
@@ -438,4 +490,4 @@ class MagicModel:
                 for col in extra_col:
                     block[col] = item.get(col, None)
                 blocks.append(block)
-        return blocks
+        return blocks