Przeglądaj źródła

更新mm markdown拼装函数

赵小蒙 1 rok temu
rodzic
commit
d7128a9d87

+ 80 - 2
magic_pdf/dict2md/ocr_mkcontent.py

@@ -3,7 +3,7 @@ from loguru import logger
 from magic_pdf.libs.commons import join_path
 from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
-from magic_pdf.libs.ocr_content_type import ContentType
+from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 import wordninja
 import re
 
@@ -23,7 +23,7 @@ def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
     markdown = []
     for page_info in pdf_info_list:
         paras_of_layout = page_info.get("para_blocks")
-        page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm", img_buket_path)
+        page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
         markdown.extend(page_markdown)
     return '\n\n'.join(markdown)
 
@@ -36,6 +36,7 @@ def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
         markdown.extend(page_markdown)
     return '\n\n'.join(markdown)
 
+
 def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, img_buket_path):
     markdown_with_para_and_pagination = []
     page_no = 0
@@ -90,6 +91,82 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
     return page_markdown
 
 
+def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
+    page_markdown = []
+    for paras in paras_of_layout:
+        for para in paras:
+            para_type = para.get('type')
+            if para_type == BlockType.Text:
+                para_text = merge_para_with_text(para)
+            elif para_type == BlockType.Title:
+                para_text = f"# {merge_para_with_text(para)}"
+            elif para_type == BlockType.InterlineEquation:
+                para_text = merge_para_with_text(para)
+            elif para_type == BlockType.Image:
+                if mode == 'nlp':
+                    continue
+                elif mode == 'mm':
+                    img_blocks = para.get('blocks')
+                    for img_block in img_blocks:
+                        if img_block.get('type') == BlockType.ImageBody:
+                            for line in img_block.get('lines'):
+                                for span in line['spans']:
+                                    if span.get('type') == ContentType.Image:
+                                        para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
+                    for img_block in img_blocks:
+                        if img_block.get('type') == BlockType.ImageCaption:
+                            para_text += merge_para_with_text(img_block)
+            elif para_type == BlockType.Table:
+                if mode == 'nlp':
+                    continue
+                elif mode == 'mm':
+                    table_blocks = para.get('blocks')
+                    for table_block in table_blocks:
+                        if table_block.get('type') == BlockType.TableBody:
+                            for line in table_block.get('lines'):
+                                for span in line['spans']:
+                                    if span.get('type') == ContentType.Table:
+                                        para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
+                    for table_block in table_blocks:
+                        if table_block.get('type') == BlockType.TableCaption:
+                            para_text += merge_para_with_text(table_block)
+                        elif table_block.get('type') == BlockType.TableFootnote:
+                            para_text += merge_para_with_text(table_block)
+
+            if para_text.strip() == '':
+                continue
+            else:
+                page_markdown.append(para_text.strip() + '  ')
+
+    return page_markdown
+
+
+def merge_para_with_text(para):
+    para_text = ''
+    for line in para['lines']:
+        for span in line['spans']:
+            span_type = span.get('type')
+            content = ''
+            language = ''
+            if span_type == ContentType.Text:
+                content = span['content']
+                language = detect_lang(content)
+                if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
+                    content = ocr_escape_special_markdown_char(split_long_words(content))
+                else:
+                    content = ocr_escape_special_markdown_char(content)
+            elif span_type == ContentType.InlineEquation:
+                content = f"${span['content']}$"
+            elif span_type == ContentType.InterlineEquation:
+                content = f"\n$$\n{span['content']}\n$$\n"
+            if content != '':
+                if language == 'en':  # 英文语境下 content间需要空格分隔
+                    para_text += content + ' '
+                else:  # 中文语境下,content间不需要空格分隔
+                    para_text += content
+    return para_text
+
+
 def para_to_standard_format(para, img_buket_path):
     para_content = {}
     if len(para) == 1:
@@ -124,6 +201,7 @@ def para_to_standard_format(para, img_buket_path):
         }
     return para_content
 
+
 def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
     content_list = []
     for page_info in pdf_info_dict:

+ 0 - 1
magic_pdf/pdf_parse_by_ocr_v2.py

@@ -92,7 +92,6 @@ def parse_pdf_by_ocr(pdf_bytes,
         pdf_info_dict[f"page_{page_id}"] = page_info
 
     """分段"""
-    # if debug_mode:
     para_split(pdf_info_dict, debug_mode=debug_mode)
 
     """dict转list"""