Pārlūkot izejas kodu

Merge branch 'master' of github.com:papayalove/Magic-PDF

# Conflicts:
#	magic_pdf/dict2md/ocr_mkcontent.py
liukaiwen 1 gadu atpakaļ
vecāks
revīzija
ab3501694d

+ 2 - 1
.github/workflows/cli.yml

@@ -32,6 +32,7 @@ jobs:
       uses: actions/checkout@v3
       with:
         fetch-depth: 2
+      
     - name: check-requirements
       run: |
         changed_files=$(git diff --name-only -r HEAD~1 HEAD)
@@ -77,4 +78,4 @@ jobs:
         "text": {
             "mentioned_list": ["${{ env.METIONS }}"] , "content": "'${{ github.repository }}' GitHubAction Failed!\n 细节请查看:https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"
         } 
-        }'     
+        }'     

+ 5 - 1
magic_pdf/cli/magicpdf.py

@@ -60,7 +60,11 @@ def prepare_env(pdf_file_name, method):
 
 def _do_parse(pdf_file_name, pdf_bytes, model_list, parse_method, image_writer, md_writer, image_dir, local_md_dir):
     if parse_method == "auto":
-        pipe = UNIPipe(pdf_bytes, model_list, image_writer, is_debug=True)
+        jso_useful_key = {
+            "_pdf_type": "",
+            "model_list": model_list
+        }
+        pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True)
     elif parse_method == "txt":
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True)
     elif parse_method == "ocr":

+ 14 - 13
magic_pdf/dict2md/ocr_mkcontent.py

@@ -106,29 +106,30 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
-                for block in para_block['blocks']:
+                for block in para_block['blocks']:  # 1st.拼image_body
                     if block['type'] == BlockType.ImageBody:
                         for line in block['lines']:
                             for span in line['spans']:
                                 if span['type'] == ContentType.Image:
-                                    para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
-                for block in para_block['blocks']:
+                                    para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
+                for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageCaption:
                         para_text += merge_para_with_text(block)
         elif para_type == BlockType.Table:
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
-                for block in para_block['blocks']:
+                for block in para_block['blocks']:  # 1st.拼table_caption
+                    if block['type'] == BlockType.TableCaption:
+                        para_text += merge_para_with_text(block)
+                for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
                             for span in line['spans']:
                                 if span['type'] == ContentType.Table:
-                                    para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
-                for block in para_block['blocks']:
-                    if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block)
-                    elif block['type'] == BlockType.TableFootnote:
+                                    para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
+                for block in para_block['blocks']:  # 3rd.拼table_footnote
+                    if block['type'] == BlockType.TableFootnote:
                         para_text += merge_para_with_text(block)
 
         if para_text.strip() == '':
@@ -159,10 +160,10 @@ def merge_para_with_text(para_block):
                 content = f"\n$$\n{span['content']}\n$$\n"
 
             if content != '':
-                if language in ['en', 'un']:  # 英文语境下 content间需要空格分隔
-                    para_text += content + ' '
-                else:  # 中文语境下,content间不需要空格分隔
-                    para_text += content
+                if 'zh' in language:
+                    para_text += content  # 中文语境下,content间不需要空格分隔
+                else:
+                    para_text += content + ' '  # 英文语境下 content间需要空格分隔
     return para_text
 
 

+ 17 - 9
magic_pdf/pipe/UNIPipe.py

@@ -10,11 +10,12 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 
 class UNIPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug)
+    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False):
+        self.pdf_type = jso_useful_key["_pdf_type"]
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug)
 
     def pipe_classify(self):
-        self.pdf_type = UNIPipe.classify(self.pdf_bytes)
+        self.pdf_type = AbsPipe.classify(self.pdf_bytes)
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
@@ -46,14 +47,21 @@ if __name__ == '__main__':
     img_bucket_path = "imgs"
     img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
 
-    pipe = UNIPipe(pdf_bytes, model_list, img_writer, img_bucket_path)
+    # pdf_type = UNIPipe.classify(pdf_bytes)
+    # jso_useful_key = {
+    #     "_pdf_type": pdf_type,
+    #     "model_list": model_list
+    # }
+
+    jso_useful_key = {
+        "_pdf_type": "",
+        "model_list": model_list
+    }
+    pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
     pipe.pipe_classify()
     pipe.pipe_parse()
-    md_content = pipe.pipe_mk_markdown()
-    try:
-        content_list = pipe.pipe_mk_uni_format()
-    except Exception as e:
-        logger.exception(e)
+    md_content = pipe.pipe_mk_markdown(img_bucket_path)
+    content_list = pipe.pipe_mk_uni_format(img_bucket_path)
 
     md_writer = DiskReaderWriter(write_path)
     md_writer.write(md_content, "19983-00.md", AbsReaderWriter.MODE_TXT)

+ 5 - 4
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -156,7 +156,7 @@ def fill_spans_in_blocks(blocks, spans):
         block_spans = []
         for span in spans:
             span_bbox = span['bbox']
-            if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.7:
+            if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.6:
                 block_spans.append(span)
 
         '''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
@@ -167,8 +167,8 @@ def fill_spans_in_blocks(blocks, spans):
         '''模型识别错误的行间公式, type类型转换成行内公式'''
         block_spans = modify_inline_equation(block_spans, displayed_list, text_inline_lines)
 
-        '''bbox去除粘连'''
-        block_spans = remove_overlap_between_bbox(block_spans)
+        '''bbox去除粘连'''  # 去粘连会影响span的bbox,导致后续fill的时候出错
+        # block_spans = remove_overlap_between_bbox(block_spans)
 
         block_dict['spans'] = block_spans
         block_with_spans.append(block_dict)
@@ -208,7 +208,7 @@ def merge_spans_to_block(spans: list, block_bbox: list, block_type: str):
     block_spans = []
     # 如果有img_caption,则将img_block中的text_spans放入img_caption_block中
     for span in spans:
-        if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block_bbox) > 0.8:
+        if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block_bbox) > 0.6:
             block_spans.append(span)
     block_lines = merge_spans_to_line(block_spans)
     # 对line中的span进行排序
@@ -268,6 +268,7 @@ def fix_table_block(block, table_blocks):
     # 遍历table_blocks,找到与当前block匹配的table_block
     for table_block in table_blocks:
         if table_block['bbox'] == block['bbox']:
+
             # 创建table_body_block
             for span in block['spans']:
                 if span['type'] == ContentType.Table and span['bbox'] == table_block['table_body_bbox']:

+ 4 - 4
magic_pdf/rw/S3ReaderWriter.py

@@ -1,5 +1,5 @@
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key
+from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key, join_path
 import boto3
 from loguru import logger
 from boto3.s3.transfer import TransferConfig
@@ -30,7 +30,7 @@ class S3ReaderWriter(AbsReaderWriter):
         if s3_relative_path.startswith("s3://"):
             s3_path = s3_relative_path
         else:
-            s3_path = os.path.join(self.path, s3_relative_path)
+            s3_path = join_path(self.path, s3_relative_path)
         bucket_name, key = parse_bucket_key(s3_path)
         res = self.client.get_object(Bucket=bucket_name, Key=key)
         body = res["Body"].read()
@@ -46,7 +46,7 @@ class S3ReaderWriter(AbsReaderWriter):
         if s3_relative_path.startswith("s3://"):
             s3_path = s3_relative_path
         else:
-            s3_path = os.path.join(self.path, s3_relative_path)
+            s3_path = join_path(self.path, s3_relative_path)
         if mode == MODE_TXT:
             body = content.encode(encoding)  # Encode text data as bytes
         elif mode == MODE_BIN:
@@ -61,7 +61,7 @@ class S3ReaderWriter(AbsReaderWriter):
         if path.startswith("s3://"):
             s3_path = path
         else:
-            s3_path = os.path.join(self.path, path)
+            s3_path = join_path(self.path, path)
         bucket_name, key = parse_bucket_key(s3_path)
 
         range_header = f'bytes={byte_start}-{byte_end}' if byte_end else f'bytes={byte_start}-'