Ver código fonte

feat(pipeline): pass language parameter for parsing and markdown conversion

The pipeline now supports passing the language parameter to parsing functions and
during markdown conversion to optimize processing based on the specified language.
This enhancement allows for more accurate parsing and markdown generation, particularly
when dealing with non-English content.
myhloli 1 ano atrás
pai
commit
6062862c96

+ 34 - 25
magic_pdf/dict2md/ocr_mkcontent.py

@@ -116,17 +116,20 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=''):
 
 def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                       mode,
-                                      img_buket_path=''):
+                                      img_buket_path='',
+                                      parse_type="auto",
+                                      lang=None
+                                      ):
     page_markdown = []
     for para_block in paras_of_layout:
         para_text = ''
         para_type = para_block['type']
         if para_type == BlockType.Text:
-            para_text = merge_para_with_text(para_block)
+            para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
         elif para_type == BlockType.Title:
-            para_text = f'# {merge_para_with_text(para_block)}'
+            para_text = f'# {merge_para_with_text(para_block, parse_type=parse_type, lang=lang)}'
         elif para_type == BlockType.InterlineEquation:
-            para_text = merge_para_with_text(para_block)
+            para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
         elif para_type == BlockType.Image:
             if mode == 'nlp':
                 continue
@@ -139,17 +142,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                     para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageCaption:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
         elif para_type == BlockType.Table:
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
@@ -164,7 +167,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                         para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block, parse_type=parse_type, lang=lang)
 
         if para_text.strip() == '':
             continue
@@ -174,7 +177,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
     return page_markdown
 
 
-def merge_para_with_text(para_block):
+def merge_para_with_text(para_block, parse_type="auto", lang=None):
 
     def detect_language(text):
         en_pattern = r'[a-zA-Z]+'
@@ -205,11 +208,15 @@ def merge_para_with_text(para_block):
                 content = span['content']
                 # language = detect_lang(content)
                 language = detect_language(content)
-                if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
-                    content = ocr_escape_special_markdown_char(
-                        split_long_words(content))
-                else:
+                # 判断是否小语种
+                if lang is not None and lang != 'en':
                     content = ocr_escape_special_markdown_char(content)
+                else:  # 非小语种逻辑
+                    if language == 'en' and parse_type == 'ocr':  # 只对英文长词进行分词处理,中文分词会丢失文本
+                        content = ocr_escape_special_markdown_char(
+                            split_long_words(content))
+                    else:
+                        content = ocr_escape_special_markdown_char(content)
             elif span_type == ContentType.InlineEquation:
                 content = f" ${span['content']}$ "
             elif span_type == ContentType.InterlineEquation:
@@ -265,25 +272,25 @@ def para_to_standard_format(para, img_buket_path):
     return para_content
 
 
-def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
+def para_to_standard_format_v2(para_block, img_buket_path, page_idx, parse_type="auto", lang=None):
     para_type = para_block['type']
     if para_type == BlockType.Text:
         para_content = {
             'type': 'text',
-            'text': merge_para_with_text(para_block),
+            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
             'page_idx': page_idx,
         }
     elif para_type == BlockType.Title:
         para_content = {
             'type': 'text',
-            'text': merge_para_with_text(para_block),
+            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
             'text_level': 1,
             'page_idx': page_idx,
         }
     elif para_type == BlockType.InterlineEquation:
         para_content = {
             'type': 'equation',
-            'text': merge_para_with_text(para_block),
+            'text': merge_para_with_text(para_block, parse_type=parse_type, lang=lang),
             'text_format': 'latex',
             'page_idx': page_idx,
         }
@@ -295,9 +302,9 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
                     img_buket_path,
                     block['lines'][0]['spans'][0]['image_path'])
             if block['type'] == BlockType.ImageCaption:
-                para_content['img_caption'] = merge_para_with_text(block)
+                para_content['img_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
             if block['type'] == BlockType.ImageFootnote:
-                para_content['img_footnote'] = merge_para_with_text(block)
+                para_content['img_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
     elif para_type == BlockType.Table:
         para_content = {'type': 'table', 'page_idx': page_idx}
         for block in para_block['blocks']:
@@ -308,9 +315,9 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx):
                     para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:
-                para_content['table_caption'] = merge_para_with_text(block)
+                para_content['table_caption'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
             if block['type'] == BlockType.TableFootnote:
-                para_content['table_footnote'] = merge_para_with_text(block)
+                para_content['table_footnote'] = merge_para_with_text(block, parse_type=parse_type, lang=lang)
 
     return para_content
 
@@ -394,7 +401,9 @@ def ocr_mk_mm_standard_format(pdf_info_dict: list):
 def union_make(pdf_info_dict: list,
                make_mode: str,
                drop_mode: str,
-               img_buket_path: str = ''):
+               img_buket_path: str = '',
+               parse_type: str = "auto",
+               lang=None):
     output_content = []
     for page_info in pdf_info_dict:
         if page_info.get('need_drop', False):
@@ -417,16 +426,16 @@ def union_make(pdf_info_dict: list,
             continue
         if make_mode == MakeMode.MM_MD:
             page_markdown = ocr_mk_markdown_with_para_core_v2(
-                paras_of_layout, 'mm', img_buket_path)
+                paras_of_layout, 'mm', img_buket_path, parse_type=parse_type, lang=lang)
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.NLP_MD:
             page_markdown = ocr_mk_markdown_with_para_core_v2(
-                paras_of_layout, 'nlp')
+                paras_of_layout, 'nlp', parse_type=parse_type, lang=lang)
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.STANDARD_FORMAT:
             for para_block in paras_of_layout:
                 para_content = para_to_standard_format_v2(
-                    para_block, img_buket_path, page_idx)
+                    para_block, img_buket_path, page_idx, parse_type=parse_type, lang=lang)
                 output_content.append(para_content)
     if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
         return '\n\n'.join(output_content)

+ 6 - 2
magic_pdf/pipe/AbsPipe.py

@@ -95,7 +95,9 @@ class AbsPipe(ABC):
         """
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_info_list = pdf_mid_data["pdf_info"]
-        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
+        parse_type = pdf_mid_data["_parse_type"]
+        lang = pdf_mid_data.get("_lang", None)
+        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path, parse_type, lang)
         return content_list
 
     @staticmethod
@@ -105,7 +107,9 @@ class AbsPipe(ABC):
         """
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_info_list = pdf_mid_data["pdf_info"]
-        md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
+        parse_type = pdf_mid_data["_parse_type"]
+        lang = pdf_mid_data.get("_lang", None)
+        md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path, parse_type, lang)
         return md_content
 
 

+ 2 - 1
magic_pdf/pipe/OCRPipe.py

@@ -23,7 +23,8 @@ class OCRPipe(AbsPipe):
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 2 - 1
magic_pdf/pipe/TXTPipe.py

@@ -24,7 +24,8 @@ class TXTPipe(AbsPipe):
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 2 - 1
magic_pdf/pipe/UNIPipe.py

@@ -44,7 +44,8 @@ class UNIPipe(AbsPipe):
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                               is_debug=self.is_debug,
-                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                              lang=self.lang)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 11 - 2
magic_pdf/user_api.py

@@ -26,7 +26,7 @@ PARSE_TYPE_OCR = "ocr"
 
 
 def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None,
+                  start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
     """
     解析文本类pdf
@@ -44,11 +44,14 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
 
     pdf_info_dict["_version_name"] = __version__
 
+    if lang is not None:
+        pdf_info_dict["_lang"] = lang
+
     return pdf_info_dict
 
 
 def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None,
+                  start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
     """
     解析ocr类pdf
@@ -66,6 +69,9 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
 
     pdf_info_dict["_version_name"] = __version__
 
+    if lang is not None:
+        pdf_info_dict["_lang"] = lang
+
     return pdf_info_dict
 
 
@@ -110,4 +116,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
 
     pdf_info_dict["_version_name"] = __version__
 
+    if lang is not None:
+        pdf_info_dict["_lang"] = lang
+
     return pdf_info_dict