소스 검색

refactor: enhance title block processing and improve markdown generation

myhloli 5 달 전
부모
커밋
6f2c3ad89e
3개의 변경된 파일38개의 추가작업 그리고 13개의 파일을 삭제
  1. 6 10
      mineru/api/vlm_middle_json_mkcontent.py
  2. 5 3
      mineru/backend/vlm/token_to_middle_json.py
  3. 27 0
      mineru/utils/vlm_magic_model.py

+ 6 - 10
mineru/api/vlm_middle_json_mkcontent.py

@@ -22,8 +22,11 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
     for para_block in para_blocks:
         para_text = ''
         para_type = para_block['type']
-        if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.TITLE, BlockType.INTERLINE_EQUATION]:
+        if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
             para_text = merge_para_with_text(para_block)
+        elif para_type == BlockType.TITLE:
+            title_level = get_title_level(para_block)
+            para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
         elif para_type == BlockType.IMAGE:
             if make_mode == MakeMode.NLP_MD:
                 continue
@@ -87,13 +90,7 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
     return page_markdown
 
 
-def count_leading_hashes(text):
-    match = re.match(r'^(#+)', text)
-    return len(match.group(1)) if match else 0
 
-def strip_leading_hashes(text):
-    # 去除开头的#和紧随其后的空格
-    return re.sub(r'^#+\s*', '', text)
 
 
 def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
@@ -105,11 +102,10 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
             'text': merge_para_with_text(para_block),
         }
     elif para_type == BlockType.TITLE:
-        title_content = merge_para_with_text(para_block)
-        title_level = count_leading_hashes(title_content)
+        title_level = get_title_level(para_block)
         para_content = {
             'type': 'text',
-            'text': strip_leading_hashes(title_content),
+            'text': merge_para_with_text(para_block),
         }
         if title_level != 0:
             para_content['text_level'] = title_level

+ 5 - 3
mineru/backend/vlm/token_to_middle_json.py

@@ -1,9 +1,10 @@
 import re
 
+from mineru.utils.block_pre_proc import fix_text_overlap_title_blocks
 from mineru.utils.cut_image import cut_image_and_table
 from mineru.utils.enum_class import BlockType, ContentType
 from mineru.utils.hash_utils import str_md5
-from mineru.utils.vlm_magic_model import fix_two_layer_blocks
+from mineru.utils.vlm_magic_model import fix_two_layer_blocks, fix_title_blocks
 from mineru.version import __version__
 
 
@@ -103,13 +104,14 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
 
     image_blocks = fix_two_layer_blocks(blocks, BlockType.IMAGE)
     table_blocks = fix_two_layer_blocks(blocks, BlockType.TABLE)
+    title_blocks = fix_title_blocks(blocks)
 
     page_blocks = [
         block
         for block in blocks
-        if block["type"] in [BlockType.TEXT, BlockType.TITLE, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]
+        if block["type"] in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]
     ]
-    page_blocks.extend([*image_blocks, *table_blocks])
+    page_blocks.extend([*image_blocks, *table_blocks, *title_blocks])
     # 对page_blocks根据index的值进行排序
     page_blocks.sort(key=lambda x: x["index"])
 

+ 27 - 0
mineru/utils/vlm_magic_model.py

@@ -1,6 +1,9 @@
+import re
 from typing import Literal
 
 from .boxbase import bbox_distance, is_in
+from .enum_class import BlockType
+from ..api.vlm_middle_json_mkcontent import merge_para_with_text
 
 
 def __reduct_overlap(bboxes):
@@ -217,3 +220,27 @@ def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
         fixed_blocks.append(two_layer_block)
 
     return fixed_blocks
+
+
+def fix_title_blocks(blocks):
+    for block in blocks:
+        if block["type"] == BlockType.TITLE:
+            title_content = merge_para_with_text(block)
+            title_level = count_leading_hashes(title_content)
+            block['level'] = title_level
+            for line in block['lines']:
+                for span in line['spans']:
+                    span['content'] = strip_leading_hashes(span['content'])
+                    break
+                break
+    return blocks
+
+
+def count_leading_hashes(text):
+    match = re.match(r'^(#+)', text)
+    return len(match.group(1)) if match else 0
+
+
+def strip_leading_hashes(text):
+    # 去除开头的#和紧随其后的空格
+    return re.sub(r'^#+\s*', '', text)