瀏覽代碼

Merge branch 'master' of github.com:papayalove/Magic-PDF

liukaiwen 1 年之前
父節點
當前提交
bb2bf065ac

+ 9 - 7
magic_pdf/cli/magicpdf.py

@@ -25,6 +25,8 @@ import os
 import json as json_parse
 from datetime import datetime
 import click
+from loguru import logger
+
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
@@ -77,13 +79,13 @@ def _do_parse(pdf_bytes, model_list, parse_method, image_writer, md_writer, imag
         path=f"{part_file_name}.json",
         mode=AbsReaderWriter.MODE_TXT,
     )
-    try:
-        content_list = pipe.pipe_mk_uni_format()
-    except Exception as e:
-        print(e)
-    md_writer.write(
-        str(content_list), f"{part_file_name}.txt", AbsReaderWriter.MODE_TXT
-    )
+    # try:
+    #     content_list = pipe.pipe_mk_uni_format()
+    # except Exception as e:
+    #     logger.exception(e)
+    # md_writer.write(
+    #     str(content_list), f"{part_file_name}.txt", AbsReaderWriter.MODE_TXT
+    # )
 
 
 @click.group()

+ 81 - 4
magic_pdf/dict2md/ocr_mkcontent.py

@@ -3,7 +3,7 @@ from loguru import logger
 from magic_pdf.libs.commons import join_path
 from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
-from magic_pdf.libs.ocr_content_type import ContentType
+from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 import wordninja
 import re
 
@@ -23,7 +23,7 @@ def ocr_mk_mm_markdown_with_para(pdf_info_list: list, img_buket_path):
     markdown = []
     for page_info in pdf_info_list:
         paras_of_layout = page_info.get("para_blocks")
-        page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm", img_buket_path)
+        page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
         markdown.extend(page_markdown)
     return '\n\n'.join(markdown)
 
@@ -32,10 +32,11 @@ def ocr_mk_nlp_markdown_with_para(pdf_info_dict: list):
     markdown = []
     for page_info in pdf_info_dict:
         paras_of_layout = page_info.get("para_blocks")
-        page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "nlp")
+        page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "nlp")
         markdown.extend(page_markdown)
     return '\n\n'.join(markdown)
 
+
 def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, img_buket_path):
     markdown_with_para_and_pagination = []
     page_no = 0
@@ -43,7 +44,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list, img_buket_p
         paras_of_layout = page_info.get("para_blocks")
         if not paras_of_layout:
             continue
-        page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm", img_buket_path)
+        page_markdown = ocr_mk_markdown_with_para_core_v2(paras_of_layout, "mm", img_buket_path)
         markdown_with_para_and_pagination.append({
             'page_no': page_no,
             'md_content': '\n\n'.join(page_markdown)
@@ -90,6 +91,81 @@ def ocr_mk_markdown_with_para_core(paras_of_layout, mode, img_buket_path=""):
     return page_markdown
 
 
+def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
+    page_markdown = []
+    for para_block in paras_of_layout:
+        para_type = para_block.get('type')
+        if para_type == BlockType.Text:
+            para_text = merge_para_with_text(para_block)
+        elif para_type == BlockType.Title:
+            para_text = f"# {merge_para_with_text(para_block)}"
+        elif para_type == BlockType.InterlineEquation:
+            para_text = merge_para_with_text(para_block)
+        elif para_type == BlockType.Image:
+            if mode == 'nlp':
+                continue
+            elif mode == 'mm':
+                img_blocks = para_block.get('blocks')
+                for img_block in img_blocks:
+                    if img_block.get('type') == BlockType.ImageBody:
+                        for line in img_block.get('lines'):
+                            for span in line['spans']:
+                                if span.get('type') == ContentType.Image:
+                                    para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
+                for img_block in img_blocks:
+                    if img_block.get('type') == BlockType.ImageCaption:
+                        para_text += merge_para_with_text(img_block)
+        elif para_type == BlockType.Table:
+            if mode == 'nlp':
+                continue
+            elif mode == 'mm':
+                table_blocks = para_block.get('blocks')
+                for table_block in table_blocks:
+                    if table_block.get('type') == BlockType.TableBody:
+                        for line in table_block.get('lines'):
+                            for span in line['spans']:
+                                if span.get('type') == ContentType.Table:
+                                    para_text = f"\n![]({join_path(img_buket_path, span['image_path'])})\n"
+                for table_block in table_blocks:
+                    if table_block.get('type') == BlockType.TableCaption:
+                        para_text += merge_para_with_text(table_block)
+                    elif table_block.get('type') == BlockType.TableFootnote:
+                        para_text += merge_para_with_text(table_block)
+
+        if para_text.strip() == '':
+            continue
+        else:
+            page_markdown.append(para_text.strip() + '  ')
+
+    return page_markdown
+
+
+def merge_para_with_text(para):
+    para_text = ''
+    for line in para['lines']:
+        for span in line['spans']:
+            span_type = span.get('type')
+            content = ''
+            language = ''
+            if span_type == ContentType.Text:
+                content = span['content']
+                language = detect_lang(content)
+                if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
+                    content = ocr_escape_special_markdown_char(split_long_words(content))
+                else:
+                    content = ocr_escape_special_markdown_char(content)
+            elif span_type == ContentType.InlineEquation:
+                content = f"${span['content']}$"
+            elif span_type == ContentType.InterlineEquation:
+                content = f"\n$$\n{span['content']}\n$$\n"
+            if content != '':
+                if language == 'en':  # 英文语境下 content间需要空格分隔
+                    para_text += content + ' '
+                else:  # 中文语境下,content间不需要空格分隔
+                    para_text += content
+    return para_text
+
+
 def para_to_standard_format(para, img_buket_path):
     para_content = {}
     if len(para) == 1:
@@ -124,6 +200,7 @@ def para_to_standard_format(para, img_buket_path):
         }
     return para_content
 
+
 def make_standard_format_with_para(pdf_info_dict: list, img_buket_path: str):
     content_list = []
     for page_info in pdf_info_dict:

+ 5 - 1
magic_pdf/libs/math.py

@@ -2,4 +2,8 @@ def float_gt(a, b):
     if 0.0001 >= abs(a -b):
         return False
     return a > b
-    
+    
+def float_equal(a, b):
+    if 0.0001 >= abs(a-b):
+        return True
+    return False

+ 29 - 18
magic_pdf/model/magic_model.py

@@ -21,8 +21,8 @@ class MagicModel:
     """
 
     def __fix_axis(self):
-        need_remove_list = []
         for model_page_info in self.__model_list:
+            need_remove_list = []
             page_no = model_page_info["page_info"]["page_no"]
             horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
                 model_page_info, self.__docs[page_no]
@@ -43,12 +43,24 @@ class MagicModel:
             for need_remove in need_remove_list:
                 layout_dets.remove(need_remove)
 
+    def __fix_by_confidence(self):
+        for model_page_info in self.__model_list:
+            need_remove_list = []
+            layout_dets = model_page_info["layout_dets"]
+            for layout_det in layout_dets:
+                if layout_det["score"] < 0.6:
+                    need_remove_list.append(layout_det)
+                else:
+                    continue
+            for need_remove in need_remove_list:
+                layout_dets.remove(need_remove)
 
     def __init__(self, model_list: list, docs: fitz.Document):
         self.__model_list = model_list
         self.__docs = docs
         self.__fix_axis()
-        #@todo 移除置信度小于0.6的所有block
+        #@TODO 删除掉一些低置信度的会导致分段错误,后面再修复
+        # self.__fix_by_confidence()
 
     def __reduct_overlap(self, bboxes):
         N = len(bboxes)
@@ -63,13 +75,13 @@ class MagicModel:
         return [bboxes[i] for i in range(N) if keep[i]]
 
     def __tie_up_category_by_distance(
-        self, page_no, subject_category_id, object_category_id
+            self, page_no, subject_category_id, object_category_id
     ):
         """
         假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object 只能属于一个 subject
         """
         ret = []
-        MAX_DIS_OF_POINT = 10**9 + 7
+        MAX_DIS_OF_POINT = 10 ** 9 + 7
 
         subjects = self.__reduct_overlap(
             list(
@@ -112,8 +124,8 @@ class MagicModel:
         for i in range(N):
             for j in range(i):
                 if (
-                    all_bboxes[i]["category_id"] == subject_category_id
-                    and all_bboxes[j]["category_id"] == subject_category_id
+                        all_bboxes[i]["category_id"] == subject_category_id
+                        and all_bboxes[j]["category_id"] == subject_category_id
                 ):
                     continue
 
@@ -143,9 +155,9 @@ class MagicModel:
                 if pos_flag_count > 1:
                     continue
                 if (
-                    all_bboxes[j]["category_id"] != object_category_id
-                    or j in used
-                    or dis[i][j] == MAX_DIS_OF_POINT
+                        all_bboxes[j]["category_id"] != object_category_id
+                        or j in used
+                        or dis[i][j] == MAX_DIS_OF_POINT
                 ):
                     continue
                 arr.append((dis[i][j], j))
@@ -174,10 +186,10 @@ class MagicModel:
                         continue
 
                     if (
-                        all_bboxes[k]["category_id"] != object_category_id
-                        or k in used
-                        or k in seen
-                        or dis[j][k] == MAX_DIS_OF_POINT
+                            all_bboxes[k]["category_id"] != object_category_id
+                            or k in used
+                            or k in seen
+                            or dis[j][k] == MAX_DIS_OF_POINT
                     ):
                         continue
                     is_nearest = True
@@ -185,12 +197,10 @@ class MagicModel:
                         if l in (j, k) or l in used or l in seen:
                             continue
 
-
                         if not float_gt(dis[l][k], dis[j][k]):
                             is_nearest = False
                             break
 
-
                     if is_nearest:
                         tmp.append(k)
                         seen.add(k)
@@ -303,8 +313,8 @@ class MagicModel:
             candidates = []
             for j in range(N):
                 if (
-                    all_bboxes[j]["category_id"] != subject_category_id
-                    or j in with_caption_subject
+                        all_bboxes[j]["category_id"] != subject_category_id
+                        or j in with_caption_subject
                 ):
                     continue
                 candidates.append((dis[i][j], j))
@@ -326,7 +336,7 @@ class MagicModel:
         ]
 
     def get_tables(
-        self, page_no: int
+            self, page_no: int
     ) -> list:  # 3个坐标, caption, table主体,table-note
         with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
         with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
@@ -441,6 +451,7 @@ class MagicModel:
                     blocks.append(block)
         return blocks
 
+
 if __name__ == "__main__":
     drw = DiskReaderWriter(r"D:/project/20231108code-clean")
     if 0:

+ 0 - 1
magic_pdf/pdf_parse_by_ocr_v2.py

@@ -92,7 +92,6 @@ def parse_pdf_by_ocr(pdf_bytes,
         pdf_info_dict[f"page_{page_id}"] = page_info
 
     """分段"""
-    # if debug_mode:
     para_split(pdf_info_dict, debug_mode=debug_mode)
 
     """dict转list"""

+ 6 - 2
magic_pdf/pdf_parse_by_txt_v2.py

@@ -31,7 +31,8 @@ from magic_pdf.pre_proc.equations_replace import (
     replace_equations_in_textblock,
 )
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
-
+from magic_pdf.libs.math import float_equal
+from magic_pdf.para.para_split_v2 import para_split
 
 def txt_spans_extract(pdf_page, inline_equations, interline_equations):
     text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
@@ -48,6 +49,9 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
     for v in text_blocks:
         for line in v["lines"]:
             for span in line["spans"]:
+                bbox = span["bbox"]
+                if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
+                    continue
                 spans.append(
                     {
                         "bbox": list(span["bbox"]),
@@ -167,7 +171,7 @@ def parse_pdf_by_txt(
         pdf_info_dict[f"page_{page_id}"] = page_info
 
     """分段"""
-    pass
+    para_split(pdf_info_dict, debug_mode=debug_mode)
 
     """dict转list"""
     pdf_info_list = dict_to_list(pdf_info_dict)

+ 3 - 2
magic_pdf/pipe/AbsPipe.py

@@ -106,8 +106,9 @@ class AbsPipe(ABC):
         parse_type = pdf_mid_data["_parse_type"]
         pdf_info_list = pdf_mid_data["pdf_info"]
         if parse_type == AbsPipe.PIP_TXT:
-            content_list = mk_universal_format(pdf_info_list, img_buket_path)
-            md_content = mk_mm_markdown(content_list)
+            # content_list = mk_universal_format(pdf_info_list, img_buket_path)
+            # md_content = mk_mm_markdown(content_list)
+            md_content = ocr_mk_mm_markdown_with_para(pdf_info_list, img_buket_path)
         elif parse_type == AbsPipe.PIP_OCR:
             md_content = ocr_mk_mm_markdown_with_para(pdf_info_list, img_buket_path)
         return md_content

+ 2 - 2
magic_pdf/pre_proc/construct_page_dict.py

@@ -55,7 +55,7 @@ def ocr_construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h,
 
 
 def ocr_construct_page_component_v2(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
-                                    images, tables, interline_equations, droped_blocks):
+                                    images, tables, interline_equations, discarded_blocks):
     return_dict = {
         'preproc_blocks': blocks,
         'layout_bboxes': layout_bboxes,
@@ -65,6 +65,6 @@ def ocr_construct_page_component_v2(blocks, layout_bboxes, page_id, page_w, page
         'images': images,
         'tables': tables,
         'interline_equations': interline_equations,
-        'droped_blocks': droped_blocks,
+        'discarded_blocks': discarded_blocks,
     }
     return return_dict

+ 2 - 2
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -28,7 +28,7 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
         all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None])
 
     '''block嵌套问题解决'''
-    '''文本框与标题框重叠,优先信任标题框'''
+    '''文本框与标题框重叠,优先信任文本框'''
     all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
     '''任何框体与舍弃框重叠,优先信任舍弃框'''
     all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
@@ -60,7 +60,7 @@ def fix_text_overlap_title_blocks(all_bboxes):
             text_block_bbox = text_block[0], text_block[1], text_block[2], text_block[3]
             title_block_bbox = title_block[0], title_block[1], title_block[2], title_block[3]
             if calculate_iou(text_block_bbox, title_block_bbox) > 0.8:
-                all_bboxes.remove(text_block)
+                all_bboxes.remove(title_block)
 
     return all_bboxes
 

+ 9 - 9
magic_pdf/pre_proc/remove_bbox_overlap.py

@@ -5,7 +5,7 @@ def _remove_overlap_between_bbox(spans):
     res = []
     for v in spans:
         for i in range(len(res)):
-            if _is_in(res[i]["bbox"], v["bbox"]):
+            if _is_in(res[i]["bbox"], v["bbox"]) or _is_in(v["bbox"], res[i]["bbox"]):
                 continue
             if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]):
                 ix0, iy0, ix1, iy1 = res[i]["bbox"]
@@ -17,21 +17,21 @@ def _remove_overlap_between_bbox(spans):
                 if diff_y > diff_x:
                     if x1 >= ix1:
                         mid = (x0 + ix1) // 2
-                        ix1 = min(mid, ix1)
-                        x0 = max(mid + 1, x0)
+                        ix1 = min(mid - 0.25, ix1)
+                        x0 = max(mid + 0.25, x0)
                     else:
                         mid = (ix0 + x1) // 2
-                        ix0 = max(mid + 1, ix0)
-                        x1 = min(mid, x1)
+                        ix0 = max(mid + 0.25, ix0)
+                        x1 = min(mid -0.25, x1)
                 else:
                     if y1 >= iy1:
                         mid = (y0 + iy1) // 2
-                        y0 = max(mid + 1, y0)
-                        iy1 = min(iy1, mid)
+                        y0 = max(mid + 0.25, y0)
+                        iy1 = min(iy1, mid-0.25)
                     else:
                         mid = (iy0 + y1) // 2
-                        y1 = min(y1, mid)
-                        iy0 = max(mid + 1, iy0)
+                        y1 = min(y1, mid-0.25)
+                        iy0 = max(mid + 0.25, iy0)
                 res[i]["bbox"] = [ix0, iy0, ix1, iy1]
                 v["bbox"] = [x0, y0, x1, y1]