Quellcode durchsuchen

Merge pull request #2873 from myhloli/dev

Dev
Xiaomeng Zhao vor 4 Monaten
Ursprung
Commit
516f4926b4

+ 0 - 3
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -152,9 +152,6 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
     """对block进行fix操作"""
     fix_blocks = fix_block_spans(block_with_spans)
 
-    """同一行被断开的titile合并"""
-    # merge_title_blocks(fix_blocks)
-
     """对block进行排序"""
     sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
 

+ 98 - 38
mineru/backend/pipeline/pipeline_magic_model.py

@@ -1,4 +1,4 @@
-from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
+from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
 from mineru.utils.enum_class import CategoryId, ContentType
 
 
@@ -13,7 +13,62 @@ class MagicModel:
         self.__fix_by_remove_low_confidence()
         """删除高iou(>0.9)数据中置信度较低的那个"""
         self.__fix_by_remove_high_iou_and_low_confidence()
+        """将部分tbale_footnote修正为image_footnote"""
         self.__fix_footnote()
+        """处理重叠的image_body和table_body"""
+        self.__fix_by_remove_overlap_image_table_body()
+
+    def __fix_by_remove_overlap_image_table_body(self):
+        need_remove_list = []
+        layout_dets = self.__page_model_info['layout_dets']
+        image_blocks = list(filter(
+            lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
+        ))
+        table_blocks = list(filter(
+            lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
+        ))
+
+        def add_need_remove_block(blocks):
+            for i in range(len(blocks)):
+                for j in range(i + 1, len(blocks)):
+                    block1 = blocks[i]
+                    block2 = blocks[j]
+                    overlap_box = get_minbox_if_overlap_by_ratio(
+                        block1['bbox'], block2['bbox'], 0.8
+                    )
+                    if overlap_box is not None:
+                        # 判断哪个区块的面积更小,移除较小的区块
+                        area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
+                        area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
+
+                        if area1 <= area2:
+                            block_to_remove = block1
+                            large_block = block2
+                        else:
+                            block_to_remove = block2
+                            large_block = block1
+
+                        if block_to_remove not in need_remove_list:
+                            # 扩展大区块的边界框
+                            x1, y1, x2, y2 = large_block['bbox']
+                            sx1, sy1, sx2, sy2 = block_to_remove['bbox']
+                            x1 = min(x1, sx1)
+                            y1 = min(y1, sy1)
+                            x2 = max(x2, sx2)
+                            y2 = max(y2, sy2)
+                            large_block['bbox'] = [x1, y1, x2, y2]
+                            need_remove_list.append(block_to_remove)
+
+        # 处理图像-图像重叠
+        add_need_remove_block(image_blocks)
+        # 处理表格-表格重叠
+        add_need_remove_block(table_blocks)
+
+        # 从布局中移除标记的区块
+        for need_remove in need_remove_list:
+            if need_remove in layout_dets:
+                layout_dets.remove(need_remove)
+
 
     def __fix_axis(self):
         need_remove_list = []
@@ -46,42 +101,46 @@ class MagicModel:
 
     def __fix_by_remove_high_iou_and_low_confidence(self):
         need_remove_list = []
-        layout_dets = self.__page_model_info['layout_dets']
+        layout_dets = list(filter(
+            lambda x: x['category_id'] in [
+                    CategoryId.Title,
+                    CategoryId.Text,
+                    CategoryId.ImageBody,
+                    CategoryId.ImageCaption,
+                    CategoryId.TableBody,
+                    CategoryId.TableCaption,
+                    CategoryId.TableFootnote,
+                    CategoryId.InterlineEquation_Layout,
+                    CategoryId.InterlineEquationNumber_Layout,
+                ], self.__page_model_info['layout_dets']
+            )
+        )
         for i in range(len(layout_dets)):
             for j in range(i + 1, len(layout_dets)):
                 layout_det1 = layout_dets[i]
                 layout_det2 = layout_dets[j]
-                if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
-                    if (
-                        calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
-                        > 0.9
-                    ):
-                        if layout_det1['score'] < layout_det2['score']:
-                            layout_det_need_remove = layout_det1
-                        else:
-                            layout_det_need_remove = layout_det2
 
-                        if layout_det_need_remove not in need_remove_list:
-                            need_remove_list.append(layout_det_need_remove)
-                    else:
-                        continue
-                else:
-                    continue
+                if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
+
+                    layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
+
+                    if layout_det_need_remove not in need_remove_list:
+                        need_remove_list.append(layout_det_need_remove)
+
         for need_remove in need_remove_list:
-            layout_dets.remove(need_remove)
+            self.__page_model_info['layout_dets'].remove(need_remove)
 
     def __fix_footnote(self):
-        # 3: figure, 5: table, 7: footnote
         footnotes = []
         figures = []
         tables = []
 
         for obj in self.__page_model_info['layout_dets']:
-            if obj['category_id'] == 7:
+            if obj['category_id'] == CategoryId.TableFootnote:
                 footnotes.append(obj)
-            elif obj['category_id'] == 3:
+            elif obj['category_id'] == CategoryId.ImageBody:
                 figures.append(obj)
-            elif obj['category_id'] == 5:
+            elif obj['category_id'] == CategoryId.TableBody:
                 tables.append(obj)
             if len(footnotes) * len(figures) == 0:
                 continue
@@ -314,10 +373,10 @@ class MagicModel:
 
     def get_imgs(self):
         with_captions = self.__tie_up_category_by_distance_v3(
-            3, 4
+            CategoryId.ImageBody, CategoryId.ImageCaption
         )
         with_footnotes = self.__tie_up_category_by_distance_v3(
-            3, CategoryId.ImageFootnote
+            CategoryId.ImageBody, CategoryId.ImageFootnote
         )
         ret = []
         for v in with_captions:
@@ -333,10 +392,10 @@ class MagicModel:
 
     def get_tables(self) -> list:
         with_captions = self.__tie_up_category_by_distance_v3(
-            5, 6
+            CategoryId.TableBody, CategoryId.TableCaption
         )
         with_footnotes = self.__tie_up_category_by_distance_v3(
-            5, 7
+            CategoryId.TableBody, CategoryId.TableFootnote
         )
         ret = []
         for v in with_captions:
@@ -385,20 +444,21 @@ class MagicModel:
 
         all_spans = []
         layout_dets = self.__page_model_info['layout_dets']
-        allow_category_id_list = [3, 5, 13, 14, 15]
+        allow_category_id_list = [
+            CategoryId.ImageBody,
+            CategoryId.TableBody,
+            CategoryId.InlineEquation,
+            CategoryId.InterlineEquation_YOLO,
+            CategoryId.OcrText,
+        ]
         """当成span拼接的"""
-        #  3: 'image', # 图片
-        #  5: 'table',       # 表格
-        #  13: 'inline_equation',     # 行内公式
-        #  14: 'interline_equation',      # 行间公式
-        #  15: 'text',      # ocr识别文本
         for layout_det in layout_dets:
             category_id = layout_det['category_id']
             if category_id in allow_category_id_list:
                 span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
-                if category_id == 3:
+                if category_id == CategoryId.ImageBody:
                     span['type'] = ContentType.IMAGE
-                elif category_id == 5:
+                elif category_id == CategoryId.TableBody:
                     # 获取table模型结果
                     latex = layout_det.get('latex', None)
                     html = layout_det.get('html', None)
@@ -407,13 +467,13 @@ class MagicModel:
                     elif html:
                         span['html'] = html
                     span['type'] = ContentType.TABLE
-                elif category_id == 13:
+                elif category_id == CategoryId.InlineEquation:
                     span['content'] = layout_det['latex']
                     span['type'] = ContentType.INLINE_EQUATION
-                elif category_id == 14:
+                elif category_id == CategoryId.InterlineEquation_YOLO:
                     span['content'] = layout_det['latex']
                     span['type'] = ContentType.INTERLINE_EQUATION
-                elif category_id == 15:
+                elif category_id == CategoryId.OcrText:
                     span['content'] = layout_det['text']
                     span['type'] = ContentType.TEXT
                 all_spans.append(span)
@@ -438,4 +498,4 @@ class MagicModel:
                 for col in extra_col:
                     block[col] = item.get(col, None)
                 blocks.append(block)
-        return blocks
+        return blocks

+ 9 - 9
mineru/cli/fast_api.py

@@ -34,10 +34,10 @@ async def parse_pdf(
         formula_enable: bool = Form(True),
         table_enable: bool = Form(True),
         server_url: Optional[str] = Form(None),
-        reuturn_md: bool = Form(True),
-        reuturn_middle_json: bool = Form(False),
+        return_md: bool = Form(True),
+        return_middle_json: bool = Form(False),
         return_model_output: bool = Form(False),
-        reuturn_content_list: bool = Form(False),
+        return_content_list: bool = Form(False),
         return_images: bool = Form(False),
         start_page_id: int = Form(0),
         end_page_id: int = Form(99999),
@@ -98,11 +98,11 @@ async def parse_pdf(
             server_url=server_url,
             f_draw_layout_bbox=False,
             f_draw_span_bbox=False,
-            f_dump_md=reuturn_md,
-            f_dump_middle_json=reuturn_middle_json,
+            f_dump_md=return_md,
+            f_dump_middle_json=return_middle_json,
             f_dump_model_output=return_model_output,
             f_dump_orig_pdf=False,
-            f_dump_content_list=reuturn_content_list,
+            f_dump_content_list=return_content_list,
             start_page_id=start_page_id,
             end_page_id=end_page_id,
         )
@@ -128,16 +128,16 @@ async def parse_pdf(
 
 
             if os.path.exists(parse_dir):
-                if reuturn_md:
+                if return_md:
                     data["md_content"] = get_infer_result(".md")
-                if reuturn_middle_json:
+                if return_middle_json:
                     data["middle_json"] = get_infer_result("_middle.json")
                 if return_model_output:
                     if backend.startswith("pipeline"):
                         data["model_output"] = get_infer_result("_model.json")
                     else:
                         data["model_output"] = get_infer_result("_model_output.txt")
-                if reuturn_content_list:
+                if return_content_list:
                     data["content_list"] = get_infer_result("_content_list.json")
                 if return_images:
                     image_paths = glob(f"{parse_dir}/images/*.jpg")

+ 35 - 31
mineru/utils/block_pre_proc.py

@@ -90,8 +90,8 @@ def prepare_block_bboxes(
     """经过以上处理后,还存在大框套小框的情况,则删除小框"""
     all_bboxes = remove_overlaps_min_blocks(all_bboxes)
     all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
-    """将剩余的bbox做分离处理,防止后面分layout时出错"""
-    # all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
+
+    """粗排序后返回"""
     all_bboxes.sort(key=lambda x: x[0]+x[1])
     return all_bboxes, all_discarded_blocks, footnote_blocks
 
@@ -213,35 +213,39 @@ def remove_overlaps_min_blocks(all_bboxes):
     #  重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
     #  删除重叠blocks中较小的那些
     need_remove = []
-    for block1 in all_bboxes:
-        for block2 in all_bboxes:
-            if block1 != block2:
-                block1_bbox = block1[:4]
-                block2_bbox = block2[:4]
-                overlap_box = get_minbox_if_overlap_by_ratio(
-                    block1_bbox, block2_bbox, 0.8
-                )
-                if overlap_box is not None:
-                    block_to_remove = next(
-                        (block for block in all_bboxes if block[:4] == overlap_box),
-                        None,
-                    )
-                    if (
-                        block_to_remove is not None
-                        and block_to_remove not in need_remove
-                    ):
-                        large_block = block1 if block1 != block_to_remove else block2
-                        x1, y1, x2, y2 = large_block[:4]
-                        sx1, sy1, sx2, sy2 = block_to_remove[:4]
-                        x1 = min(x1, sx1)
-                        y1 = min(y1, sy1)
-                        x2 = max(x2, sx2)
-                        y2 = max(y2, sy2)
-                        large_block[:4] = [x1, y1, x2, y2]
-                        need_remove.append(block_to_remove)
-
-    if len(need_remove) > 0:
-        for block in need_remove:
+    for i in range(len(all_bboxes)):
+        for j in range(i + 1, len(all_bboxes)):
+            block1 = all_bboxes[i]
+            block2 = all_bboxes[j]
+            block1_bbox = block1[:4]
+            block2_bbox = block2[:4]
+            overlap_box = get_minbox_if_overlap_by_ratio(
+                block1_bbox, block2_bbox, 0.8
+            )
+            if overlap_box is not None:
+                # 判断哪个区块的面积更小,移除较小的区块
+                area1 = (block1[2] - block1[0]) * (block1[3] - block1[1])
+                area2 = (block2[2] - block2[0]) * (block2[3] - block2[1])
+
+                if area1 <= area2:
+                    block_to_remove = block1
+                    large_block = block2
+                else:
+                    block_to_remove = block2
+                    large_block = block1
+
+                if block_to_remove not in need_remove:
+                    x1, y1, x2, y2 = large_block[:4]
+                    sx1, sy1, sx2, sy2 = block_to_remove[:4]
+                    x1 = min(x1, sx1)
+                    y1 = min(y1, sy1)
+                    x2 = max(x2, sx2)
+                    y2 = max(y2, sy2)
+                    large_block[:4] = [x1, y1, x2, y2]
+                    need_remove.append(block_to_remove)
+
+    for block in need_remove:
+        if block in all_bboxes:
             all_bboxes.remove(block)
 
     return all_bboxes

+ 1 - 1
pyproject.toml

@@ -43,7 +43,7 @@ vlm = [
     "pydantic",
 ]
 sglang = [
-    "sglang[all]>=0.4.7,<0.4.9",
+    "sglang[all]>=0.4.8,<0.4.9",
 ]
 pipeline = [
     "matplotlib>=3.10,<4",