Просмотр исходного кода

refactor: introduce SplitFlag class and update references in para_split and vlm_magic_model

myhloli 5 месяцев назад
Родитель
Сommit
4359b36f90

+ 5 - 7
mineru/backend/pipeline/para_split.py

@@ -1,10 +1,8 @@
 import copy
 from loguru import logger
-from mineru.utils.enum_class import ContentType, BlockType
+from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
 from mineru.utils.language import detect_lang
 
-CROSS_PAGE = 'cross_page'
-LINES_DELETED = 'lines_deleted'
 
 LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
 LIST_END_FLAG = ('.', '。', ';', ';')
@@ -284,10 +282,10 @@ def __merge_2_text_blocks(block1, block2):
                             if block1['page_num'] != block2['page_num']:
                                 for line in block1['lines']:
                                     for span in line['spans']:
-                                        span[CROSS_PAGE] = True
+                                        span[SplitFlag.CROSS_PAGE] = True
                             block2['lines'].extend(block1['lines'])
                             block1['lines'] = []
-                            block1[LINES_DELETED] = True
+                            block1[SplitFlag.LINES_DELETED] = True
 
     return block1, block2
 
@@ -296,10 +294,10 @@ def __merge_2_list_blocks(block1, block2):
     if block1['page_num'] != block2['page_num']:
         for line in block1['lines']:
             for span in line['spans']:
-                span[CROSS_PAGE] = True
+                span[SplitFlag.CROSS_PAGE] = True
     block2['lines'].extend(block1['lines'])
     block1['lines'] = []
-    block1[LINES_DELETED] = True
+    block1[SplitFlag.LINES_DELETED] = True
 
     return block1, block2
 

+ 33 - 3
mineru/backend/vlm/vlm_magic_model.py

@@ -2,7 +2,7 @@ import re
 from typing import Literal
 
 from mineru.utils.boxbase import bbox_distance, is_in
-from mineru.utils.enum_class import BlockType, ContentType
+from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
 from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
 from mineru.utils.format_utils import convert_otsl_to_html
 
@@ -187,7 +187,7 @@ class MagicModel:
         return fix_title_blocks(self.title_blocks)
 
     def get_text_blocks(self):
-        return self.text_blocks
+        return fix_text_blocks(self.text_blocks)
 
     def get_interline_equation_blocks(self):
         return self.interline_equation_blocks
@@ -440,4 +440,34 @@ def count_leading_hashes(text):
 
 def strip_leading_hashes(text):
     # 去除开头的#和紧随其后的空格
-    return re.sub(r'^#+\s*', '', text)
+    return re.sub(r'^#+\s*', '', text)
+
+
+def fix_text_blocks(blocks):
+    i = 0
+    while i < len(blocks):
+        block = blocks[i]
+        last_line = block["lines"][-1]if block["lines"] else None
+        if last_line:
+            last_span = last_line["spans"][-1] if last_line["spans"] else None
+            if last_span and last_span['content'].endswith('<|txt_contd|>'):
+                last_span['content'] = last_span['content'][:-len('<|txt_contd|>')]
+
+                # 查找下一个未被清空的块
+                next_idx = i + 1
+                while next_idx < len(blocks) and blocks[next_idx].get(SplitFlag.LINES_DELETED, False):
+                    next_idx += 1
+
+                # 如果找到下一个有效块,则合并
+                if next_idx < len(blocks):
+                    next_block = blocks[next_idx]
+                    # 将下一个块的lines扩展到当前块的lines中
+                    block["lines"].extend(next_block["lines"])
+                    # 清空下一个块的lines
+                    next_block["lines"] = []
+                    # 在下一个块中添加标志
+                    next_block[SplitFlag.LINES_DELETED] = True
+                    # 不增加i,继续检查当前块(现在已包含下一个块的内容)
+                    continue
+        i += 1
+    return blocks

+ 1 - 1
mineru/backend/vlm/vlm_middle_json_mkcontent.py

@@ -28,7 +28,7 @@ def merge_para_with_text(para_block):
                 content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
             elif span_type == ContentType.INTERLINE_EQUATION:
                 content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
-            content = content.strip()
+            # content = content.strip()
             if content:
                 if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
                     if j == len(line['spans']) - 1:

+ 6 - 1
mineru/utils/enum_class.py

@@ -54,4 +54,9 @@ class ModelPath:
     pytorch_paddle = "models/OCR/paddleocr_torch"
     layout_reader = "models/ReadingOrder/layout_reader"
     vlm_root_hf = "opendatalab/MinerU-VLM-1.0"
-    vlm_root_modelscope = "OpenDataLab/MinerU-VLM-1.0"
+    vlm_root_modelscope = "OpenDataLab/MinerU-VLM-1.0"
+
+
+class SplitFlag:
+    CROSS_PAGE = 'cross_page'
+    LINES_DELETED = 'lines_deleted'