Prechádzať zdrojové kódy

Merge branch 'master' of https://github.com/myhloli/Magic-PDF

liusilu 1 rok pred
rodič
commit
fffee0ae97

+ 1 - 1
demo/demo_test.py

@@ -34,7 +34,7 @@ def get_json_from_local_or_s3(book_name=None):
         s3_config = get_s3_config(json_path)
         file_content = read_file(json_path, s3_config)
         json_str = file_content.decode("utf-8")
-        logger.info(json_str)
+        # logger.info(json_str)
         json_object = json.loads(json_str)
     return json_object
 

+ 57 - 40
demo/ocr_demo.py

@@ -4,8 +4,16 @@ import os
 from loguru import logger
 from pathlib import Path
 
+from app.common.s3 import get_s3_config
 from demo.demo_test import get_json_from_local_or_s3
-from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown_with_para, ocr_mk_nlp_markdown, ocr_mk_mm_markdown, ocr_mk_mm_standard_format
+from magic_pdf.dict2md.ocr_mkcontent import (
+    ocr_mk_mm_markdown_with_para,
+    ocr_mk_nlp_markdown,
+    ocr_mk_mm_markdown,
+    ocr_mk_mm_standard_format,
+    ocr_mk_mm_markdown_with_para_and_pagination,
+    make_standard_format_with_para
+)
 from magic_pdf.libs.commons import join_path
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 
@@ -35,50 +43,59 @@ def ocr_local_parse(ocr_pdf_path, ocr_json_file_path):
         ocr_pdf_model_info = read_json_file(ocr_json_file_path)
         pth = Path(ocr_json_file_path)
         book_name = pth.name
-        save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
-        save_path = join_path(save_tmp_path, "md")
-        save_path_with_bookname = os.path.join(save_path, book_name)
-        text_content_save_path = f"{save_path_with_bookname}/book.md"
-        pdf_info_dict = parse_pdf_by_ocr(
-            ocr_pdf_path,
-            None,
-            ocr_pdf_model_info,
-            save_path,
-            book_name,
-            debug_mode=True)
-
-        parent_dir = os.path.dirname(text_content_save_path)
-        if not os.path.exists(parent_dir):
-            os.makedirs(parent_dir)
-
-        # markdown_content = mk_nlp_markdown(pdf_info_dict)
-        markdown_content = ocr_mk_mm_markdown_with_para(pdf_info_dict)
-
-        with open(text_content_save_path, "w", encoding="utf-8") as f:
-            f.write(markdown_content)
-
-        standard_format = ocr_mk_mm_standard_format(pdf_info_dict)
-        standard_format_save_path = f"{save_path_with_bookname}/standard_format.txt"
-        with open(standard_format_save_path, "w", encoding="utf-8") as f:
-            f.write(str(standard_format))
-
-        # logger.info(markdown_content)
-        # save_markdown(markdown_text, ocr_json_file_path)
+        ocr_parse_core(book_name, ocr_pdf_path, ocr_pdf_model_info)
     except Exception as e:
         logger.exception(e)
 
 
 def ocr_online_parse(book_name, start_page_id=0, debug_mode=True):
-    json_object = get_json_from_local_or_s3(book_name)
-    logger.info(json_object)
+    try:
+        json_object = get_json_from_local_or_s3(book_name)
+        # logger.info(json_object)
+        s3_pdf_path = json_object["file_location"]
+        s3_config = get_s3_config(s3_pdf_path)
+        ocr_pdf_model_info = json_object.get("doc_layout_result")
+        ocr_parse_core(book_name, s3_pdf_path, ocr_pdf_model_info, s3_config=s3_config)
+    except Exception as e:
+        logger.exception(e)
+
+
+def ocr_parse_core(book_name, ocr_pdf_path, ocr_pdf_model_info, start_page_id=0, s3_config=None):
+    save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
+    save_path = join_path(save_tmp_path, "md")
+    save_path_with_bookname = os.path.join(save_path, book_name)
+    text_content_save_path = f"{save_path_with_bookname}/book.md"
+    pdf_info_dict = parse_pdf_by_ocr(
+        ocr_pdf_path,
+        s3_config,
+        ocr_pdf_model_info,
+        save_path,
+        book_name,
+        debug_mode=True)
+
+    parent_dir = os.path.dirname(text_content_save_path)
+    if not os.path.exists(parent_dir):
+        os.makedirs(parent_dir)
+
+    # markdown_content = mk_nlp_markdown(pdf_info_dict)
+    markdown_content = ocr_mk_mm_markdown_with_para(pdf_info_dict)
+    # markdown_pagination = ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict)
+
+    with open(text_content_save_path, "w", encoding="utf-8") as f:
+        f.write(markdown_content)
+
+    standard_format = make_standard_format_with_para(pdf_info_dict)
+    standard_format_save_path = f"{save_path_with_bookname}/standard_format.txt"
+    with open(standard_format_save_path, "w", encoding="utf-8") as f:
+        # 将standard_format dump成json文本并保存
+        f.write(json.dumps(standard_format, ensure_ascii=False))
+
 
 if __name__ == '__main__':
-    #ocr_pdf_path = r"D:\project\20231108code-clean\ocr\new\双栏\s0043-1354(02)00581-x.pdf"
-    #ocr_json_file_path = r"D:\project\20231108code-clean\ocr\new\双栏\s0043-1354(02)00581-x.json"
-    # ocr_pdf_path = r"D:\project\20231108code-clean\ocr\new\双栏\j.1540-627x.2006.00176.x.pdf"
-    # ocr_json_file_path = r"D:\project\20231108code-clean\ocr\new\双栏\j.1540-627x.2006.00176.x.json"
-    ocr_pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/ocr_1.pdf"
-    ocr_json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/ocr_1.json"
-    ocr_online_parse(book_name="数学新星网/edu_00001236")
-    ocr_local_parse(ocr_pdf_path, ocr_json_file_path)
+    pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
+    json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"
+    # ocr_local_parse(pdf_path, json_file_path)
+    book_name = "科数网/edu_00011318"
+    ocr_online_parse(book_name)
+    
     pass

+ 89 - 16
magic_pdf/dict2md/ocr_mkcontent.py

@@ -1,6 +1,19 @@
 from magic_pdf.libs.commons import s3_image_save_path, join_path
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
 from magic_pdf.libs.ocr_content_type import ContentType
+import wordninja
+import re
+
+
+def split_long_words(text):
+    segments = text.split(' ')
+    for i in range(len(segments)):
+        words = re.findall(r'\w+|[^\w\s]', segments[i], re.UNICODE)
+        for j in range(len(words)):
+            if len(words[j]) > 15:
+                words[j] = ' '.join(wordninja.split(words[j]))
+        segments[i] = ''.join(words)
+    return ' '.join(segments)
 
 
 def ocr_mk_nlp_markdown(pdf_info_dict: dict):
@@ -58,37 +71,96 @@ def ocr_mk_mm_markdown(pdf_info_dict: dict):
 def ocr_mk_mm_markdown_with_para(pdf_info_dict: dict):
     markdown = []
     for _, page_info in pdf_info_dict.items():
-        paras = page_info.get("para_blocks")
-        if not paras:
+        paras_of_layout = page_info.get("para_blocks")
+        page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm")
+        markdown.extend(page_markdown)
+    return '\n\n'.join(markdown)
+
+
+def ocr_mk_nlp_markdown_with_para(pdf_info_dict: dict):
+    markdown = []
+    for _, page_info in pdf_info_dict.items():
+        paras_of_layout = page_info.get("para_blocks")
+        page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "nlp")
+        markdown.extend(page_markdown)
+    return '\n\n'.join(markdown)
+
+def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
+    markdown_with_para_and_pagination = []
+    for page_no, page_info in pdf_info_dict.items():
+        paras_of_layout = page_info.get("para_blocks")
+        if not paras_of_layout:
             continue
+        page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm")
+        markdown_with_para_and_pagination.append({
+            'page_no': page_no,
+            'md_content': '\n\n'.join(page_markdown)
+        })
+    return markdown_with_para_and_pagination
+
+
+def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode):
+    page_markdown = []
+    for paras in paras_of_layout:
         for para in paras:
             para_text = ''
             for line in para:
                 for span in line['spans']:
                     span_type = span.get('type')
+                    content = ''
                     if span_type == ContentType.Text:
-                        para_text += span['content']
+                        content = ocr_escape_special_markdown_char(split_long_words(span['content']))
                     elif span_type == ContentType.InlineEquation:
-                        para_text += f" ${span['content']}$ "
+                        content = f"${ocr_escape_special_markdown_char(span['content'])}$"
                     elif span_type == ContentType.InterlineEquation:
-                        para_text += f"$$\n{span['content']}\n$$ "
-                    elif span_type == ContentType.Image:
-                        para_text += f"![]({join_path(s3_image_save_path, span['image_path'])})"
-            markdown.append(para_text)
+                        content = f"\n$$\n{ocr_escape_special_markdown_char(span['content'])}\n$$\n"
+                    elif span_type in [ContentType.Image, ContentType.Table]:
+                        if mode == 'mm':
+                            content = f"\n![]({join_path(s3_image_save_path, span['image_path'])})\n"
+                        elif mode == 'nlp':
+                            pass
+                    if content != '':
+                        para_text += content + ' '
+            if para_text.strip() == '':
+                continue
+            else:
+                page_markdown.append(para_text.strip() + '  ')
+    return page_markdown
 
-    return '\n\n'.join(markdown)
 
+def para_to_standard_format(para):
+    para_content = {}
+    if len(para) == 1:
+        para_content = line_to_standard_format(para[0])
+    elif len(para) > 1:
+        para_text = ''
+        inline_equation_num = 0
+        for line in para:
+            for span in line['spans']:
+                span_type = span.get('type')
+                if span_type == ContentType.Text:
+                    content = ocr_escape_special_markdown_char(split_long_words(span['content']))
+                elif span_type == ContentType.InlineEquation:
+                    content = f"${ocr_escape_special_markdown_char(span['content'])}$"
+                    inline_equation_num += 1
+                para_text += content + ' '
+        para_content = {
+            'type': 'text',
+            'text': para_text,
+            'inline_equation_num': inline_equation_num
+        }
+    return para_content
 
 def make_standard_format_with_para(pdf_info_dict: dict):
     content_list = []
     for _, page_info in pdf_info_dict.items():
-        paras = page_info.get("para_blocks")
-        if not paras:
+        paras_of_layout = page_info.get("para_blocks")
+        if not paras_of_layout:
             continue
-        for para in paras:
-            for line in para:
-                content = line_to_standard_format(line)
-                content_list.append(content)
+        for paras in paras_of_layout:
+            for para in paras:
+                para_content = para_to_standard_format(para)
+                content_list.append(para_content)
     return content_list
 
 
@@ -125,7 +197,8 @@ def line_to_standard_format(line):
                 line_text += f"${inline_equation}$"
                 inline_equation_num += 1
             elif span['type'] == ContentType.Text:
-                line_text += span['content']
+                text_content = ocr_escape_special_markdown_char(span['content'])  # 转义特殊符号
+                line_text += text_content
     content = {
         'type': 'text',
         'text': line_text,

+ 27 - 0
magic_pdf/libs/boxbase.py

@@ -18,6 +18,33 @@ def _is_in_or_part_overlap(box1, box2) -> bool:
                 y1_1 < y0_2 or  # box1在box2的上边
                 y0_1 > y1_2)    # box1在box2的下边
 
+def _is_in_or_part_overlap_with_area_ratio(box1, box2, area_ratio_threshold=0.6):
+    """
+    判断box1是否在box2里面,或者box1和box2有部分重叠,且重叠面积占box1的比例超过area_ratio_threshold
+    
+    """
+    if box1 is None or box2 is None:
+        return False
+    
+    x0_1, y0_1, x1_1, y1_1 = box1
+    x0_2, y0_2, x1_2, y1_2 = box2
+
+    if not _is_in_or_part_overlap(box1, box2):
+        return False
+    
+    # 计算重叠面积
+    x_left = max(x0_1, x0_2)
+    y_top = max(y0_1, y0_2)
+    x_right = min(x1_1, x1_2)
+    y_bottom = min(y1_1, y1_2)
+    overlap_area = (x_right - x_left) * (y_bottom - y_top)
+    
+    # 计算box1的面积
+    box1_area = (x1_1 - x0_1) * (y1_1 - y0_1)
+    
+    return overlap_area / box1_area > area_ratio_threshold
+    
+    
 def _is_in(box1, box2) -> bool:
     """
     box1是否完全在box2里面

+ 8 - 10
magic_pdf/libs/draw_bbox.py

@@ -27,7 +27,7 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config):
         page.insert_text((x0, y0), str(j + 1), fontsize=10, color=new_rgb)  # Insert the index at the top left corner of the rectangle
 
 
-def draw_layout_bbox(pdf_info_dict, input_path, out_path):
+def draw_layout_bbox(pdf_info_dict, pdf_bytes, out_path):
     layout_bbox_list = []
     dropped_bbox_list = []
     for page in pdf_info_dict.values():
@@ -40,15 +40,14 @@ def draw_layout_bbox(pdf_info_dict, input_path, out_path):
             for dropped_bbox in dropped_bboxes:
                 page_dropped_list.append(dropped_bbox)
         dropped_bbox_list.append(page_dropped_list)
-
-    doc = fitz.open(input_path)
-    for i, page in enumerate(doc):
+    pdf_docs = fitz.open("pdf", pdf_bytes)
+    for i, page in enumerate(pdf_docs):
         draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0])
         draw_bbox_without_number(i, dropped_bbox_list, page, [0, 255, 0])
     # Save the PDF
-    doc.save(f"{out_path}/layout.pdf")
+    pdf_docs.save(f"{out_path}/layout.pdf")
 
-def draw_text_bbox(pdf_info_dict, input_path, out_path):
+def draw_text_bbox(pdf_info_dict, pdf_bytes, out_path):
     text_list = []
     inline_equation_list = []
     interline_equation_list = []
@@ -68,13 +67,12 @@ def draw_text_bbox(pdf_info_dict, input_path, out_path):
         text_list.append(page_text_list)
         inline_equation_list.append(page_inline_equation_list)
         interline_equation_list.append(page_interline_equation_list)
-
-    doc = fitz.open(input_path)
-    for i, page in enumerate(doc):
+    pdf_docs = fitz.open("pdf", pdf_bytes)
+    for i, page in enumerate(pdf_docs):
         # 获取当前页面的数据
         draw_bbox_without_number(i, text_list, page, [255, 0, 0])
         draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0])
         draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255])
 
     # Save the PDF
-    doc.save(f"{out_path}/text.pdf")
+    pdf_docs.save(f"{out_path}/text.pdf")

+ 381 - 80
magic_pdf/para/para_split.py

@@ -2,14 +2,14 @@ from sklearn.cluster import DBSCAN
 import numpy as np
 from loguru import logger
 
-from magic_pdf.libs.boxbase import _is_in
+from magic_pdf.libs.boxbase import _is_in_or_part_overlap_with_area_ratio as is_in_layout
 from magic_pdf.libs.ocr_content_type import ContentType
 
 
 LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?',":", ":", ")", ")", ";"]
 INLINE_EQUATION = ContentType.InlineEquation
 INTERLINE_EQUATION = ContentType.InterlineEquation
-TEXT = "text"
+TEXT = ContentType.Text
 
 
 def __get_span_text(span):
@@ -19,24 +19,102 @@ def __get_span_text(span):
         
     return c
     
-    
-def __add_line_period(blocks, layout_bboxes):
+
+def __detect_list_lines(lines, new_layout_bboxes, lang):
     """
-    为每行添加句号
-    如果这个行
-    1. 以行内公式结尾,但没有任何标点符号,此时加个句号,认为他就是段落结尾。
+    探测是否包含了列表,并且把列表的行分开.
+    这样的段落特点是,顶格字母大写/数字,紧跟着几行缩进的。缩进的行首字母含小写的。
     """
-    for block in blocks:
-        for line in block['lines']:
-            last_span = line['spans'][-1]
-            span_type = last_span['type']
-            if span_type in [INLINE_EQUATION]:
-                span_content = last_span['content'].strip()
-                if span_type==INLINE_EQUATION and span_content[-1] not in LINE_STOP_FLAG:
-                    if span_type in [INLINE_EQUATION, INTERLINE_EQUATION]:
-                        last_span['content'] = span_content + '.'
+    def find_repeating_patterns(lst):
+        indices = []
+        ones_indices = []
+        i = 0
+        while i < len(lst) - 1:  # 确保余下元素至少有2个
+            if lst[i] == 1 and lst[i+1] in [2, 3]:  # 额外检查以防止连续出现的1
+                start = i
+                ones_in_this_interval = [i]
+                i += 1
+                while i < len(lst) and lst[i] in [2, 3]:
+                    i += 1
+                # 验证下一个序列是否符合条件
+                if i < len(lst) - 1 and lst[i] == 1 and lst[i+1] in [2, 3] and lst[i-1] in [2, 3]:
+                    while i < len(lst) and lst[i] in [1, 2, 3]:
+                        if lst[i] == 1:
+                            ones_in_this_interval.append(i)
+                        i += 1
+                    indices.append((start, i - 1))
+                    ones_indices.append(ones_in_this_interval)
+                else:
+                    i += 1
+            else:
+                i += 1
+        return indices, ones_indices
+    """===================="""
+    def split_indices(slen, index_array):
+        result = []
+        last_end = 0
+        
+        for start, end in sorted(index_array):
+            if start > last_end:
+                # 前一个区间结束到下一个区间开始之间的部分标记为"text"
+                result.append(('text', last_end, start - 1))
+            # 区间内标记为"list"
+            result.append(('list', start, end))
+            last_end = end + 1
+
+        if last_end < slen:
+            # 如果最后一个区间结束后还有剩余的字符串,将其标记为"text"
+            result.append(('text', last_end, slen - 1))
 
+        return result
+    """===================="""
 
+    if lang!='en':
+        return lines, None
+    else:
+        total_lines = len(lines)
+        line_fea_encode = []
+        """
+        对每一行进行特征编码,编码规则如下:
+        1. 如果行顶格,且大写字母开头或者数字开头,编码为1
+        2. 如果顶格,其他非大写开头编码为4
+        3. 如果非顶格,首字符大写,编码为2
+        4. 如果非顶格,首字符非大写编码为3
+        """
+        for l in lines:
+            first_char = __get_span_text(l['spans'][0])[0]
+            layout_left = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)[0]
+            if l['bbox'][0] == layout_left:
+                if first_char.isupper() or first_char.isdigit():
+                    line_fea_encode.append(1)
+                else:
+                    line_fea_encode.append(4)
+            else:
+                if first_char.isupper():
+                    line_fea_encode.append(2)
+                else:
+                    line_fea_encode.append(3)
+                    
+        # 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
+        
+        list_indice, list_start_idx  = find_repeating_patterns(line_fea_encode)
+        if len(list_indice)>0:
+            logger.info(f"发现了列表,列表行数:{list_indice}, {list_start_idx}")
+        
+        # TODO check一下这个特列表里缩进的行左侧是不是对齐的。
+        segments = []
+        for start, end in list_indice:
+            for i in range(start, end+1):
+                if i>0:
+                    if line_fea_encode[i] == 4:
+                        logger.info(f"列表行的第{i}行不是顶格的")
+                        break
+            else:
+                logger.info(f"列表行的第{start}到第{end}行是列表")
+        
+        return split_indices(total_lines, list_indice), list_start_idx
+        
+            
 
 def __valign_lines(blocks, layout_bboxes):
     """
@@ -50,7 +128,7 @@ def __valign_lines(blocks, layout_bboxes):
     new_layout_bboxes = []
     
     for layout_box in layout_bboxes:
-        blocks_in_layoutbox = [b for b in blocks if _is_in(b['bbox'], layout_box['layout_bbox'])]
+        blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], layout_box['layout_bbox'])]
         if len(blocks_in_layoutbox)==0:
             continue
         
@@ -136,7 +214,7 @@ def __group_line_by_layout(blocks, layout_bboxes, lang="en"):
     lines_group = []
     
     for lyout in layout_bboxes:
-        lines = [line for block in blocks if _is_in(block['bbox'], lyout['layout_bbox']) for line in block['lines']]
+        lines = [line for block in blocks if is_in_layout(block['bbox'], lyout['layout_bbox']) for line in block['lines']]
         lines_group.append(lines)
 
     return lines_group
@@ -151,45 +229,156 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
         且下一行开头不留空白。
     
     """
-    paras = []
+    list_info = [] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
+    layout_paras = []
     right_tail_distance = 1.5 * char_avg_len
+    
+    
     for lines in lines_group:
+        paras = []
         total_lines = len(lines)
-        if total_lines<=1: # 0行无需处理。1行无法分段。
+        if total_lines==0:
+            continue # 0行无需处理
+        if total_lines==1: # 1行无法分段。
+            layout_paras.append([lines])
+            list_info.append([False, False])
             continue
-        #layout_right = max([line['bbox'][2] for line in lines])
+        
+        """在进入到真正的分段之前,要对文字块从统计维度进行对齐方式的探测,
+            对齐方式分为以下:
+            1. 左对齐的文本块(特点是左侧顶格,或者左侧不顶格但是右侧顶格的行数大于非顶格的行数,顶格的首字母有大写也有小写)
+                1) 右侧对齐的行,单独成一段
+                2) 中间对齐的行,按照字体/行高聚合成一段
+            2. 左对齐的列表块(其特点是左侧顶格的行数小于等于非顶格的行数,非定格首字母会有小写,顶格90%是大写。并且左侧顶格行数大于1,大于1是为了这种模式连续出现才能称之为列表)
+                这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。
+        """
+        
+        text_segments, list_start_line = __detect_list_lines(lines, new_layout_bbox, lang)
+        """根据list_range,把lines分成几个部分
+        
+        """
+        
         layout_right = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[2]
+        layout_left = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[0]
         para = [] # 元素是line
-        
-        for i, line in enumerate(lines):
-            # 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断一下行结尾特征。
-            
-            cur_line_type = line['spans'][-1]['type']
-            #cur_line_last_char = line['spans'][-1]['content'][-1]
-            next_line = lines[i+1] if i<total_lines-1 else None
-            
-            if cur_line_type in [TEXT, INLINE_EQUATION]:
-                if line['bbox'][2] < layout_right - right_tail_distance:
-                    para.append(line)
-                    paras.append(para)
-                    para = []
-                elif line['bbox'][2] >= layout_right - right_tail_distance and next_line and next_line['bbox'][0] == layout_right: # 现在这行到了行尾沾满,下一行存在且顶格。
-                    para.append(line)
-                else: 
-                    para.append(line)
+        layout_list_info = [False, False] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
+        for content_type, start, end in text_segments:
+            if content_type == 'list':
+                for i, line in enumerate(lines[start:end+1]):
+                    line_x0 = line['bbox'][0]
+                    if line_x0 == layout_left: # 列表开头
+                        if len(para)>0:
+                            paras.append(para)
+                            para = []
+                        para.append(line)
+                    else:
+                        para.append(line)
+                if len(para)>0:
                     paras.append(para)
                     para = []
-            else: # 其他,图片、表格、行间公式,各自占一段
-                if len(para)>0:  # 先把之前的段落加入到结果中
+                if start==0:
+                    layout_list_info[0] = True
+                if end==total_lines-1:
+                    layout_list_info[1] = True
+            else:
+                for i, line in enumerate(lines[start:end+1]):
+                    # 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断一下行结尾特征。
+                    cur_line_type = line['spans'][-1]['type']
+                    next_line = lines[i+1] if i<total_lines-1 else None
+                    
+                    if cur_line_type in [TEXT, INLINE_EQUATION]:
+                        if line['bbox'][2] < layout_right - right_tail_distance:
+                            para.append(line)
+                            paras.append(para)
+                            para = []
+                        elif line['bbox'][2] >= layout_right - right_tail_distance and next_line and next_line['bbox'][0] == layout_left: # 现在这行到了行尾沾满,下一行存在且顶格。
+                            para.append(line)
+                        else: 
+                            para.append(line)
+                            paras.append(para)
+                            para = []
+                    else: # 其他,图片、表格、行间公式,各自占一段
+                        if len(para)>0:  # 先把之前的段落加入到结果中
+                            paras.append(para)
+                            para = []
+                        paras.append([line]) # 再把当前行加入到结果中。当前行为行间公式、图、表等。
+                        para = []
+                        
+                if len(para)>0:
                     paras.append(para)
                     para = []
-                paras.append([line]) # 再把当前行加入到结果中。当前行为行间公式、图、表等。
-                para = []
-        if len(para)>0:
-            paras.append(para)
-            para = []
+                
+        list_info.append(layout_list_info)
+        layout_paras.append(paras)
+        paras = []
+                
                     
-    return paras
+    return layout_paras, list_info
+
+def __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang):
+    """
+    如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
+    根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
+    """
+    if len(layout_paras)==0 or len(layout_list_info)==0: # 0的时候最后的return 会出错
+        return layout_paras, [False, False]
+        
+    for i in range(1, len(layout_paras)):
+        pre_layout_list_info = layout_list_info[i-1]
+        next_layout_list_info = layout_list_info[i]
+        pre_last_para = layout_paras[i-1][-1]
+        next_paras = layout_paras[i]
+        next_first_para = next_paras[0]
+        
+        if pre_layout_list_info[1] and not next_layout_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
+            logger.info(f"连接page {page_num} 内的list")
+            # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
+            may_list_lines = []
+            for j in range(len(next_paras)):
+                line = next_paras[j]
+                if len(line)==1: # 只可能是一行,多行情况再需要分析了
+                    if line[0]['bbox'][0] > __find_layout_bbox_by_line(line[0]['bbox'], new_layout_bbox)[0]:
+                        may_list_lines.append(line[0])
+                    else:
+                        break
+                else:
+                    break
+            # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
+            if len(may_list_lines)>0 and len(set([x['bbox'][0] for x in may_list_lines]))==1:
+                pre_last_para.extend(may_list_lines)
+                layout_paras[i] = layout_paras[i][len(may_list_lines):]
+                           
+    return layout_paras, [layout_list_info[0][0], layout_list_info[-1][1]] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
+
+
+def __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox,  pre_page_list_info, next_page_list_info, page_num, lang):
+    """
+    如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
+    根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
+    """
+    if len(pre_page_paras)==0 or len(next_page_paras)==0: # 0的时候最后的return 会出错
+        return False
+    
+    if pre_page_list_info[1] and not next_page_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
+        logger.info(f"连接page {page_num} 内的list")
+        # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
+        may_list_lines = []
+        for j in range(len(next_page_paras[0])):
+            line = next_page_paras[0][j]
+            if len(line)==1: # 只可能是一行,多行情况再需要分析了
+                if line[0]['bbox'][0] > __find_layout_bbox_by_line(line[0]['bbox'], next_page_layout_bbox)[0]:
+                    may_list_lines.append(line[0])
+                else:
+                    break
+            else:
+                break
+        # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
+        if len(may_list_lines)>0 and len(set([x['bbox'][0] for x in may_list_lines]))==1:
+            pre_page_paras[-1].append(may_list_lines)
+            next_page_paras[0] = next_page_paras[0][len(may_list_lines):]
+            return True
+                       
+    return False
 
 
 def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
@@ -197,12 +386,12 @@ def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
     根据line找到所在的layout
     """
     for layout in layout_bboxes:
-        if _is_in(line_bbox, layout):
+        if is_in_layout(line_bbox, layout):
             return layout
     return None
 
 
-def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang="en"):
+def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang):
     """
     layout之间进行分段。
     主要是计算前一个layOut的最后一行和后一个layout的第一行是否可以连接。
@@ -212,21 +401,27 @@ def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang="en"):
 
     """
     connected_layout_paras = []
-    for i, para in enumerate(layout_paras):
-        if i==0:
-            connected_layout_paras.append(para)
+    if len(layout_paras)==0:
+        return connected_layout_paras
+    
+    connected_layout_paras.append(layout_paras[0])
+    for i in range(1, len(layout_paras)):
+        try:
+            if len(layout_paras[i])==0 or len(layout_paras[i-1])==0: #  TODO 考虑连接问题,
+                continue
+            pre_last_line = layout_paras[i-1][-1][-1]
+            next_first_line = layout_paras[i][0][0]
+        except Exception as e:
+            logger.error(f"page layout {i} has no line")
             continue
-        pre_last_line = layout_paras[i-1][-1]
-        next_first_line = layout_paras[i][0]
         pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
         pre_last_line_type = pre_last_line['spans'][-1]['type']
         next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
         next_first_line_type = next_first_line['spans'][0]['type']
-        if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]: # TODO,真的要做好,要考虑跨table, image, 行间的情况
-            connected_layout_paras.append(para)
+        if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
+            connected_layout_paras.append(layout_paras[i])
             continue
         
-        
         pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)[2]
         next_x0_min = __find_layout_bbox_by_line(next_first_line['bbox'], new_layout_bbox)[0]
         
@@ -234,23 +429,31 @@ def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang="en"):
         next_first_line_text = next_first_line_text.strip()
         if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and next_first_line['bbox'][0]==next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
             """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
-            connected_layout_paras[-1].extend(para)
-        else:
+            connected_layout_paras[-1][-1].extend(layout_paras[i][0])
+            layout_paras[i].pop(0) # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。
+            if len(layout_paras[i])==0:
+                layout_paras.pop(i)
+            else:
+                connected_layout_paras.append(layout_paras[i])
+        else:                            
             """连接段落条件不成立,将前一个layout的段落加入到结果中。"""
-            connected_layout_paras.append(para)
+            connected_layout_paras.append(layout_paras[i])
     
     return connected_layout_paras
 
 
-def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, lang):
+def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num, lang):
     """
     连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。
     是否可以连接的条件:
     1. 前一个页面的最后一个段落最后一行沾满整个行。并且没有结尾符号。
     2. 后一个页面的第一个段落第一行没有空白开头。
     """
-    pre_last_para = pre_page_paras[-1]
-    next_first_para = next_page_paras[0]
+    # 有的页面可能压根没有文字
+    if len(pre_page_paras)==0 or len(next_page_paras)==0 or len(pre_page_paras[0])==0 or len(next_page_paras[0])==0: # TODO [[]]为什么出现在pre_page_paras里?
+        return False
+    pre_last_para = pre_page_paras[-1][-1]
+    next_first_para = next_page_paras[0][0]
     pre_last_line = pre_last_para[-1]
     next_first_line = next_first_para[0]
     pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
@@ -269,14 +472,91 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
     next_first_line_text = next_first_line_text.strip()
     if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and next_first_line['bbox'][0]==next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
         """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
-        pre_page_paras[-1].extend(next_first_para)
-        next_page_paras.pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
+        pre_last_para.extend(next_first_para)
+        next_page_paras[0].pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
         return True
     else:
         return False
 
+def find_consecutive_true_regions(input_array):
+    start_index = None  # 连续True区域的起始索引
+    regions = []  # 用于保存所有连续True区域的起始和结束索引
+
+    for i in range(len(input_array)):
+        # 如果我们找到了一个True值,并且当前并没有在连续True区域中
+        if input_array[i] and start_index is None:
+            start_index = i  # 记录连续True区域的起始索引
+
+        # 如果我们找到了一个False值,并且当前在连续True区域中
+        elif not input_array[i] and start_index is not None:
+            # 如果连续True区域长度大于1,那么将其添加到结果列表中
+            if i - start_index > 1: 
+                regions.append((start_index, i-1)) 
+            start_index = None  # 重置起始索引
+
+    # 如果最后一个元素是True,那么需要将最后一个连续True区域加入到结果列表中
+    if start_index is not None and len(input_array) - start_index > 1:
+        regions.append((start_index, len(input_array)-1))
+
+    return regions
+
+
+def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode):
+    """
+    找出来中间对齐的连续单行文本,如果连续行高度相同,那么合并为一个段落。
+    一个line居中的条件是:
+    1. 水平中心点跨越layout的中心点。
+    2. 左右两侧都有空白
+    """
+    
+    for layout_i, layout_para in enumerate(page_paras):
+        layout_box = new_layout_bbox[layout_i]
+        single_line_paras_tag = []
+        for i in range(len(layout_para)):
+            single_line_paras_tag.append(len(layout_para[i])==1 and layout_para[i][0]['spans'][0]['type']==TEXT)
+            
+        """找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
+        consecutive_single_line_indices = find_consecutive_true_regions(single_line_paras_tag)
+        if len(consecutive_single_line_indices)>0:
+            index_offset = 0
+            """检查这些行是否是高度相同的,居中的"""
+            for start, end in consecutive_single_line_indices:
+                start += index_offset
+                end += index_offset
+                line_hi = np.array([line[0]['bbox'][3]-line[0]['bbox'][1] for line in layout_para[start:end+1]])
+                first_line_text = ''.join([__get_span_text(span) for span in layout_para[start][0]['spans']])
+                if "Table" in first_line_text or "Figure" in first_line_text:
+                    pass
+                if debug_mode:
+                    logger.info(line_hi.std())
+                
+                if line_hi.std()<2:
+                    """行高度相同,那么判断是否居中"""
+                    all_left_x0 = [line[0]['bbox'][0] for line in layout_para[start:end+1]]
+                    all_right_x1 = [line[0]['bbox'][2] for line in layout_para[start:end+1]]
+                    layout_center = (layout_box[0] + layout_box[2]) / 2
+                    if all([x0 < layout_center < x1 for x0, x1 in zip(all_left_x0, all_right_x1)]) \
+                    and not all([x0==layout_box[0] for x0 in all_left_x0]) \
+                    and not all([x1==layout_box[2] for x1 in all_right_x1]):
+                        merge_para = [l[0] for l in layout_para[start:end+1]]
+                        para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']])
+                        if debug_mode:
+                            logger.info(para_text)
+                        layout_para[start:end+1] = [merge_para]
+                        index_offset -= end-start
+                        
+    return
+            
+
+def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang):
+    """
+    找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。
+    """
+    
+    pass
+
 
-def __do_split(blocks, layout_bboxes, new_layout_bbox, lang="en"):
+def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
     """
     根据line和layout情况进行分段
     先实现一个根据行末尾特征分段的简单方法。
@@ -289,34 +569,55 @@ def __do_split(blocks, layout_bboxes, new_layout_bbox, lang="en"):
     4. 图、表,目前独占一行,不考虑分段。
     """
     lines_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段
-    layout_paras = __split_para_in_layoutbox(lines_group, new_layout_bbox, lang) # layout内分段
-    connected_layout_paras = __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang) # layout间链接段落
-    return connected_layout_paras
+    layout_paras, layout_list_info = __split_para_in_layoutbox(lines_group, new_layout_bbox, lang) # layout内分段
+    layout_paras2, page_list_info = __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang) # layout之间连接列表段落
+    connected_layout_paras = __connect_para_inter_layoutbox(layout_paras2, new_layout_bbox, lang) # layout间链接段落
     
     
-def para_split(pdf_info_dict, lang="en"):
+    return connected_layout_paras, page_list_info
+   
+
+def para_split(pdf_info_dict, debug_mode, lang="en"):
     """
     根据line和layout情况进行分段
     """
     new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
-    for _, page in pdf_info_dict.items():
+    all_page_list_info = [] # 保存每个页面开头和结尾是否是列表
+    for page_num, page in pdf_info_dict.items():
         blocks = page['preproc_blocks']
         layout_bboxes = page['layout_bboxes']
         new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
         new_layout_of_pages.append(new_layout_bbox)
-        splited_blocks = __do_split(blocks, layout_bboxes, new_layout_bbox, lang)
+        splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang)
+        all_page_list_info.append(page_list_info)
         page['para_blocks'] = splited_blocks
         
     """连接页面与页面之间的可能合并的段落"""
     pdf_infos = list(pdf_info_dict.values())
-    for i, page in enumerate(pdf_info_dict.values()):
-        if i==0:
+    for page_num, page in enumerate(pdf_info_dict.values()):
+        if page_num==0:
             continue
-        pre_page_paras = pdf_infos[i-1]['para_blocks']
-        next_page_paras = pdf_infos[i]['para_blocks']
-        pre_page_layout_bbox = new_layout_of_pages[i-1]
-        next_page_layout_bbox = new_layout_of_pages[i]
+        pre_page_paras = pdf_infos[page_num-1]['para_blocks']
+        next_page_paras = pdf_infos[page_num]['para_blocks']
+        pre_page_layout_bbox = new_layout_of_pages[page_num-1]
+        next_page_layout_bbox = new_layout_of_pages[page_num]
         
-        is_conn= __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, lang) 
-        if is_conn:
-            logger.info(f"连接了第{i-1}页和第{i}页的段落")
+        is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num, lang)
+        if debug_mode:
+            if is_conn:
+                logger.info(f"连接了第{page_num-1}页和第{page_num}页的段落")
+            
+        is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, all_page_list_info[page_num-1], all_page_list_info[page_num], page_num, lang)
+        if debug_mode:
+            if is_list_conn:
+                logger.info(f"连接了第{page_num-1}页和第{page_num}页的列表段落")
+            
+    """接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
+    1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
+    2. 居中的一些连续单行,如果高度相同,那么可能是一个段落。
+    """
+    for page_num, page in enumerate(pdf_info_dict.values()):
+        page_paras = page['para_blocks']
+        new_layout_bbox = new_layout_of_pages[page_num]
+        __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode=debug_mode)
+        __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang)

+ 37 - 48
magic_pdf/pdf_parse_by_ocr.py

@@ -57,16 +57,16 @@ def construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, lay
 
 
 def parse_pdf_by_ocr(
-    pdf_path,
-    s3_pdf_profile,
-    pdf_model_output,
-    save_path,
-    book_name,
-    pdf_model_profile=None,
-    image_s3_config=None,
-    start_page_id=0,
-    end_page_id=None,
-    debug_mode=False,
+        pdf_path,
+        s3_pdf_profile,
+        pdf_model_output,
+        save_path,
+        book_name,
+        pdf_model_profile=None,
+        image_s3_config=None,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
 ):
     pdf_bytes = read_file(pdf_path, s3_pdf_profile)
     save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
@@ -95,7 +95,6 @@ def parse_pdf_by_ocr(
 
     start_time = time.time()
 
-
     end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
     for page_id in range(start_page_id, end_page_id + 1):
 
@@ -125,13 +124,6 @@ def parse_pdf_by_ocr(
             page_id, page, ocr_page_info, md_bookname_save_path, debug_mode=debug_mode
         )
 
-        # 构建需要remove的bbox列表
-        # need_remove_spans_bboxes = []
-        # need_remove_spans_bboxes.extend(page_no_bboxes)
-        # need_remove_spans_bboxes.extend(header_bboxes)
-        # need_remove_spans_bboxes.extend(footer_bboxes)
-        # need_remove_spans_bboxes.extend(footnote_bboxes)
-
         # 构建需要remove的bbox字典
         need_remove_spans_bboxes_dict = {
             DropTag.PAGE_NUMBER: page_no_bboxes,
@@ -199,50 +191,48 @@ def parse_pdf_by_ocr(
             else:
                 continue
 
-
-
-
-        # 删除重叠spans中较小的那些
+        '''删除重叠spans中较小的那些'''
         spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
 
-        # 删除remove_span_block_bboxes中的bbox
-        # spans = remove_spans_by_bboxes(spans, need_remove_spans_bboxes)
-        # 按qa要求,增加drop相关数据
+        '''
+        删除remove_span_block_bboxes中的bbox
+        并增加drop相关数据
+        '''
         spans, dropped_spans_by_removed_bboxes = remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict)
 
-        # 对image和table截图
+        '''对image和table截图'''
         spans = cut_image_and_table(spans, page, page_id, book_name, save_path, img_s3_client)
 
-        # 行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)
+        '''行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)'''
         displayed_list = []
         text_inline_lines = []
         modify_y_axis(spans, displayed_list, text_inline_lines)
-        # 模型识别错误的行间公式, type类型转换成行内公式
+
+        '''模型识别错误的行间公式, type类型转换成行内公式'''
         spans = modify_inline_equation(spans, displayed_list, text_inline_lines)
 
-        # bbox去除粘连
+        '''bbox去除粘连'''
         spans = remove_overlap_between_bbox(spans)
 
-        # 对tpye=["interline_equation", "image", "table"]进行额外处理,如果左边有字的话,将该span的bbox中y0调整至不高于文字的y0
+        '''
+        对tpye=["interline_equation", "image", "table"]进行额外处理,
+        如果左边有字的话,将该span的bbox中y0调整至不高于文字的y0
+        '''
         spans = adjust_bbox_for_standalone_block(spans)
 
-
-        # 从ocr_page_info中解析layout信息(按自然阅读方向排序,并修复重叠和交错的bad case)
+        '''从ocr_page_info中解析layout信息(按自然阅读方向排序,并修复重叠和交错的bad case)'''
         layout_bboxes, layout_tree = layout_detect(ocr_page_info['subfield_dets'], page, ocr_page_info)
 
-        # 将spans合并成line(在layout内,从上到下,从左到右)
+        '''将spans合并成line(在layout内,从上到下,从左到右)'''
         lines, dropped_spans_by_layout = merge_spans_to_line_by_layout(spans, layout_bboxes)
 
-        # 将lines合并成block
+        '''将lines合并成block'''
         blocks = merge_lines_to_block(lines)
 
-        # 根据block合并段落
-        #para_blocks = para_split(blocks, layout_bboxes)
-        
-        # 获取QA需要外置的list
+        '''获取QA需要外置的list'''
         images, tables, interline_equations, inline_equations = get_qa_need_list(blocks)
 
-        # drop的span_list合并
+        '''drop的span_list合并'''
         dropped_spans = []
         dropped_spans.extend(dropped_spans_by_span_overlap)
         dropped_spans.extend(dropped_spans_by_removed_bboxes)
@@ -263,19 +253,18 @@ def parse_pdf_by_ocr(
             elif span['type'] in [ContentType.InlineEquation, ContentType.InterlineEquation]:
                 dropped_equation_block.append(span)
 
-
-
-        # 构造pdf_info_dict
+        '''构造pdf_info_dict'''
         page_info = construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
                                              images, tables, interline_equations, inline_equations,
-                                             dropped_text_block, dropped_image_block, dropped_table_block, dropped_equation_block,
+                                             dropped_text_block, dropped_image_block, dropped_table_block,
+                                             dropped_equation_block,
                                              need_remove_spans_bboxes_dict)
         pdf_info_dict[f"page_{page_id}"] = page_info
 
     """分段"""
-    para_split(pdf_info_dict)
-    
-    # 在测试时,保存调试信息
+    para_split(pdf_info_dict, debug_mode=debug_mode)
+
+    '''在测试时,保存调试信息'''
     if debug_mode:
         params_file_save_path = join_path(
             save_tmp_path, "md", book_name, "preproc_out.json"
@@ -284,7 +273,7 @@ def parse_pdf_by_ocr(
             json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
 
         # drow_bbox
-        draw_layout_bbox(pdf_info_dict, pdf_path, md_bookname_save_path)
-        draw_text_bbox(pdf_info_dict, pdf_path, md_bookname_save_path)
+        draw_layout_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
+        draw_text_bbox(pdf_info_dict, pdf_bytes, md_bookname_save_path)
 
     return pdf_info_dict

+ 686 - 0
magic_pdf/pdf_parse_for_train.py

@@ -0,0 +1,686 @@
+import time
+
+# from anyio import Path
+
+from magic_pdf.libs.commons import (
+    fitz,
+    get_delta_time,
+    get_img_s3_client,
+    get_docx_model_output,
+)
+import json
+import os
+from copy import deepcopy
+import math
+from loguru import logger
+from magic_pdf.layout.bbox_sort import (
+    prepare_bboxes_for_layout_split,
+)
+from magic_pdf.layout.layout_sort import (
+    LAYOUT_UNPROC,
+    get_bboxes_layout,
+    get_columns_cnt_of_layout,
+    sort_text_block,
+)
+from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.libs.markdown_utils import escape_special_markdown_char
+from magic_pdf.libs.safe_filename import sanitize_filename
+from magic_pdf.libs.vis_utils import draw_bbox_on_page, draw_layout_bbox_on_page
+from magic_pdf.pre_proc.detect_images import parse_images
+from magic_pdf.pre_proc.detect_tables import parse_tables  # 获取tables的bbox
+from magic_pdf.pre_proc.detect_equation import parse_equations  # 获取equations的bbox
+from magic_pdf.pre_proc.detect_header import parse_headers  # 获取headers的bbox
+from magic_pdf.pre_proc.detect_page_number import parse_pageNos  # 获取pageNos的bbox
+from magic_pdf.pre_proc.detect_footnote import (
+    parse_footnotes_by_model,
+    parse_footnotes_by_rule,
+)  # 获取footnotes的bbox
+from magic_pdf.pre_proc.detect_footer_by_model import parse_footers  # 获取footers的bbox
+
+from magic_pdf.post_proc.detect_para import (
+    ParaProcessPipeline,
+    TitleDetectionException,
+    TitleLevelException,
+    ParaSplitException,
+    ParaMergeException,
+    DenseSingleLineBlockException,
+)
+from magic_pdf.pre_proc.main_text_font import get_main_text_font
+from magic_pdf.pre_proc.remove_colored_strip_bbox import remove_colored_strip_textblock
+from magic_pdf.pre_proc.remove_footer_header import remove_headder_footer_one_page
+from magic_pdf.train_utils.extract_caption import extract_caption_bbox
+
+"""
+from para.para_pipeline import ParaProcessPipeline
+from para.exceptions import (
+    TitleDetectionException,
+    TitleLevelException,
+    ParaSplitException,
+    ParaMergeException,
+    DenseSingleLineBlockException,
+)
+"""
+
+from magic_pdf.libs.commons import read_file, join_path
+from magic_pdf.libs.pdf_image_tools import save_images_by_bboxes
+from magic_pdf.post_proc.remove_footnote import (
+    merge_footnote_blocks,
+    remove_footnote_blocks,
+)
+from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
+from magic_pdf.pre_proc.equations_replace import (
+    combine_chars_to_pymudict,
+    remove_chars_in_text_blocks,
+    replace_equations_in_textblock,
+)
+from magic_pdf.pre_proc.pdf_pre_filter import pdf_filter
+from magic_pdf.pre_proc.detect_footer_header_by_statistics import drop_footer_header
+from magic_pdf.pre_proc.construct_paras import construct_page_component
+from magic_pdf.pre_proc.fix_image import (
+    combine_images,
+    fix_image_vertical,
+    fix_seperated_image,
+    include_img_title,
+)
+from magic_pdf.post_proc.pdf_post_filter import pdf_post_filter
+from magic_pdf.pre_proc.remove_rotate_bbox import (
+    get_side_boundry,
+    remove_rotate_side_textblock,
+    remove_side_blank_block,
+)
+from magic_pdf.pre_proc.resolve_bbox_conflict import (
+    check_text_block_horizontal_overlap,
+    resolve_bbox_overlap_conflict,
+)
+from magic_pdf.pre_proc.fix_table import (
+    fix_table_text_block,
+    fix_tables,
+    include_table_title,
+)
+from magic_pdf.pre_proc.solve_line_alien import solve_inline_too_large_interval
+
+denseSingleLineBlockException_msg = DenseSingleLineBlockException().message
+titleDetectionException_msg = TitleDetectionException().message
+titleLevelException_msg = TitleLevelException().message
+paraSplitException_msg = ParaSplitException().message
+paraMergeException_msg = ParaMergeException().message
+
+
+def parse_pdf_for_train(
+    s3_pdf_path,
+    s3_pdf_profile,
+    pdf_model_output,
+    save_path,
+    book_name,
+    pdf_model_profile=None,
+    image_s3_config=None,
+    start_page_id=0,
+    end_page_id=None,
+    junk_img_bojids=[],
+    debug_mode=False,
+):
+    pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
+    save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
+    md_bookname_save_path = ""
+    book_name = sanitize_filename(book_name)
+    if debug_mode:
+        save_path = join_path(save_tmp_path, "md")
+        pdf_local_path = join_path(save_tmp_path, "download-pdfs", book_name)
+
+        if not os.path.exists(os.path.dirname(pdf_local_path)):
+            # 如果目录不存在,创建它
+            os.makedirs(os.path.dirname(pdf_local_path))
+
+        md_bookname_save_path = join_path(save_tmp_path, "md", book_name)
+        if not os.path.exists(md_bookname_save_path):
+            # 如果目录不存在,创建它
+            os.makedirs(md_bookname_save_path)
+
+        with open(pdf_local_path + ".pdf", "wb") as pdf_file:
+            pdf_file.write(pdf_bytes)
+
+    pdf_docs = fitz.open("pdf", pdf_bytes)
+    pdf_info_dict = {}
+    img_s3_client = get_img_s3_client(
+        save_path, image_s3_config
+    )  # 更改函数名和参数,避免歧义
+    # img_s3_client = "img_s3_client"  #不创建这个对象,直接用字符串占位
+
+    start_time = time.time()
+
+    """通过统计pdf全篇文字,识别正文字体"""
+    main_text_font = get_main_text_font(pdf_docs)
+
+    end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
+    for page_id in range(start_page_id, end_page_id + 1):
+        page = pdf_docs[page_id]
+        page_width = page.rect.width
+        page_height = page.rect.height
+
+        if debug_mode:
+            time_now = time.time()
+            logger.info(
+                f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
+            )
+            start_time = time_now
+        """
+        # 通过一个规则,过滤掉单页超过1500非junkimg的pdf
+        # 对单页面非重复id的img数量做统计,如果当前页超过1500则直接return need_drop
+        """
+        page_imgs = page.get_images()
+        img_counts = 0
+        for img in page_imgs:
+            img_bojid = img[0]
+            if img_bojid in junk_img_bojids:  # 判断这个图片在不在junklist中
+                continue  # 如果在junklist就不用管了,跳过
+            else:
+                recs = page.get_image_rects(img, transform=True)
+                if recs:  # 如果这张图在当前页面有展示
+                    img_counts += 1
+        if (
+            img_counts >= 1500
+        ):  # 如果去除了junkimg的影响,单页img仍然超过1500的话,就排除当前pdf
+            logger.warning(
+                f"page_id: {page_id}, img_counts: {img_counts}, drop this pdf: {book_name}, drop_reason: {DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS}"
+            )
+            result = {
+                "need_drop": True,
+                "drop_reason": DropReason.HIGH_COMPUTATIONAL_lOAD_BY_IMGS,
+            }
+            if not debug_mode:
+                return result
+
+        """
+        ==================================================================================================================================
+        首先获取基本的block数据,对pdf进行分解,获取图片、表格、公式、text的bbox
+        """
+        # 解析pdf原始文本block
+        text_raw_blocks = page.get_text(
+            "dict",
+            flags=fitz.TEXTFLAGS_TEXT,
+        )["blocks"]
+        model_output_json = get_docx_model_output(
+            pdf_model_output, pdf_model_profile, page_id
+        )
+
+        # 解析图片
+        image_bboxes = parse_images(page_id, page, model_output_json, junk_img_bojids)
+        image_bboxes = fix_image_vertical(
+            image_bboxes, text_raw_blocks
+        )  # 修正图片的位置
+        image_bboxes = fix_seperated_image(image_bboxes)  # 合并有边重合的图片
+
+        old_image_bboxes = deepcopy(image_bboxes)
+        image_bboxes = include_img_title(
+            text_raw_blocks, image_bboxes
+        )  # 向图片上方和下方寻找title,使用规则进行匹配,暂时只支持英文规则
+        """此时image_bboxes中可能出现这种情况,水平并列的2个图片,下方分别有各自的子标题,2个子标题下方又有大标题(形如Figxxx),会出现2个图片的bbox都包含了这个大标题,这种情况需要把图片合并"""
+        image_bboxes = combine_images(image_bboxes)  # 合并图片
+
+        # 解析表格并对table_bboxes进行位置的微调,防止表格周围的文字被截断
+        table_bboxes = parse_tables(page_id, page, model_output_json)
+        table_bboxes = fix_tables(
+            page, table_bboxes, include_table_title=False, scan_line_num=2
+        )  # 修正
+        table_bboxes = fix_table_text_block(
+            text_raw_blocks, table_bboxes
+        )  # 修正与text block的关系,某些table修正与pymupdf获取到的table内textblock没有完全包含,因此要进行一次修正。
+        # debug_show_bbox(pdf_docs, page_id, table_bboxes, [], [b['bbox'] for b in text_raw_blocks], join_path(save_path, book_name, f"{book_name}_debug.pdf"), 7)
+
+        old_table_bboxes = deepcopy(table_bboxes)
+        table_bboxes = include_table_title(
+            text_raw_blocks, table_bboxes
+        )  # 向table上方和下方寻找title,使用规则进行匹配,暂时只支持英文规则
+
+        # 解析公式
+        equations_inline_bboxes, equations_interline_bboxes = parse_equations(
+            page_id, page, model_output_json
+        )
+
+        # get image box and caption !
+        image_bboxes_with_caption = extract_caption_bbox(image_bboxes, old_image_bboxes)
+
+        # get table box and caption !
+        table_bboxes_with_caption = extract_caption_bbox(table_bboxes, old_table_bboxes)
+
+        """
+        ==================================================================================================================================
+        进入预处理-1阶段
+        -------------------
+        # # 解析标题
+        # title_bboxs = parse_titles(page_id, page, model_output_json)
+        # # 评估Layout是否规整、简单
+        # isSimpleLayout_flag, fullColumn_cnt, subColumn_cnt, curPage_loss = evaluate_pdf_layout(page_id, page, model_output_json)
+        接下来开始进行预处理过程
+        """
+        # title_bboxs = parse_titles(page_id, page, model_output_json)
+        
+        """去掉每页的页码、页眉、页脚"""
+        page_no_bboxs = parse_pageNos(page_id, page, model_output_json)
+        header_bboxs = parse_headers(page_id, page, model_output_json)
+        footer_bboxs = parse_footers(page_id, page, model_output_json)
+        (
+            image_bboxes,
+            table_bboxes,
+            remain_text_blocks,
+            removed_hdr_foot_txt_block,
+            removed_hdr_foot_img_block,
+            removed_hdr_foot_table,
+        ) = remove_headder_footer_one_page(
+            text_raw_blocks,
+            image_bboxes,
+            table_bboxes,
+            header_bboxs,
+            footer_bboxs,
+            page_no_bboxs,
+            page_width,
+            page_height,
+        )
+
+        """去除页面上半部分长条色块内的文本块"""
+        remain_text_blocks, removed_colored_narrow_strip_background_text_block = (
+            remove_colored_strip_textblock(remain_text_blocks, page)
+        )
+
+        # debug_show_bbox(pdf_docs, page_id, footnote_bboxes_by_model, [b['bbox'] for b in remain_text_blocks], header_bboxs, join_path(save_path, book_name, f"{book_name}_debug.pdf"), 7)
+
+        """去掉旋转的文字:水印、垂直排列的文字"""
+        remain_text_blocks, removed_non_horz_text_block = remove_rotate_side_textblock(
+            remain_text_blocks, page_width, page_height
+        )  # 去掉水印,非水平文字
+        remain_text_blocks, removed_empty_side_block = remove_side_blank_block(
+            remain_text_blocks, page_width, page_height
+        )  # 删除页面四周可能会留下的完全空白的textblock,这种block形成原因未知
+
+        """出现在图片、表格上的文字块去掉,把层叠的图片单独分离出来,不参与layout的计算"""
+        (
+            image_bboxes,
+            table_bboxes,
+            equations_interline_bboxes,
+            equations_inline_bboxes,
+            remain_text_blocks,
+            text_block_on_image_removed,
+            images_overlap_backup,
+            interline_eq_temp_text_block,
+        ) = resolve_bbox_overlap_conflict(
+            image_bboxes,
+            table_bboxes,
+            equations_interline_bboxes,
+            equations_inline_bboxes,
+            remain_text_blocks,
+        )
+
+        # """去掉footnote, 从文字和图片中"""
+        # # 通过模型识别到的footnote
+        # footnote_bboxes_by_model = parse_footnotes_by_model(page_id, page, model_output_json, md_bookname_save_path,
+        #                                                     debug_mode=debug_mode)
+        # # 通过规则识别到的footnote
+        # footnote_bboxes_by_rule = parse_footnotes_by_rule(remain_text_blocks, page_height, page_id)
+        """
+        ==================================================================================================================================
+        """
+        if debug_mode:  # debugmode截图到本地
+            save_path = join_path(save_tmp_path, "md")
+
+        # 把图、表、公式都进行截图,保存到存储上,返回图片路径作为内容
+        image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info = (
+            save_images_by_bboxes(
+                book_name,
+                page_id,
+                page,
+                save_path,
+                image_bboxes,
+                images_overlap_backup,
+                table_bboxes,
+                equations_inline_bboxes,
+                equations_interline_bboxes,
+                # 传入img_s3_client
+                img_s3_client,
+            )
+        )  # 只要表格和图片的截图
+
+        """"以下进入到公式替换环节 """
+        char_level_text_blocks = page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
+            "blocks"
+        ]
+        remain_text_blocks = combine_chars_to_pymudict(
+            remain_text_blocks, char_level_text_blocks
+        )  # 合并chars
+        remain_text_blocks = replace_equations_in_textblock(
+            remain_text_blocks, inline_eq_info, interline_eq_info
+        )
+        remain_text_blocks = remove_citation_marker(
+            remain_text_blocks
+        )  # 公式替换之后去角标,防止公式无法替换成功。但是这样也会带来个问题就是把角标当公式。各有优劣。
+        remain_text_blocks = remove_chars_in_text_blocks(
+            remain_text_blocks
+        )  # 减少中间态数据体积
+        # debug_show_bbox(pdf_docs, page_id, [b['bbox'] for b in inline_eq_info], [b['bbox'] for b in interline_eq_info], [], join_path(save_path, book_name, f"{book_name}_debug.pdf"), 3)
+
+        """去掉footnote, 从文字和图片中(先去角标再去footnote试试)"""
+        # 通过模型识别到的footnote
+        footnote_bboxes_by_model = parse_footnotes_by_model(
+            page_id,
+            page,
+            model_output_json,
+            md_bookname_save_path,
+            debug_mode=debug_mode,
+        )
+        # 通过规则识别到的footnote
+        footnote_bboxes_by_rule = parse_footnotes_by_rule(
+            remain_text_blocks, page_height, page_id, main_text_font
+        )
+        """进入pdf过滤器,去掉一些不合理的pdf"""
+        is_good_pdf, err = pdf_filter(
+            page, remain_text_blocks, table_bboxes, image_bboxes
+        )
+        if not is_good_pdf:
+            logger.warning(
+                f"page_id: {page_id}, drop this pdf: {book_name}, reason: {err}"
+            )
+            if not debug_mode:
+                return err
+
+        """
+        ==================================================================================================================================
+        进行版面布局切分和过滤
+        """
+        """在切分之前,先检查一下bbox是否有左右重叠的情况,如果有,那么就认为这个pdf暂时没有能力处理好,这种左右重叠的情况大概率是由于pdf里的行间公式、表格没有被正确识别出来造成的 """
+
+        is_text_block_horz_overlap = check_text_block_horizontal_overlap(
+            remain_text_blocks, header_bboxs, footer_bboxs
+        )
+
+        if is_text_block_horz_overlap:
+            # debug_show_bbox(pdf_docs, page_id, [b['bbox'] for b in remain_text_blocks], [], [], join_path(save_path, book_name, f"{book_name}_debug.pdf"), 0)
+            logger.warning(
+                f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TEXT_BLCOK_HOR_OVERLAP}"
+            )
+            result = {
+                "need_drop": True,
+                "drop_reason": DropReason.TEXT_BLCOK_HOR_OVERLAP,
+            }
+            if not debug_mode:
+                return result
+
+        """统一格式化成一个数据结构用于计算layout"""
+        page_y0 = 0 if len(header_bboxs) == 0 else max([b[3] for b in header_bboxs])
+        page_y1 = (
+            page_height if len(footer_bboxs) == 0 else min([b[1] for b in footer_bboxs])
+        )
+        left_x, right_x = get_side_boundry(
+            removed_non_horz_text_block, page_width, page_height
+        )
+        page_boundry = [
+            math.floor(left_x),
+            page_y0 + 1,
+            math.ceil(right_x),
+            page_y1 - 1,
+        ]
+        # 返回的是一个数组,每个元素[x0, y0, x1, y1, block_content, idx_x, idx_y], 初始时候idx_x, idx_y都是None. 对于图片、公式来说,block_content是图片的地址, 对于段落来说,block_content是段落的内容
+
+        all_bboxes = prepare_bboxes_for_layout_split(
+            image_info,
+            image_backup_info,
+            table_info,
+            inline_eq_info,
+            interline_eq_info,
+            remain_text_blocks,
+            page_boundry,
+            page,
+        )
+        # debug_show_bbox(pdf_docs, page_id, [], [], all_bboxes, join_path(save_path, book_name, f"{book_name}_debug.pdf"), 1)
+        """page_y0, page_y1能够过滤掉页眉和页脚,不会算作layout内"""
+        layout_bboxes, layout_tree = get_bboxes_layout(
+            all_bboxes, page_boundry, page_id
+        )
+
+        if (
+            len(remain_text_blocks) > 0
+            and len(all_bboxes) > 0
+            and len(layout_bboxes) == 0
+        ):
+            logger.warning(
+                f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}"
+            )
+            result = {
+                "need_drop": True,
+                "drop_reason": DropReason.CAN_NOT_DETECT_PAGE_LAYOUT,
+            }
+            if not debug_mode:
+                return result
+
+        """以下去掉复杂的布局和超过2列的布局"""
+        if any(
+            [lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]
+        ):  # 复杂的布局
+            logger.warning(
+                f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.COMPLICATED_LAYOUT}"
+            )
+            result = {"need_drop": True, "drop_reason": DropReason.COMPLICATED_LAYOUT}
+            if not debug_mode:
+                return result
+
+        layout_column_width = get_columns_cnt_of_layout(layout_tree)
+        if layout_column_width > 2:  # 去掉超过2列的布局pdf
+            logger.warning(
+                f"page_id: {page_id}, drop this pdf: {book_name}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}"
+            )
+            result = {
+                "need_drop": True,
+                "drop_reason": DropReason.TOO_MANY_LAYOUT_COLUMNS,
+                "extra_info": {"column_cnt": layout_column_width},
+            }
+            if not debug_mode:
+                return result
+
+        """
+        ==================================================================================================================================
+        构造出下游需要的数据结构
+        """
+        remain_text_blocks = (
+            remain_text_blocks + interline_eq_temp_text_block
+        )  # 把计算layout时候临时删除的行间公式再放回去,防止行间公式替换的时候丢失。
+        removed_text_blocks = []
+        removed_text_blocks.extend(removed_hdr_foot_txt_block)
+        # removed_text_blocks.extend(removed_footnote_text_block)
+        removed_text_blocks.extend(text_block_on_image_removed)
+        removed_text_blocks.extend(removed_non_horz_text_block)
+        removed_text_blocks.extend(removed_colored_narrow_strip_background_text_block)
+
+        removed_images = []
+        # removed_images.extend(footnote_imgs)
+        removed_images.extend(removed_hdr_foot_img_block)
+
+        images_backup = []
+        images_backup.extend(image_backup_info)
+        remain_text_blocks = escape_special_markdown_char(
+            remain_text_blocks
+        )  # 转义span里的text
+        sorted_text_remain_text_block = sort_text_block(
+            remain_text_blocks, layout_bboxes
+        )
+
+        footnote_bboxes_tmp = []
+        footnote_bboxes_tmp.extend(footnote_bboxes_by_model)
+        footnote_bboxes_tmp.extend(footnote_bboxes_by_rule)
+
+        page_info = construct_page_component(
+            page_id,
+            image_info,
+            table_info,
+            sorted_text_remain_text_block,
+            layout_bboxes,
+            inline_eq_info,
+            interline_eq_info,
+            page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"],
+            removed_text_blocks=removed_text_blocks,
+            removed_image_blocks=removed_images,
+            images_backup=images_backup,
+            droped_table_block=[],
+            table_backup=[],
+            layout_tree=layout_tree,
+            page_w=page.rect.width,
+            page_h=page.rect.height,
+            footnote_bboxes_tmp=footnote_bboxes_tmp,
+        )
+
+        page_info["image_bboxes_with_caption"] = image_bboxes_with_caption  # add by xr
+        page_info["table_bboxes_with_caption"] = table_bboxes_with_caption
+
+        page_info["bak_page_no_bboxes"] = page_no_bboxs
+        page_info["bak_header_bboxes"] = header_bboxs
+        page_info["bak_footer_bboxes"] = footer_bboxs
+        page_info["bak_footer_note_bboxes"] = footnote_bboxes_tmp
+
+        pdf_info_dict[f"page_{page_id}"] = page_info
+
+    # end page for
+
+    """计算后处理阶段耗时"""
+    start_time = time.time()
+
+    """
+    ==================================================================================================================================
+    去掉页眉和页脚,这里需要用到一定的统计量,所以放到最后
+    页眉和页脚主要从文本box和图片box中去除,位于页面的四周。
+    下面函数会直接修改pdf_info_dict,从文字块中、图片中删除属于页眉页脚的内容,删除内容做相对应记录
+    """
+    # 去页眉页脚
+    header, footer = drop_footer_header(
+        pdf_info_dict
+    )  # TODO: using header and footer boxes here !
+
+    """对单个layout内footnote和他下面的所有textbbox合并"""
+
+    for page_key, page_info in pdf_info_dict.items():
+        page_info = merge_footnote_blocks(page_info, main_text_font)
+        page_info = remove_footnote_blocks(page_info)
+        pdf_info_dict[page_key] = page_info
+
+    """进入pdf后置过滤器,去掉一些不合理的pdf"""
+
+    i = 0
+    for page_info in pdf_info_dict.values():
+        is_good_pdf, err = pdf_post_filter(page_info)
+        if not is_good_pdf:
+            logger.warning(f"page_id: {i}, drop this pdf: {book_name}, reason: {err}")
+            if not debug_mode:
+                return err
+        i += 1
+
+    if debug_mode:
+        params_file_save_path = join_path(
+            save_tmp_path, "md", book_name, "preproc_out.json"
+        )
+        page_draw_rect_save_path = join_path(
+            save_tmp_path, "md", book_name, "layout.pdf"
+        )
+        # dir_path = os.path.dirname(page_draw_rect_save_path)
+        # if not os.path.exists(dir_path):
+        #     # 如果目录不存在,创建它
+        #     os.makedirs(dir_path)
+
+        with open(params_file_save_path, "w", encoding="utf-8") as f:
+            json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
+        # 先检测本地 page_draw_rect_save_path 是否存在,如果存在则删除
+        if os.path.exists(page_draw_rect_save_path):
+            os.remove(page_draw_rect_save_path)
+        # 绘制bbox和layout到pdf
+        draw_bbox_on_page(pdf_docs, pdf_info_dict, page_draw_rect_save_path)
+        draw_layout_bbox_on_page(
+            pdf_docs, pdf_info_dict, header, footer, page_draw_rect_save_path
+        )
+
+    if debug_mode:
+        # 打印后处理阶段耗时
+        logger.info(f"post_processing_time: {get_delta_time(start_time)}")
+
+    """
+    ==================================================================================================================================
+    进入段落处理-2阶段
+    """
+
+    # 处理行内文字间距较大问题
+    pdf_info_dict = solve_inline_too_large_interval(pdf_info_dict)
+
+    start_time = time.time()
+
+    para_process_pipeline = ParaProcessPipeline()
+
+    def _deal_with_text_exception(error_info):
+        logger.warning(
+            f"page_id: {page_id}, drop this pdf: {book_name}, reason: {error_info}"
+        )
+        if error_info == denseSingleLineBlockException_msg:
+            logger.warning(
+                f"Drop this pdf: {book_name}, reason: {DropReason.DENSE_SINGLE_LINE_BLOCK}"
+            )
+            result = {
+                "need_drop": True,
+                "drop_reason": DropReason.DENSE_SINGLE_LINE_BLOCK,
+            }
+            return result
+        if error_info == titleDetectionException_msg:
+            logger.warning(
+                f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_DETECTION_FAILED}"
+            )
+            result = {
+                "need_drop": True,
+                "drop_reason": DropReason.TITLE_DETECTION_FAILED,
+            }
+            return result
+        elif error_info == titleLevelException_msg:
+            logger.warning(
+                f"Drop this pdf: {book_name}, reason: {DropReason.TITLE_LEVEL_FAILED}"
+            )
+            result = {"need_drop": True, "drop_reason": DropReason.TITLE_LEVEL_FAILED}
+            return result
+        elif error_info == paraSplitException_msg:
+            logger.warning(
+                f"Drop this pdf: {book_name}, reason: {DropReason.PARA_SPLIT_FAILED}"
+            )
+            result = {"need_drop": True, "drop_reason": DropReason.PARA_SPLIT_FAILED}
+            return result
+        elif error_info == paraMergeException_msg:
+            logger.warning(
+                f"Drop this pdf: {book_name}, reason: {DropReason.PARA_MERGE_FAILED}"
+            )
+            result = {"need_drop": True, "drop_reason": DropReason.PARA_MERGE_FAILED}
+            return result
+
+    if debug_mode:
+        input_pdf_file = f"{pdf_local_path}.pdf"
+        output_dir = f"{save_path}/{book_name}"
+        output_pdf_file = f"{output_dir}/pdf_annos.pdf"
+
+        """
+        Call the para_process_pipeline function to process the pdf_info_dict.
+        
+        Parameters:
+        para_debug_mode: str or None
+            If para_debug_mode is None, the para_process_pipeline will not keep any intermediate results.
+            If para_debug_mode is "simple", the para_process_pipeline will only keep the annos on the pdf and the final results as a json file.
+            If para_debug_mode is "full", the para_process_pipeline will keep all the intermediate results generated during each step.
+        """
+        pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(
+            pdf_info_dict,
+            para_debug_mode="simple",
+            input_pdf_path=input_pdf_file,
+            output_pdf_path=output_pdf_file,
+        )
+        # 打印段落处理阶段耗时
+        logger.info(f"para_process_time: {get_delta_time(start_time)}")
+
+        # debug的时候不return drop信息
+        if error_info is not None:
+            _deal_with_text_exception(error_info)
+        return pdf_info_dict
+    else:
+        pdf_info_dict, error_info = para_process_pipeline.para_process_pipeline(
+            pdf_info_dict
+        )
+        if error_info is not None:
+            return _deal_with_text_exception(error_info)
+
+    return pdf_info_dict

+ 388 - 145
magic_pdf/pipeline.py

@@ -3,9 +3,16 @@ import sys
 import time
 from urllib.parse import quote
 
-from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_nlp_markdown, ocr_mk_mm_markdown, ocr_mk_mm_standard_format, \
-    ocr_mk_mm_markdown_with_para
-from magic_pdf.libs.commons import read_file, join_path, parse_bucket_key, formatted_time, s3_image_save_path
+from magic_pdf.dict2md.ocr_mkcontent import ocr_mk_mm_markdown, ocr_mk_nlp_markdown_with_para, \
+    ocr_mk_mm_markdown_with_para_and_pagination, ocr_mk_mm_markdown_with_para, ocr_mk_mm_standard_format, \
+    make_standard_format_with_para
+from magic_pdf.libs.commons import (
+    read_file,
+    join_path,
+    parse_bucket_key,
+    formatted_time,
+    s3_image_save_path,
+)
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.json_compressor import JsonCompressor
 from magic_pdf.dict2md.mkcontent import mk_nlp_markdown, mk_universal_format
@@ -14,50 +21,41 @@ from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
 from loguru import logger
 
-from app.common.s3 import get_s3_config, get_s3_client
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
+from magic_pdf.pdf_parse_for_train import parse_pdf_for_train
+from magic_pdf.spark.base import exception_handler, get_data_source
+from magic_pdf.train_utils.convert_to_train_format import convert_to_train_format
+from app.common.s3 import get_s3_config, get_s3_client
 
 
-def exception_handler(jso: dict, e):
-    logger.exception(e)
-    jso['need_drop'] = True
-    jso['drop_reason'] = DropReason.Exception
-    jso['exception'] = f"ERROR: {e}"
-    return jso
-
 
 def get_data_type(jso: dict):
-    data_type = jso.get('data_type')
+    data_type = jso.get("data_type")
     if data_type is None:
-        data_type = jso.get('file_type')
+        data_type = jso.get("file_type")
     return data_type
 
 
 def get_bookid(jso: dict):
-    book_id = jso.get('bookid')
+    book_id = jso.get("bookid")
     if book_id is None:
-        book_id = jso.get('original_file_id')
+        book_id = jso.get("original_file_id")
     return book_id
 
 
-def get_data_source(jso: dict):
-    data_source = jso.get('data_source')
-    if data_source is None:
-        data_source = jso.get('file_source')
-    return data_source
-
-
 def meta_scan(jso: dict, doc_layout_check=True) -> dict:
-    s3_pdf_path = jso.get('file_location')
+    s3_pdf_path = jso.get("file_location")
     s3_config = get_s3_config(s3_pdf_path)
     if doc_layout_check:
-        if 'doc_layout_result' not in jso:  # 检测json中是存在模型数据,如果没有则需要跳过该pdf
-            jso['need_drop'] = True
-            jso['drop_reason'] = DropReason.MISS_DOC_LAYOUT_RESULT
+        if (
+            "doc_layout_result" not in jso
+        ):  # 检测json中是存在模型数据,如果没有则需要跳过该pdf
+            jso["need_drop"] = True
+            jso["drop_reason"] = DropReason.MISS_DOC_LAYOUT_RESULT
             return jso
     try:
         data_source = get_data_source(jso)
-        file_id = jso.get('file_id')
+        file_id = jso.get("file_id")
         book_name = f"{data_source}/{file_id}"
 
         # 首页存在超量drawing问题
@@ -68,90 +66,111 @@ def meta_scan(jso: dict, doc_layout_check=True) -> dict:
         #     return jso
 
         start_time = time.time()  # 记录开始时间
-        logger.info(f"book_name is:{book_name},start_time is:{formatted_time(start_time)}", file=sys.stderr)
+        logger.info(
+            f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
+            file=sys.stderr,
+        )
         file_content = read_file(s3_pdf_path, s3_config)
         read_file_time = int(time.time() - start_time)  # 计算执行时间
 
         start_time = time.time()  # 记录开始时间
         res = pdf_meta_scan(s3_pdf_path, file_content)
-        if res.get('need_drop', False):  # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
-            jso['need_drop'] = True
-            jso['drop_reason'] = res["drop_reason"]
+        if res.get(
+            "need_drop", False
+        ):  # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
+            jso["need_drop"] = True
+            jso["drop_reason"] = res["drop_reason"]
         else:  # 正常返回
-            jso['pdf_meta'] = res
-            jso['content'] = ""
-            jso['remark'] = ""
-            jso['data_url'] = ""
+            jso["pdf_meta"] = res
+            jso["content"] = ""
+            jso["remark"] = ""
+            jso["data_url"] = ""
         end_time = time.time()  # 记录结束时间
         meta_scan_time = int(end_time - start_time)  # 计算执行时间
-        logger.info(f"book_name is:{book_name},end_time is:{formatted_time(end_time)},read_file_time is:{read_file_time},meta_scan_time is:{meta_scan_time}", file=sys.stderr)
-        jso['read_file_time'] = read_file_time
-        jso['meta_scan_time'] = meta_scan_time
+        logger.info(
+            f"book_name is:{book_name},end_time is:{formatted_time(end_time)},read_file_time is:{read_file_time},meta_scan_time is:{meta_scan_time}",
+            file=sys.stderr,
+        )
+        jso["read_file_time"] = read_file_time
+        jso["meta_scan_time"] = meta_scan_time
     except Exception as e:
         jso = exception_handler(jso, e)
     return jso
 
 
 def classify_by_type(jso: dict, debug_mode=False) -> dict:
-    #检测debug开关
+    # 检测debug开关
     if debug_mode:
         pass
-    else:# 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
             return jso
     # 开始正式逻辑
     try:
-        pdf_meta = jso.get('pdf_meta')
+        pdf_meta = jso.get("pdf_meta")
         data_source = get_data_source(jso)
-        file_id = jso.get('file_id')
+        file_id = jso.get("file_id")
         book_name = f"{data_source}/{file_id}"
         total_page = pdf_meta["total_page"]
         page_width = pdf_meta["page_width_pts"]
         page_height = pdf_meta["page_height_pts"]
         img_sz_list = pdf_meta["image_info_per_page"]
-        img_num_list = pdf_meta['imgs_per_page']
-        text_len_list = pdf_meta['text_len_per_page']
-        text_layout_list = pdf_meta['text_layout_per_page']
-        text_language = pdf_meta['text_language']
+        img_num_list = pdf_meta["imgs_per_page"]
+        text_len_list = pdf_meta["text_len_per_page"]
+        text_layout_list = pdf_meta["text_layout_per_page"]
+        text_language = pdf_meta["text_language"]
         # allow_language = ['zh', 'en']  # 允许的语言,目前只允许简中和英文的
 
         # if text_language not in allow_language:  # 如果语言不在允许的语言中,则drop
         #     jso['need_drop'] = True
         #     jso['drop_reason'] = DropReason.NOT_ALLOW_LANGUAGE
         #     return jso
-        pdf_path = pdf_meta['pdf_path']
-        is_encrypted = pdf_meta['is_encrypted']
-        is_needs_password = pdf_meta['is_needs_password']
-        if is_encrypted or is_needs_password:  # 加密的,需要密码的,没有页面的,都不处理
-            jso['need_drop'] = True
-            jso['drop_reason'] = DropReason.ENCRYPTED
+        pdf_path = pdf_meta["pdf_path"]
+        is_encrypted = pdf_meta["is_encrypted"]
+        is_needs_password = pdf_meta["is_needs_password"]
+        if (
+            is_encrypted or is_needs_password
+        ):  # 加密的,需要密码的,没有页面的,都不处理
+            jso["need_drop"] = True
+            jso["drop_reason"] = DropReason.ENCRYPTED
         else:
             start_time = time.time()  # 记录开始时间
-            is_text_pdf, results = classify(pdf_path, total_page, page_width, page_height, img_sz_list, text_len_list, img_num_list, text_layout_list)
+            is_text_pdf, results = classify(
+                pdf_path,
+                total_page,
+                page_width,
+                page_height,
+                img_sz_list,
+                text_len_list,
+                img_num_list,
+                text_layout_list,
+            )
             classify_time = int(time.time() - start_time)  # 计算执行时间
             if is_text_pdf:
-                pdf_meta['is_text_pdf'] = is_text_pdf
-                jso['pdf_meta'] = pdf_meta
-                jso['classify_time'] = classify_time
+                pdf_meta["is_text_pdf"] = is_text_pdf
+                jso["pdf_meta"] = pdf_meta
+                jso["classify_time"] = classify_time
                 # print(json.dumps(pdf_meta, ensure_ascii=False))
 
-                allow_language = ['zh', 'en']  # 允许的语言,目前只允许简中和英文的
-                if text_language not in allow_language:  # 如果语言不在允许的语言中,则drop
-                    jso['need_drop'] = True
-                    jso['drop_reason'] = DropReason.NOT_ALLOW_LANGUAGE
+                allow_language = ["zh", "en"]  # 允许的语言,目前只允许简中和英文的
+                if (
+                    text_language not in allow_language
+                ):  # 如果语言不在允许的语言中,则drop
+                    jso["need_drop"] = True
+                    jso["drop_reason"] = DropReason.NOT_ALLOW_LANGUAGE
                     return jso
             else:
                 # 先不drop
-                pdf_meta['is_text_pdf'] = is_text_pdf
-                jso['pdf_meta'] = pdf_meta
-                jso['classify_time'] = classify_time
-                jso['need_drop'] = True
-                jso['drop_reason'] = DropReason.NOT_IS_TEXT_PDF
+                pdf_meta["is_text_pdf"] = is_text_pdf
+                jso["pdf_meta"] = pdf_meta
+                jso["classify_time"] = classify_time
+                jso["need_drop"] = True
+                jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF
                 extra_info = {"classify_rules": []}
                 for condition, result in results.items():
                     if not result:
                         extra_info["classify_rules"].append(condition)
-                jso['extra_info'] = extra_info
+                jso["extra_info"] = extra_info
 
     except Exception as e:
         jso = exception_handler(jso, e)
@@ -162,48 +181,69 @@ def save_tables_to_s3(jso: dict, debug_mode=False) -> dict:
 
     if debug_mode:
         pass
-    else:# 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
-            logger.info(f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop", file=sys.stderr)
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
+            logger.info(
+                f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop",
+                file=sys.stderr,
+            )
             jso["dropped"] = True
             return jso
     try:
         data_source = get_data_source(jso)
-        file_id = jso.get('file_id')
+        file_id = jso.get("file_id")
         book_name = f"{data_source}/{file_id}"
-        title = jso.get('title')
-        url_encode_title = quote(title, safe='')
-        if data_source != 'scihub':
+        title = jso.get("title")
+        url_encode_title = quote(title, safe="")
+        if data_source != "scihub":
             return jso
-        pdf_intermediate_dict = jso['pdf_intermediate_dict']
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
         # 将 pdf_intermediate_dict 解压
         pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
         i = 0
         for page in pdf_intermediate_dict.values():
-            if page.get('tables'):
-                if len(page['tables']) > 0:
+            if page.get("tables"):
+                if len(page["tables"]) > 0:
                     j = 0
-                    for table in page['tables']:
+                    for table in page["tables"]:
                         if debug_mode:
-                            image_path = join_path("s3://mllm-raw-media/pdf2md_img/", book_name, table['image_path'])
+                            image_path = join_path(
+                                "s3://mllm-raw-media/pdf2md_img/",
+                                book_name,
+                                table["image_path"],
+                            )
                         else:
-                            image_path = join_path("s3://mllm-raw-media/pdf2md_img/", table['image_path'])
+                            image_path = join_path(
+                                "s3://mllm-raw-media/pdf2md_img/", table["image_path"]
+                            )
 
-                        if image_path.endswith('.jpg'):
+                        if image_path.endswith(".jpg"):
                             j += 1
                             s3_client = get_s3_client(image_path)
                             bucket_name, bucket_key = parse_bucket_key(image_path)
                             # 通过s3_client获取图片到内存
-                            image_bytes = s3_client.get_object(Bucket=bucket_name, Key=bucket_key)['Body'].read()
+                            image_bytes = s3_client.get_object(
+                                Bucket=bucket_name, Key=bucket_key
+                            )["Body"].read()
                             # 保存图片到新的位置
                             if debug_mode:
-                                new_image_path = join_path("s3://mllm-raw-media/pdf2md_img/table_new/", url_encode_title + "_" + table['image_path'].lstrip('tables/'))
+                                new_image_path = join_path(
+                                    "s3://mllm-raw-media/pdf2md_img/table_new/",
+                                    url_encode_title
+                                    + "_"
+                                    + table["image_path"].lstrip("tables/"),
+                                )
                             else:
-                                new_image_path = join_path("s3://mllm-raw-media/pdf2md_img/table_new/", url_encode_title + f"_page{i}_{j}.jpg")
+                                new_image_path = join_path(
+                                    "s3://mllm-raw-media/pdf2md_img/table_new/",
+                                    url_encode_title + f"_page{i}_{j}.jpg",
+                                )
 
                             logger.info(new_image_path, file=sys.stderr)
                             bucket_name, bucket_key = parse_bucket_key(new_image_path)
-                            s3_client.put_object(Bucket=bucket_name, Key=bucket_key, Body=image_bytes)
+                            s3_client.put_object(
+                                Bucket=bucket_name, Key=bucket_key, Body=image_bytes
+                            )
                         else:
                             continue
             i += 1
@@ -218,8 +258,11 @@ def save_tables_to_s3(jso: dict, debug_mode=False) -> dict:
 
 
 def drop_needdrop_pdf(jso: dict) -> dict:
-    if jso.get('need_drop', False):
-        logger.info(f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop", file=sys.stderr)
+    if jso.get("need_drop", False):
+        logger.info(
+            f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop",
+            file=sys.stderr,
+        )
         jso["dropped"] = True
     return jso
 
@@ -228,19 +271,19 @@ def pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
 
     if debug_mode:
         pass
-    else:# 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
-            book_name = join_path(get_data_source(jso), jso['file_id'])
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
             logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
             jso["dropped"] = True
             return jso
     try:
-        pdf_intermediate_dict = jso['pdf_intermediate_dict']
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
         # 将 pdf_intermediate_dict 解压
         pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
-        #markdown_content = mk_nlp_markdown(pdf_intermediate_dict)
-        jso['content_list'] = mk_universal_format(pdf_intermediate_dict)
-        #jso["content"] = markdown_content
+        # markdown_content = mk_nlp_markdown(pdf_intermediate_dict)
+        jso["content_list"] = mk_universal_format(pdf_intermediate_dict)
+        # jso["content"] = markdown_content
         logger.info(f"book_name is:{get_data_source(jso)}/{jso['file_id']}")
         # 把无用的信息清空
         jso["doc_layout_result"] = ""
@@ -252,18 +295,18 @@ def pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
 
 
 def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
-    #检测debug开关
+    # 检测debug开关
     if debug_mode:
         pass
-    else:# 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
             return jso
     # 开始正式逻辑
-    s3_pdf_path = jso.get('file_location')
+    s3_pdf_path = jso.get("file_location")
     s3_config = get_s3_config(s3_pdf_path)
-    model_output_json_list = jso.get('doc_layout_result')
+    model_output_json_list = jso.get("doc_layout_result")
     data_source = get_data_source(jso)
-    file_id = jso.get('file_id')
+    file_id = jso.get("file_id")
     book_name = f"{data_source}/{file_id}"
 
     # 1.23.22已修复
@@ -275,15 +318,15 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
     #         jso['drop_reason'] = DropReason.SPECIAL_PDF
     #         return jso
 
-    junk_img_bojids = jso['pdf_meta']['junk_img_bojids']
+    junk_img_bojids = jso["pdf_meta"]["junk_img_bojids"]
     # total_page = jso['pdf_meta']['total_page']
 
     # 增加检测 max_svgs 数量的检测逻辑,如果 max_svgs 超过3000则drop
-    svgs_per_page_list = jso['pdf_meta']['svgs_per_page']
+    svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
     max_svgs = max(svgs_per_page_list)
     if max_svgs > 3000:
-        jso['need_drop'] = True
-        jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
+        jso["need_drop"] = True
+        jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
     # elif total_page > 1000:
     #     jso['need_drop'] = True
     #     jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
@@ -293,44 +336,144 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
             image_s3_config = get_s3_config(save_path)
             start_time = time.time()  # 记录开始时间
             # 先打印一下book_name和解析开始的时间
-            logger.info(f"book_name is:{book_name},start_time is:{formatted_time(start_time)}", file=sys.stderr)
-            pdf_info_dict = parse_pdf_by_model(s3_pdf_path, s3_config, model_output_json_list, save_path,
-                                                  book_name, pdf_model_profile=None,
-                                                  image_s3_config=image_s3_config,
-                                                  start_page_id=start_page_id, junk_img_bojids=junk_img_bojids,
-                                                  debug_mode=debug_mode)
-            if pdf_info_dict.get('need_drop', False):  # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
-                jso['need_drop'] = True
-                jso['drop_reason'] = pdf_info_dict["drop_reason"]
+            logger.info(
+                f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
+                file=sys.stderr,
+            )
+            pdf_info_dict = parse_pdf_by_model(
+                s3_pdf_path,
+                s3_config,
+                model_output_json_list,
+                save_path,
+                book_name,
+                pdf_model_profile=None,
+                image_s3_config=image_s3_config,
+                start_page_id=start_page_id,
+                junk_img_bojids=junk_img_bojids,
+                debug_mode=debug_mode,
+            )
+            if pdf_info_dict.get(
+                "need_drop", False
+            ):  # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
+                jso["need_drop"] = True
+                jso["drop_reason"] = pdf_info_dict["drop_reason"]
             else:  # 正常返回,将 pdf_info_dict 压缩并存储
                 pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
-                jso['pdf_intermediate_dict'] = pdf_info_dict
+                jso["pdf_intermediate_dict"] = pdf_info_dict
             end_time = time.time()  # 记录完成时间
             parse_time = int(end_time - start_time)  # 计算执行时间
             # 解析完成后打印一下book_name和耗时
-            logger.info(f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}", file=sys.stderr)
-            jso['parse_time'] = parse_time
+            logger.info(
+                f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
+                file=sys.stderr,
+            )
+            jso["parse_time"] = parse_time
         except Exception as e:
             jso = exception_handler(jso, e)
     return jso
 
-'''
+
+"""
 统一处理逻辑
 1.先调用parse_pdf对文本类pdf进行处理
 2.再调用ocr_dropped_parse_pdf,对之前drop的pdf进行处理
-'''
+"""
+
+
 def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
     jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
     jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
     return jso
 
+def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> dict:
+    # 检测debug开关
+    if debug_mode:
+        pass
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
+            return jso
+    # 开始正式逻辑
+    s3_pdf_path = jso.get("file_location")
+    s3_config = get_s3_config(s3_pdf_path)
+    model_output_json_list = jso.get("doc_layout_result")
+    data_source = get_data_source(jso)
+    file_id = jso.get("file_id")
+    book_name = f"{data_source}/{file_id}"
+
+    # 1.23.22已修复
+    # if debug_mode:
+    #     pass
+    # else:
+    #     if book_name == "zlib/zlib_21929367":
+    #         jso['need_drop'] = True
+    #         jso['drop_reason'] = DropReason.SPECIAL_PDF
+    #         return jso
+
+    junk_img_bojids = jso["pdf_meta"]["junk_img_bojids"]
+    # total_page = jso['pdf_meta']['total_page']
+
+    # 增加检测 max_svgs 数量的检测逻辑,如果 max_svgs 超过3000则drop
+    svgs_per_page_list = jso["pdf_meta"]["svgs_per_page"]
+    max_svgs = max(svgs_per_page_list)
+    if max_svgs > 3000:
+        jso["need_drop"] = True
+        jso["drop_reason"] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_SVGS
+    # elif total_page > 1000:
+    #     jso['need_drop'] = True
+    #     jso['drop_reason'] = DropReason.HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
+    else:
+        try:
+            save_path = s3_image_save_path
+            image_s3_config = get_s3_config(save_path)
+            start_time = time.time()  # 记录开始时间
+            # 先打印一下book_name和解析开始的时间
+            logger.info(
+                f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
+                file=sys.stderr,
+            )
+            pdf_info_dict = parse_pdf_for_train(
+                s3_pdf_path,
+                s3_config,
+                model_output_json_list,
+                save_path,
+                book_name,
+                pdf_model_profile=None,
+                image_s3_config=image_s3_config,
+                start_page_id=start_page_id,
+                junk_img_bojids=junk_img_bojids,
+                debug_mode=debug_mode,
+            )
+            if pdf_info_dict.get(
+                "need_drop", False
+            ):  # 如果返回的字典里有need_drop,则提取drop_reason并跳过本次解析
+                jso["need_drop"] = True
+                jso["drop_reason"] = pdf_info_dict["drop_reason"]
+            else:  # 正常返回,将 pdf_info_dict 压缩并存储
+                jso["parsed_results"] = convert_to_train_format(pdf_info_dict)
+                pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
+                jso["pdf_intermediate_dict"] = pdf_info_dict
+            end_time = time.time()  # 记录完成时间
+            parse_time = int(end_time - start_time)  # 计算执行时间
+            # 解析完成后打印一下book_name和耗时
+            logger.info(
+                f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
+                file=sys.stderr,
+            )
+            jso["parse_time"] = parse_time
+        except Exception as e:
+            jso = exception_handler(jso, e)
+    return jso
+
+
 # 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false
 def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
-    if not jso.get('need_drop', False):
+    if not jso.get("need_drop", False):
         return jso
     else:
-        jso = ocr_parse_pdf_core(jso, start_page_id=start_page_id, debug_mode=debug_mode)
-        jso['need_drop'] = False
+        jso = ocr_parse_pdf_core(
+            jso, start_page_id=start_page_id, debug_mode=debug_mode
+        )
+        jso["need_drop"] = False
         return jso
 
 
@@ -339,7 +482,7 @@ def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
     if debug_mode:
         pass
     else:  # 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
+        if jso.get("need_drop", False):
             return jso
 
     jso = ocr_parse_pdf_core(jso, start_page_id=start_page_id, debug_mode=debug_mode)
@@ -347,18 +490,21 @@ def ocr_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
 
 
 def ocr_parse_pdf_core(jso: dict, start_page_id=0, debug_mode=False) -> dict:
-    s3_pdf_path = jso.get('file_location')
+    s3_pdf_path = jso.get("file_location")
     s3_config = get_s3_config(s3_pdf_path)
-    model_output_json_list = jso.get('doc_layout_result')
+    model_output_json_list = jso.get("doc_layout_result")
     data_source = get_data_source(jso)
-    file_id = jso.get('file_id')
+    file_id = jso.get("file_id")
     book_name = f"{data_source}/{file_id}"
     try:
         save_path = s3_image_save_path
         image_s3_config = get_s3_config(save_path)
         start_time = time.time()  # 记录开始时间
         # 先打印一下book_name和解析开始的时间
-        logger.info(f"book_name is:{book_name},start_time is:{formatted_time(start_time)}", file=sys.stderr)
+        logger.info(
+            f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
+            file=sys.stderr,
+        )
         pdf_info_dict = parse_pdf_by_ocr(
             s3_pdf_path,
             s3_config,
@@ -368,37 +514,42 @@ def ocr_parse_pdf_core(jso: dict, start_page_id=0, debug_mode=False) -> dict:
             pdf_model_profile=None,
             image_s3_config=image_s3_config,
             start_page_id=start_page_id,
-            debug_mode=debug_mode
+            debug_mode=debug_mode,
         )
         pdf_info_dict = JsonCompressor.compress_json(pdf_info_dict)
-        jso['pdf_intermediate_dict'] = pdf_info_dict
+        jso["pdf_intermediate_dict"] = pdf_info_dict
         end_time = time.time()  # 记录完成时间
         parse_time = int(end_time - start_time)  # 计算执行时间
         # 解析完成后打印一下book_name和耗时
-        logger.info(f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}", file=sys.stderr)
-        jso['parse_time'] = parse_time
+        logger.info(
+            f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
+            file=sys.stderr,
+        )
+        jso["parse_time"] = parse_time
     except Exception as e:
         jso = exception_handler(jso, e)
     return jso
 
 
 def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
-
     if debug_mode:
         pass
     else:  # 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
-            book_name = join_path(get_data_source(jso), jso['file_id'])
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
             logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
             jso["dropped"] = True
             return jso
     try:
-        pdf_intermediate_dict = jso['pdf_intermediate_dict']
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
         # 将 pdf_intermediate_dict 解压
         pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
         markdown_content = ocr_mk_mm_markdown(pdf_intermediate_dict)
         jso["content"] = markdown_content
-        logger.info(f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}", file=sys.stderr)
+        logger.info(
+            f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
+            file=sys.stderr,
+        )
         # 把无用的信息清空
         jso["doc_layout_result"] = ""
         jso["pdf_intermediate_dict"] = ""
@@ -408,26 +559,88 @@ def ocr_pdf_intermediate_dict_to_markdown(jso: dict, debug_mode=False) -> dict:
     return jso
 
 
-def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa(jso: dict, debug_mode=False) -> dict:
+def ocr_pdf_intermediate_dict_to_markdown_with_para(jso: dict, debug_mode=False) -> dict:
+    if debug_mode:
+        pass
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
+            logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
+            jso["dropped"] = True
+            return jso
+    try:
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
+        # 将 pdf_intermediate_dict 解压
+        pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
+        # markdown_content = ocr_mk_mm_markdown_with_para(pdf_intermediate_dict)
+        markdown_content = ocr_mk_nlp_markdown_with_para(pdf_intermediate_dict)
+        jso["content"] = markdown_content
+        logger.info(
+            f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
+            file=sys.stderr,
+        )
+        # 把无用的信息清空
+        jso["doc_layout_result"] = ""
+        jso["pdf_intermediate_dict"] = ""
+        jso["pdf_meta"] = ""
+    except Exception as e:
+        jso = exception_handler(jso, e)
+    return jso
+
 
+def ocr_pdf_intermediate_dict_to_markdown_with_para_and_pagination(jso: dict, debug_mode=False) -> dict:
     if debug_mode:
         pass
     else:  # 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
-            book_name = join_path(get_data_source(jso), jso['file_id'])
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
             logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
             jso["dropped"] = True
             return jso
     try:
-        pdf_intermediate_dict = jso['pdf_intermediate_dict']
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
+        # 将 pdf_intermediate_dict 解压
+        pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
+        markdown_content = ocr_mk_mm_markdown_with_para_and_pagination(pdf_intermediate_dict)
+        jso["content"] = markdown_content
+        logger.info(
+            f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
+            file=sys.stderr,
+        )
+        # 把无用的信息清空
+        # jso["doc_layout_result"] = ""
+        jso["pdf_intermediate_dict"] = ""
+        # jso["pdf_meta"] = ""
+    except Exception as e:
+        jso = exception_handler(jso, e)
+    return jso
+
+
+def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa(
+        jso: dict, debug_mode=False
+) -> dict:
+    if debug_mode:
+        pass
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
+            logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
+            jso["dropped"] = True
+            return jso
+    try:
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
         # 将 pdf_intermediate_dict 解压
         pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
         markdown_content = ocr_mk_mm_markdown_with_para(pdf_intermediate_dict)
         jso["content_ocr"] = markdown_content
-        logger.info(f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}", file=sys.stderr)
+        logger.info(
+            f"book_name is:{get_data_source(jso)}/{jso['file_id']},markdown content length is {len(markdown_content)}",
+            file=sys.stderr,
+        )
         # 把无用的信息清空
         jso["doc_layout_result"] = ""
         jso["pdf_intermediate_dict"] = ""
+        jso["mid_json_ocr"] = pdf_intermediate_dict
         jso["pdf_meta"] = ""
     except Exception as e:
         jso = exception_handler(jso, e)
@@ -435,22 +648,52 @@ def ocr_pdf_intermediate_dict_to_markdown_with_para_for_qa(jso: dict, debug_mode
 
 
 def ocr_pdf_intermediate_dict_to_standard_format(jso: dict, debug_mode=False) -> dict:
-
     if debug_mode:
         pass
     else:  # 如果debug没开,则检测是否有needdrop字段
-        if jso.get('need_drop', False):
-            book_name = join_path(get_data_source(jso), jso['file_id'])
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
             logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
             jso["dropped"] = True
             return jso
     try:
-        pdf_intermediate_dict = jso['pdf_intermediate_dict']
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
         # 将 pdf_intermediate_dict 解压
         pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
         standard_format = ocr_mk_mm_standard_format(pdf_intermediate_dict)
         jso["content_list"] = standard_format
-        logger.info(f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}", file=sys.stderr)
+        logger.info(
+            f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}",
+            file=sys.stderr,
+        )
+        # 把无用的信息清空
+        jso["doc_layout_result"] = ""
+        jso["pdf_intermediate_dict"] = ""
+        jso["pdf_meta"] = ""
+    except Exception as e:
+        jso = exception_handler(jso, e)
+    return jso
+
+
+def ocr_pdf_intermediate_dict_to_standard_format_with_para(jso: dict, debug_mode=False) -> dict:
+    if debug_mode:
+        pass
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
+            logger.info(f"book_name is:{book_name} need drop", file=sys.stderr)
+            jso["dropped"] = True
+            return jso
+    try:
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
+        # 将 pdf_intermediate_dict 解压
+        pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
+        standard_format = make_standard_format_with_para(pdf_intermediate_dict)
+        jso["content_list"] = standard_format
+        logger.info(
+            f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}",
+            file=sys.stderr,
+        )
         # 把无用的信息清空
         jso["doc_layout_result"] = ""
         jso["pdf_intermediate_dict"] = ""

+ 37 - 0
magic_pdf/pipeline_txt.py

@@ -0,0 +1,37 @@
+"""
+文本型pdf转化为统一清洗格式
+"""
+
+
+
+from loguru import logger
+from magic_pdf.dict2md.mkcontent import mk_universal_format
+from magic_pdf.libs.commons import join_path
+from magic_pdf.libs.json_compressor import JsonCompressor
+from magic_pdf.spark.base import exception_handler, get_data_source
+
+
+def txt_pdf_to_standard_format(jso: dict, debug_mode=False) -> dict:
+
+    if debug_mode:
+        pass
+    else:  # 如果debug没开,则检测是否有needdrop字段
+        if jso.get("need_drop", False):
+            book_name = join_path(get_data_source(jso), jso["file_id"])
+            logger.info(f"book_name is:{book_name} need drop")
+            jso["dropped"] = True
+            return jso
+    try:
+        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
+        # 将 pdf_intermediate_dict 解压
+        pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
+        standard_format = mk_universal_format(pdf_intermediate_dict)
+        jso["content_list"] = standard_format
+        logger.info(f"book_name is:{get_data_source(jso)}/{jso['file_id']},content_list length is {len(standard_format)}",)
+        # 把无用的信息清空
+        jso["doc_layout_result"] = ""
+        jso["pdf_intermediate_dict"] = ""
+        jso["pdf_meta"] = ""
+    except Exception as e:
+        jso = exception_handler(jso, e)
+    return jso

+ 1 - 1
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -66,7 +66,7 @@ def merge_spans_to_line_by_layout(spans, layout_bboxes):
         # 遍历spans,将每个span放入对应的layout中
         layout_sapns = []
         for span in spans:
-            if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], layout_bbox) > 0.65:
+            if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], layout_bbox) > 0.6:
                 layout_sapns.append(span)
         # 如果layout_sapns不为空,则放入new_spans中
         if len(layout_sapns) > 0:

+ 5 - 0
magic_pdf/pre_proc/ocr_span_list_modify.py

@@ -44,10 +44,15 @@ def remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict):
         # logger.info(f"remove spans by bbox dict, drop_tag: {drop_tag}, removed_bboxes: {removed_bboxes}")
         need_remove_spans = []
         for span in spans:
+            # 通过判断span的bbox是否在removed_bboxes中, 判断是否需要删除该span
             for removed_bbox in removed_bboxes:
                 if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], removed_bbox) > 0.5:
                     need_remove_spans.append(span)
                     break
+                # 当drop_tag为DropTag.FOOTNOTE时, 判断span是否在removed_bboxes中任意一个的下方,如果是,则删除该span
+                elif drop_tag == DropTag.FOOTNOTE and (span['bbox'][1]+span['bbox'][3])/2 > removed_bbox[3] and removed_bbox[0] < (span['bbox'][0]+span['bbox'][2])/2 < removed_bbox[2]:
+                    need_remove_spans.append(span)
+                    break
 
         for span in need_remove_spans:
             spans.remove(span)

+ 21 - 0
magic_pdf/spark/base.py

@@ -0,0 +1,21 @@
+
+
+from loguru import logger
+
+from magic_pdf.libs.drop_reason import DropReason
+
+
+def get_data_source(jso: dict):
+    data_source = jso.get("data_source")
+    if data_source is None:
+        data_source = jso.get("file_source")
+    return data_source
+
+
+def exception_handler(jso: dict, e):
+    logger.exception(e)
+    jso["need_drop"] = True
+    jso["drop_reason"] = DropReason.Exception
+    jso["exception"] = f"ERROR: {e}"
+    return jso
+

+ 0 - 0
magic_pdf/train_utils/__init__.py


+ 65 - 0
magic_pdf/train_utils/convert_to_train_format.py

@@ -0,0 +1,65 @@
+def convert_to_train_format(jso: dict) -> []:
+    pages = []
+    for k, v in jso.items():
+        if not k.startswith("page_"):
+            continue
+        page_idx = v["page_idx"]
+        width, height = v["page_size"]
+
+        info = {"page_info": {"page_no": page_idx, "height": height, "width": width}}
+
+        bboxes: list[dict] = []
+        for img_bbox in v["image_bboxes_with_caption"]:
+            bbox = {"category_id": 1, "bbox": img_bbox["bbox"]}
+            if "caption" in img_bbox:
+                bbox["caption_bbox"] = img_bbox["caption"]
+            bboxes.append(bbox)
+
+        for tbl_bbox in v["table_bboxes_with_caption"]:
+            bbox = {"category_id": 7, "bbox": tbl_bbox["bbox"]}
+            if "caption" in tbl_bbox:
+                bbox["caption_bbox"] = tbl_bbox["caption"]
+            bboxes.append(bbox)
+
+        for bbox in v["bak_page_no_bboxes"]:
+            n_bbox = {"category_id": 4, "bbox": bbox}
+            bboxes.append(n_bbox)
+
+        for bbox in v["bak_header_bboxes"]:
+            n_bbox = {"category_id": 3, "bbox": bbox}
+            bboxes.append(n_bbox)
+
+        for bbox in v["bak_footer_bboxes"]:
+            n_bbox = {"category_id": 6, "bbox": bbox}
+            bboxes.append(n_bbox)
+
+        # 脚注, 目前没有看到例子
+        for para in v["para_blocks"]:
+            if "paras" in para:
+                paras = para["paras"]
+                for para_key, para_content in paras.items():
+                    para_bbox = para_content["para_bbox"]
+                    is_para_title = para_content["is_para_title"]
+                    if is_para_title:
+                        n_bbox = {"category_id": 0, "bbox": para_bbox}
+                    else:
+                        n_bbox = {"category_id": 2, "bbox": para_bbox}
+                    bboxes.append(n_bbox)
+
+        for inline_equation in v["inline_equations"]:
+            n_bbox = {"category_id": 13, "bbox": inline_equation["bbox"]}
+            bboxes.append(n_bbox)
+
+        for inter_equation in v["interline_equations"]:
+            n_bbox = {"category_id": 10, "bbox": inter_equation["bbox"]}
+            bboxes.append(n_bbox)
+
+        for footnote_bbox in v["bak_footer_note_bboxes"]:
+            n_bbox = {"category_id": 5, "bbox": list(footnote_bbox)}
+            bboxes.append(n_bbox)
+
+        info["bboxes"] = bboxes
+        info["layout_tree"] = v["layout_bboxes"]
+        pages.append(info)
+
+    return pages

+ 59 - 0
magic_pdf/train_utils/extract_caption.py

@@ -0,0 +1,59 @@
+from magic_pdf.libs.boxbase import _is_in
+
+
+def extract_caption_bbox(outer: list, inner: list) -> list:
+    """
+    ret: list of {
+                    "bbox": [1,2,3,4],
+                    "caption": [5,6,7,8] # may existed
+                }
+
+    """
+    found_count = 0  # for debug
+    print(outer, inner)
+
+    def is_float_equal(a, b):
+        if 0.01 > abs(a - b):  # non strict float equal compare
+            return True
+        return False
+
+    outer_h = {i: outer[i] for i in range(len(outer))}
+    ret = []
+    for v in inner:
+        ix0, iy0, ix1, iy1 = v
+        found_idx = None
+        d = {"bbox": v[:4]}
+        for k in outer_h:
+            ox0, oy0, ox1, oy1 = outer_h[k]
+            equal_float_flags = [
+                is_float_equal(ix0, ox0),
+                is_float_equal(iy0, oy0),
+                is_float_equal(ix1, ox1),
+                is_float_equal(iy1, oy1),
+            ]
+            if _is_in(v, outer_h[k]) and not all(equal_float_flags):
+                found_idx = k
+                break
+        if found_idx is not None:
+            found_count += 1
+            captions: list[list] = []
+            ox0, oy0, ox1, oy1 = outer_h[found_idx]
+            captions = [
+                [ox0, oy0, ix0, oy1],
+                [ox0, oy0, ox1, iy0],
+                [ox0, iy1, ox1, oy1],
+                [ix1, oy0, ox1, oy1],
+            ]
+            captions = sorted(
+                captions,
+                key=lambda rect: abs(rect[0] - rect[2]) * abs(rect[1] - rect[3]),
+            )  # 面积最大的框就是caption
+            d["caption"] = captions[-1]
+            outer_h.pop(
+                found_idx
+            )  # 同一个 outer box 只能用于确定一个 inner box 的 caption 位置。
+
+        ret.append(d)
+
+    print("found_count: ", found_count)
+    return ret

+ 159 - 0
magic_pdf/train_utils/remove_footer_header.py

@@ -0,0 +1,159 @@
+import re
+
+from magic_pdf.libs.boxbase import _is_in_or_part_overlap
+from magic_pdf.libs.drop_tag import CONTENT_IN_FOOT_OR_HEADER, PAGE_NO
+
+
+"""
+    copy from pre_proc/remove_footer_header.py
+"""
+
+
+def remove_headder_footer_one_page(
+    text_raw_blocks,
+    image_bboxes,
+    table_bboxes,
+    header_bboxs,
+    footer_bboxs,
+    page_no_bboxs,
+    page_w,
+    page_h,
+):
+    """
+    删除页眉页脚,页码
+    从line级别进行删除,删除之后观察这个text-block是否是空的,如果是空的,则移动到remove_list中
+    """
+    if 1:
+        return image_bboxes, table_bboxes, text_raw_blocks, [], [], []
+
+    header = []
+    footer = []
+    if len(header) == 0:
+        model_header = header_bboxs
+        if model_header:
+            x0 = min([x for x, _, _, _ in model_header])
+            y0 = min([y for _, y, _, _ in model_header])
+            x1 = max([x1 for _, _, x1, _ in model_header])
+            y1 = max([y1 for _, _, _, y1 in model_header])
+            header = [x0, y0, x1, y1]
+    if len(footer) == 0:
+        model_footer = footer_bboxs
+        if model_footer:
+            x0 = min([x for x, _, _, _ in model_footer])
+            y0 = min([y for _, y, _, _ in model_footer])
+            x1 = max([x1 for _, _, x1, _ in model_footer])
+            y1 = max([y1 for _, _, _, y1 in model_footer])
+            footer = [x0, y0, x1, y1]
+
+    header_y0 = 0 if len(header) == 0 else header[3]
+    footer_y0 = page_h if len(footer) == 0 else footer[1]
+    if page_no_bboxs:
+        top_part = [b for b in page_no_bboxs if b[3] < page_h / 2]
+        btn_part = [b for b in page_no_bboxs if b[1] > page_h / 2]
+
+        top_max_y0 = max([b[1] for b in top_part]) if top_part else 0
+        btn_min_y1 = min([b[3] for b in btn_part]) if btn_part else page_h
+
+        header_y0 = max(header_y0, top_max_y0)
+        footer_y0 = min(footer_y0, btn_min_y1)
+
+    content_boundry = [0, header_y0, page_w, footer_y0]
+
+    header = [0, 0, page_w, header_y0]
+    footer = [0, footer_y0, page_w, page_h]
+
+    """以上计算出来了页眉页脚的边界,下面开始进行删除"""
+    text_block_to_remove = []
+    # 首先检查每个textblock
+    for blk in text_raw_blocks:
+        if len(blk["lines"]) > 0:
+            for line in blk["lines"]:
+                line_del = []
+                for span in line["spans"]:
+                    span_del = []
+                    if span["bbox"][3] < header_y0:
+                        span_del.append(span)
+                    elif _is_in_or_part_overlap(
+                        span["bbox"], header
+                    ) or _is_in_or_part_overlap(span["bbox"], footer):
+                        span_del.append(span)
+                for span in span_del:
+                    line["spans"].remove(span)
+                if not line["spans"]:
+                    line_del.append(line)
+
+            for line in line_del:
+                blk["lines"].remove(line)
+        else:
+            # if not blk['lines']:
+            blk["tag"] = CONTENT_IN_FOOT_OR_HEADER
+            text_block_to_remove.append(blk)
+
+    """有的时候由于pageNo太小了,总是会有一点和content_boundry重叠一点,被放入正文,因此对于pageNo,进行span粒度的删除"""
+    page_no_block_2_remove = []
+    if page_no_bboxs:
+        for pagenobox in page_no_bboxs:
+            for block in text_raw_blocks:
+                if _is_in_or_part_overlap(
+                    pagenobox, block["bbox"]
+                ):  # 在span级别删除页码
+                    for line in block["lines"]:
+                        for span in line["spans"]:
+                            if _is_in_or_part_overlap(pagenobox, span["bbox"]):
+                                # span['text'] = ''
+                                span["tag"] = PAGE_NO
+                                # 检查这个block是否只有这一个span,如果是,那么就把这个block也删除
+                                if len(line["spans"]) == 1 and len(block["lines"]) == 1:
+                                    page_no_block_2_remove.append(block)
+    else:
+        # 测试最后一个是不是页码:规则是,最后一个block仅有1个line,一个span,且text是数字,空格,符号组成,不含字母,并且包含数字
+        if len(text_raw_blocks) > 0:
+            text_raw_blocks.sort(key=lambda x: x["bbox"][1], reverse=True)
+            last_block = text_raw_blocks[0]
+            if len(last_block["lines"]) == 1:
+                last_line = last_block["lines"][0]
+                if len(last_line["spans"]) == 1:
+                    last_span = last_line["spans"][0]
+                    if (
+                        last_span["text"].strip()
+                        and not re.search("[a-zA-Z]", last_span["text"])
+                        and re.search("[0-9]", last_span["text"])
+                    ):
+                        last_span["tag"] = PAGE_NO
+                        page_no_block_2_remove.append(last_block)
+
+    for b in page_no_block_2_remove:
+        text_block_to_remove.append(b)
+
+    for blk in text_block_to_remove:
+        if blk in text_raw_blocks:
+            text_raw_blocks.remove(blk)
+
+    text_block_remain = text_raw_blocks
+    image_bbox_to_remove = [
+        bbox
+        for bbox in image_bboxes
+        if not _is_in_or_part_overlap(bbox, content_boundry)
+    ]
+
+    image_bbox_remain = [
+        bbox for bbox in image_bboxes if _is_in_or_part_overlap(bbox, content_boundry)
+    ]
+    table_bbox_to_remove = [
+        bbox
+        for bbox in table_bboxes
+        if not _is_in_or_part_overlap(bbox, content_boundry)
+    ]
+    table_bbox_remain = [
+        bbox for bbox in table_bboxes if _is_in_or_part_overlap(bbox, content_boundry)
+    ]
+
+    #        1,                 2,                3
+    return (
+        image_bbox_remain,
+        table_bbox_remain,
+        text_block_remain,
+        text_block_to_remove,
+        image_bbox_to_remove,
+        table_bbox_to_remove,
+    )

+ 327 - 0
magic_pdf/train_utils/vis_utils.py

@@ -0,0 +1,327 @@
+from magic_pdf.libs.commons import fitz
+import os
+from magic_pdf.libs.coordinate_transform import get_scale_ratio
+
+
+def draw_model_output(
+    raw_pdf_doc: fitz.Document, paras_dict_arr: list[dict], save_path: str
+):
+    """
+    在page上画出bbox,保存到save_path
+    """
+    """
+    
+        # {0: 'title',  # 标题
+    # 1: 'figure', # 图片
+    #  2: 'plain text',  # 文本
+    #  3: 'header',      # 页眉
+    #  4: 'page number', # 页码
+    #  5: 'footnote',    # 脚注
+    #  6: 'footer',      # 页脚
+    #  7: 'table',       # 表格
+    #  8: 'table caption',  # 表格描述
+    #  9: 'figure caption', # 图片描述
+    #  10: 'equation',      # 公式
+    #  11: 'full column',   # 单栏
+    #  12: 'sub column',    # 多栏
+    #  13: 'embedding',     # 嵌入公式
+    #  14: 'isolated'}      # 单行公式
+    
+    """
+
+    color_map = {
+        "body": fitz.pdfcolor["green"],
+        "non_body": fitz.pdfcolor["red"],
+    }
+    """
+    {"layout_dets": [], "subfield_dets": [], "page_info": {"page_no": 22, "height": 1650, "width": 1275}}
+    """
+    for i, page in enumerate(raw_pdf_doc):
+        v = paras_dict_arr[i]
+        page_idx = v["page_info"]["page_no"]
+        width = v["page_info"]["width"]
+        height = v["page_info"]["height"]
+
+        horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
+            paras_dict_arr[i], page
+        )
+
+        for order, block in enumerate(v["layout_dets"]):
+            L = block["poly"][0] / horizontal_scale_ratio
+            U = block["poly"][1] / vertical_scale_ratio
+            R = block["poly"][2] / horizontal_scale_ratio
+            D = block["poly"][5] / vertical_scale_ratio
+            # L += pageL          # 有的页面,artBox偏移了。不在(0,0)
+            # R += pageL
+            # U += pageU
+            # D += pageU
+            L, R = min(L, R), max(L, R)
+            U, D = min(U, D), max(U, D)
+            bbox = [L, U, R, D]
+            color = color_map["body"]
+            if block["category_id"] in (3, 4, 5, 6, 0):
+                color = color_map["non_body"]
+
+            rect = fitz.Rect(bbox)
+            page.draw_rect(rect, fill=None, width=0.5, overlay=True, color=color)
+
+    parent_dir = os.path.dirname(save_path)
+    if not os.path.exists(parent_dir):
+        os.makedirs(parent_dir)
+    raw_pdf_doc.save(save_path)
+
+
+def debug_show_bbox(
+    raw_pdf_doc: fitz.Document,
+    page_idx: int,
+    bboxes: list,
+    droped_bboxes: list,
+    expect_drop_bboxes: list,
+    save_path: str,
+    expected_page_id: int,
+):
+    """
+    以覆盖的方式写个临时的pdf,用于debug
+    """
+    if page_idx != expected_page_id:
+        return
+
+    if os.path.exists(save_path):
+        # 删除已经存在的文件
+        os.remove(save_path)
+    # 创建一个新的空白 PDF 文件
+    doc = fitz.open("")
+
+    width = raw_pdf_doc[page_idx].rect.width
+    height = raw_pdf_doc[page_idx].rect.height
+    new_page = doc.new_page(width=width, height=height)
+
+    shape = new_page.new_shape()
+    for bbox in bboxes:
+        # 原始box画上去
+        rect = fitz.Rect(*bbox[0:4])
+        shape = new_page.new_shape()
+        shape.draw_rect(rect)
+        shape.finish(
+            color=fitz.pdfcolor["red"], fill=fitz.pdfcolor["blue"], fill_opacity=0.2
+        )
+        shape.finish()
+        shape.commit()
+
+    for bbox in droped_bboxes:
+        # 原始box画上去
+        rect = fitz.Rect(*bbox[0:4])
+        shape = new_page.new_shape()
+        shape.draw_rect(rect)
+        shape.finish(color=None, fill=fitz.pdfcolor["yellow"], fill_opacity=0.2)
+        shape.finish()
+        shape.commit()
+
+    for bbox in expect_drop_bboxes:
+        # 原始box画上去
+        rect = fitz.Rect(*bbox[0:4])
+        shape = new_page.new_shape()
+        shape.draw_rect(rect)
+        shape.finish(color=fitz.pdfcolor["red"], fill=None)
+        shape.finish()
+        shape.commit()
+
+    # shape.insert_textbox(fitz.Rect(200, 0, 600, 20), f"total bboxes: {len(bboxes)}", fontname="helv", fontsize=12,
+    #                      color=(0, 0, 0))
+    # shape.finish(color=fitz.pdfcolor['black'])
+    # shape.commit()
+
+    parent_dir = os.path.dirname(save_path)
+    if not os.path.exists(parent_dir):
+        os.makedirs(parent_dir)
+
+    doc.save(save_path)
+    doc.close()
+
+
+def debug_show_page(
+    page,
+    bboxes1: list,
+    bboxes2: list,
+    bboxes3: list,
+):
+    save_path = "./tmp/debug.pdf"
+    if os.path.exists(save_path):
+        # 删除已经存在的文件
+        os.remove(save_path)
+    # 创建一个新的空白 PDF 文件
+    doc = fitz.open("")
+
+    width = page.rect.width
+    height = page.rect.height
+    new_page = doc.new_page(width=width, height=height)
+
+    shape = new_page.new_shape()
+    for bbox in bboxes1:
+        # 原始box画上去
+        rect = fitz.Rect(*bbox[0:4])
+        shape = new_page.new_shape()
+        shape.draw_rect(rect)
+        shape.finish(
+            color=fitz.pdfcolor["red"], fill=fitz.pdfcolor["blue"], fill_opacity=0.2
+        )
+        shape.finish()
+        shape.commit()
+
+    for bbox in bboxes2:
+        # 原始box画上去
+        rect = fitz.Rect(*bbox[0:4])
+        shape = new_page.new_shape()
+        shape.draw_rect(rect)
+        shape.finish(color=None, fill=fitz.pdfcolor["yellow"], fill_opacity=0.2)
+        shape.finish()
+        shape.commit()
+
+    for bbox in bboxes3:
+        # 原始box画上去
+        rect = fitz.Rect(*bbox[0:4])
+        shape = new_page.new_shape()
+        shape.draw_rect(rect)
+        shape.finish(color=fitz.pdfcolor["red"], fill=None)
+        shape.finish()
+        shape.commit()
+
+    parent_dir = os.path.dirname(save_path)
+    if not os.path.exists(parent_dir):
+        os.makedirs(parent_dir)
+
+    doc.save(save_path)
+    doc.close()
+
+
+def draw_layout_bbox_on_page(
+    raw_pdf_doc: fitz.Document, paras_dict: dict, header, footer, pdf_path: str
+):
+    """
+    在page上画出bbox,保存到save_path
+    """
+    # 检查文件是否存在
+    is_new_pdf = False
+    if os.path.exists(pdf_path):
+        # 打开现有的 PDF 文件
+        doc = fitz.open(pdf_path)
+    else:
+        # 创建一个新的空白 PDF 文件
+        is_new_pdf = True
+        doc = fitz.open("")
+
+    for k, v in paras_dict.items():
+        page_idx = v["page_idx"]
+        layouts = v["layout_bboxes"]
+        page = doc[page_idx]
+        shape = page.new_shape()
+        for order, layout in enumerate(layouts):
+            border_offset = 1
+            rect_box = layout["layout_bbox"]
+            layout_label = layout["layout_label"]
+            fill_color = fitz.pdfcolor["pink"] if layout_label == "U" else None
+            rect_box = [
+                rect_box[0] + 1,
+                rect_box[1] - border_offset,
+                rect_box[2] - 1,
+                rect_box[3] + border_offset,
+            ]
+            rect = fitz.Rect(*rect_box)
+            shape.draw_rect(rect)
+            shape.finish(color=fitz.pdfcolor["red"], fill=fill_color, fill_opacity=0.4)
+            """
+            draw order text on layout box
+            """
+            font_size = 10
+            shape.insert_text(
+                (rect_box[0] + 1, rect_box[1] + font_size),
+                f"{order}",
+                fontsize=font_size,
+                color=(0, 0, 0),
+            )
+
+        """画上footer header"""
+        if header:
+            shape.draw_rect(fitz.Rect(header))
+            shape.finish(color=None, fill=fitz.pdfcolor["black"], fill_opacity=0.2)
+        if footer:
+            shape.draw_rect(fitz.Rect(footer))
+            shape.finish(color=None, fill=fitz.pdfcolor["black"], fill_opacity=0.2)
+
+        shape.commit()
+
+    if is_new_pdf:
+        doc.save(pdf_path)
+    else:
+        doc.saveIncr()
+    doc.close()
+
+
+@DeprecationWarning
+def draw_layout_on_page(
+    raw_pdf_doc: fitz.Document, page_idx: int, page_layout: list, pdf_path: str
+):
+    """
+    把layout的box用红色边框花在pdf_path的page_idx上
+    """
+
+    def draw(shape, layout, fill_color=fitz.pdfcolor["pink"]):
+        border_offset = 1
+        rect_box = layout["layout_bbox"]
+        layout_label = layout["layout_label"]
+        sub_layout = layout["sub_layout"]
+        if len(sub_layout) == 0:
+            fill_color = fill_color if layout_label == "U" else None
+            rect_box = [
+                rect_box[0] + 1,
+                rect_box[1] - border_offset,
+                rect_box[2] - 1,
+                rect_box[3] + border_offset,
+            ]
+            rect = fitz.Rect(*rect_box)
+            shape.draw_rect(rect)
+            shape.finish(color=fitz.pdfcolor["red"], fill=fill_color, fill_opacity=0.2)
+            # if layout_label=='U':
+            #     bad_boxes = layout.get("bad_boxes", [])
+            #     for bad_box in bad_boxes:
+            #         rect = fitz.Rect(*bad_box)
+            #         shape.draw_rect(rect)
+            #         shape.finish(color=fitz.pdfcolor['red'], fill=fitz.pdfcolor['red'], fill_opacity=0.2)
+        # else:
+        #     rect = fitz.Rect(*rect_box)
+        #     shape.draw_rect(rect)
+        #     shape.finish(color=fitz.pdfcolor['blue'])
+
+        for sub_layout in sub_layout:
+            draw(shape, sub_layout)
+        shape.commit()
+
+    # 检查文件是否存在
+    is_new_pdf = False
+    if os.path.exists(pdf_path):
+        # 打开现有的 PDF 文件
+        doc = fitz.open(pdf_path)
+    else:
+        # 创建一个新的空白 PDF 文件
+        is_new_pdf = True
+        doc = fitz.open("")
+
+    page = doc[page_idx]
+    shape = page.new_shape()
+    for order, layout in enumerate(page_layout):
+        draw(shape, layout, fitz.pdfcolor["yellow"])
+
+    # shape.insert_textbox(fitz.Rect(200, 0, 600, 20), f"total bboxes: {len(layout)}", fontname="helv", fontsize=12,
+    #                      color=(0, 0, 0))
+    # shape.finish(color=fitz.pdfcolor['black'])
+    # shape.commit()
+
+    parent_dir = os.path.dirname(pdf_path)
+    if not os.path.exists(parent_dir):
+        os.makedirs(parent_dir)
+
+    if is_new_pdf:
+        doc.save(pdf_path)
+    else:
+        doc.saveIncr()
+    doc.close()

+ 2 - 1
requirements.txt

@@ -2,7 +2,7 @@ boto3>=1.28.43
 Brotli>=1.1.0
 click>=8.1.7
 Distance>=0.1.3
-PyMuPDF>=1.23.26
+PyMuPDF>=1.24.0
 loguru>=0.6.0
 matplotlib>=3.8.3
 numpy>=1.21.6
@@ -12,5 +12,6 @@ regex>=2023.12.25
 spacy>=3.7.4
 termcolor>=2.4.0
 scikit-learn
+wordninja>=2.0.0
 en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl
 zh_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.7.0/zh_core_web_sm-3.7.0-py3-none-any.whl