Pārlūkot izejas kodu

refactor: move some constants or enums defs to config folder

icecraft 1 gadu atpakaļ
vecāks
revīzija
b492c19c4c
43 mainītis faili ar 2485 papildinājumiem un 1621 dzēšanām
  1. 53 0
      magic_pdf/config/constants.py
  2. 35 0
      magic_pdf/config/drop_reason.py
  3. 19 0
      magic_pdf/config/drop_tag.py
  4. 11 0
      magic_pdf/config/make_content_config.py
  5. 2 1
      magic_pdf/config/model_block_type.py
  6. 0 0
      magic_pdf/config/ocr_content_type.py
  7. 226 185
      magic_pdf/dict2md/mkcontent.py
  8. 7 8
      magic_pdf/dict2md/ocr_mkcontent.py
  9. 101 79
      magic_pdf/filter/pdf_meta_scan.py
  10. 1 1
      magic_pdf/integrations/rag/utils.py
  11. 0 55
      magic_pdf/libs/Constants.py
  12. 0 11
      magic_pdf/libs/MakeContentConfig.py
  13. 5 5
      magic_pdf/libs/config_reader.py
  14. 3 2
      magic_pdf/libs/draw_bbox.py
  15. 0 27
      magic_pdf/libs/drop_reason.py
  16. 0 19
      magic_pdf/libs/drop_tag.py
  17. 2 2
      magic_pdf/model/magic_model.py
  18. 109 59
      magic_pdf/model/pdf_extract_kit.py
  19. 39 34
      magic_pdf/model/sub_modules/model_init.py
  20. 30 28
      magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py
  21. 408 247
      magic_pdf/para/para_split.py
  22. 352 182
      magic_pdf/para/para_split_v2.py
  23. 81 46
      magic_pdf/para/para_split_v3.py
  24. 174 100
      magic_pdf/pdf_parse_union_core.py
  25. 10 8
      magic_pdf/pdf_parse_union_core_v2.py
  26. 2 2
      magic_pdf/pipe/AbsPipe.py
  27. 1 1
      magic_pdf/pipe/OCRPipe.py
  28. 1 1
      magic_pdf/pipe/TXTPipe.py
  29. 1 1
      magic_pdf/pipe/UNIPipe.py
  30. 7 14
      magic_pdf/post_proc/pdf_post_filter.py
  31. 9 11
      magic_pdf/pre_proc/cut_image.py
  32. 203 212
      magic_pdf/pre_proc/equations_replace.py
  33. 235 49
      magic_pdf/pre_proc/ocr_detect_all_bboxes.py
  34. 3 3
      magic_pdf/pre_proc/ocr_dict_merge.py
  35. 123 60
      magic_pdf/pre_proc/ocr_span_list_modify.py
  36. 37 33
      magic_pdf/pre_proc/pdf_pre_filter.py
  37. 20 18
      magic_pdf/pre_proc/remove_bbox_overlap.py
  38. 36 14
      magic_pdf/pre_proc/remove_colored_strip_bbox.py
  39. 2 5
      magic_pdf/pre_proc/remove_footer_header.py
  40. 111 63
      magic_pdf/pre_proc/remove_rotate_bbox.py
  41. 10 17
      magic_pdf/pre_proc/resolve_bbox_conflict.py
  42. 15 17
      magic_pdf/spark/spark_api.py
  43. 1 1
      magic_pdf/tools/common.py

+ 53 - 0
magic_pdf/config/constants.py

@@ -0,0 +1,53 @@
+"""span维度自定义字段."""
+# span是否是跨页合并的
+CROSS_PAGE = 'cross_page'
+
+"""
+block维度自定义字段
+"""
+# block中lines是否被删除
+LINES_DELETED = 'lines_deleted'
+
+# table recognition max time default value
+TABLE_MAX_TIME_VALUE = 400
+
+# pp_table_result_max_length
+TABLE_MAX_LEN = 480
+
+# table master structure dict
+TABLE_MASTER_DICT = 'table_master_structure_dict.txt'
+
+# table master dir
+TABLE_MASTER_DIR = 'table_structure_tablemaster_infer/'
+
+# pp detect model dir
+DETECT_MODEL_DIR = 'ch_PP-OCRv4_det_infer'
+
+# pp rec model dir
+REC_MODEL_DIR = 'ch_PP-OCRv4_rec_infer'
+
+# pp rec char dict path
+REC_CHAR_DICT = 'ppocr_keys_v1.txt'
+
+# pp rec copy rec directory
+PP_REC_DIRECTORY = '.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer'
+
+# pp rec copy det directory
+PP_DET_DIRECTORY = '.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer'
+
+
+class MODEL_NAME:
+    # pp table structure algorithm
+    TABLE_MASTER = 'tablemaster'
+    # struct eqtable
+    STRUCT_EQTABLE = 'struct_eqtable'
+
+    DocLayout_YOLO = 'doclayout_yolo'
+
+    LAYOUTLMv3 = 'layoutlmv3'
+
+    YOLO_V8_MFD = 'yolo_v8_mfd'
+
+    UniMerNet_v2_Small = 'unimernet_small'
+
+    RAPID_TABLE = 'rapid_table'

+ 35 - 0
magic_pdf/config/drop_reason.py

@@ -0,0 +1,35 @@
+class DropReason:
+    TEXT_BLCOK_HOR_OVERLAP = 'text_block_horizontal_overlap'  # 文字块有水平互相覆盖,导致无法准确定位文字顺序
+    USEFUL_BLOCK_HOR_OVERLAP = (
+        'useful_block_horizontal_overlap'  # 需保留的block水平覆盖
+    )
+    COMPLICATED_LAYOUT = 'complicated_layout'  # 复杂的布局,暂时不支持
+    TOO_MANY_LAYOUT_COLUMNS = 'too_many_layout_columns'  # 目前不支持分栏超过2列的
+    COLOR_BACKGROUND_TEXT_BOX = 'color_background_text_box'  # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
+    HIGH_COMPUTATIONAL_lOAD_BY_IMGS = (
+        'high_computational_load_by_imgs'  # 含特殊图片,计算量太大,从而丢弃
+    )
+    HIGH_COMPUTATIONAL_lOAD_BY_SVGS = (
+        'high_computational_load_by_svgs'  # 特殊的SVG图,计算量太大,从而丢弃
+    )
+    HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = 'high_computational_load_by_total_pages'  # 计算量超过负荷,当前方法下计算量消耗过大
+    MISS_DOC_LAYOUT_RESULT = 'missing doc_layout_result'  # 版面分析失败
+    Exception = '_exception'  # 解析中发生异常
+    ENCRYPTED = 'encrypted'  # PDF是加密的
+    EMPTY_PDF = 'total_page=0'  # PDF页面总数为0
+    NOT_IS_TEXT_PDF = 'not_is_text_pdf'  # 不是文字版PDF,无法直接解析
+    DENSE_SINGLE_LINE_BLOCK = 'dense_single_line_block'  # 无法清晰的分段
+    TITLE_DETECTION_FAILED = 'title_detection_failed'  # 探测标题失败
+    TITLE_LEVEL_FAILED = (
+        'title_level_failed'  # 分析标题级别失败(例如一级、二级、三级标题)
+    )
+    PARA_SPLIT_FAILED = 'para_split_failed'  # 识别段落失败
+    PARA_MERGE_FAILED = 'para_merge_failed'  # 段落合并失败
+    NOT_ALLOW_LANGUAGE = 'not_allow_language'  # 不支持的语种
+    SPECIAL_PDF = 'special_pdf'
+    PSEUDO_SINGLE_COLUMN = 'pseudo_single_column'  # 无法精确判断文字分栏
+    CAN_NOT_DETECT_PAGE_LAYOUT = 'can_not_detect_page_layout'  # 无法分析页面的版面
+    NEGATIVE_BBOX_AREA = 'negative_bbox_area'  # 缩放导致 bbox 面积为负
+    OVERLAP_BLOCKS_CAN_NOT_SEPARATION = (
+        'overlap_blocks_can_t_separation'  # 无法分离重叠的block
+    )

+ 19 - 0
magic_pdf/config/drop_tag.py

@@ -0,0 +1,19 @@
+
+COLOR_BG_HEADER_TXT_BLOCK = 'color_background_header_txt_block'
+PAGE_NO = 'page-no'  # 页码
+CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area'  # 页眉页脚内的文本
+VERTICAL_TEXT = 'vertical-text'  # 垂直文本
+ROTATE_TEXT = 'rotate-text'  # 旋转文本
+EMPTY_SIDE_BLOCK = 'empty-side-block'  # 边缘上的空白没有任何内容的block
+ON_IMAGE_TEXT = 'on-image-text'  # 文本在图片上
+ON_TABLE_TEXT = 'on-table-text'  # 文本在表格上
+
+
+class DropTag:
+    PAGE_NUMBER = 'page_no'
+    HEADER = 'header'
+    FOOTER = 'footer'
+    FOOTNOTE = 'footnote'
+    NOT_IN_LAYOUT = 'not_in_layout'
+    SPAN_OVERLAP = 'span_overlap'
+    BLOCK_OVERLAP = 'block_overlap'

+ 11 - 0
magic_pdf/config/make_content_config.py

@@ -0,0 +1,11 @@
+class MakeMode:
+    MM_MD = 'mm_markdown'
+    NLP_MD = 'nlp_markdown'
+    STANDARD_FORMAT = 'standard_format'
+
+
+class DropMode:
+    WHOLE_PDF = 'whole_pdf'
+    SINGLE_PAGE = 'single_page'
+    NONE = 'none'
+    NONE_WITH_REASON = 'none_with_reason'

+ 2 - 1
magic_pdf/libs/ModelBlockTypeEnum.py → magic_pdf/config/model_block_type.py

@@ -1,9 +1,10 @@
 from enum import Enum
 
+
 class ModelBlockTypeEnum(Enum):
     TITLE = 0
     PLAIN_TEXT = 1
     ABANDON = 2
     ISOLATE_FORMULA = 8
     EMBEDDING = 13
-    ISOLATED = 14
+    ISOLATED = 14

+ 0 - 0
magic_pdf/libs/ocr_content_type.py → magic_pdf/config/ocr_content_type.py


+ 226 - 185
magic_pdf/dict2md/mkcontent.py

@@ -1,9 +1,11 @@
 import math
+
 from loguru import logger
 
-from magic_pdf.libs.boxbase import find_bottom_nearest_text_bbox, find_top_nearest_text_bbox
+from magic_pdf.config.ocr_content_type import ContentType
+from magic_pdf.libs.boxbase import (find_bottom_nearest_text_bbox,
+                                    find_top_nearest_text_bbox)
 from magic_pdf.libs.commons import join_path
-from magic_pdf.libs.ocr_content_type import ContentType
 
 TYPE_INLINE_EQUATION = ContentType.InlineEquation
 TYPE_INTERLINE_EQUATION = ContentType.InterlineEquation
@@ -12,33 +14,30 @@ UNI_FORMAT_TEXT_TYPE = ['text', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']
 
 @DeprecationWarning
 def mk_nlp_markdown_1(para_dict: dict):
-    """
-    对排序后的bboxes拼接内容
-    """
+    """对排序后的bboxes拼接内容."""
     content_lst = []
     for _, page_info in para_dict.items():
-        para_blocks = page_info.get("para_blocks")
+        para_blocks = page_info.get('para_blocks')
         if not para_blocks:
             continue
 
         for block in para_blocks:
-            item = block["paras"]
+            item = block['paras']
             for _, p in item.items():
-                para_text = p["para_text"]
-                is_title = p["is_para_title"]
+                para_text = p['para_text']
+                is_title = p['is_para_title']
                 title_level = p['para_title_level']
-                md_title_prefix = "#"*title_level
+                md_title_prefix = '#' * title_level
                 if is_title:
-                    content_lst.append(f"{md_title_prefix} {para_text}")
+                    content_lst.append(f'{md_title_prefix} {para_text}')
                 else:
                     content_lst.append(para_text)
 
-    content_text = "\n\n".join(content_lst)
+    content_text = '\n\n'.join(content_lst)
 
     return content_text
 
 
-
 # 找到目标字符串在段落中的索引
 def __find_index(paragraph, target):
     index = paragraph.find(target)
@@ -48,69 +47,76 @@ def __find_index(paragraph, target):
         return None
 
 
-def __insert_string(paragraph, target, postion):
-    new_paragraph = paragraph[:postion] + target + paragraph[postion:] 
+def __insert_string(paragraph, target, position):
+    new_paragraph = paragraph[:position] + target + paragraph[position:]
     return new_paragraph
 
 
 def __insert_after(content, image_content, target):
-    """
-    在content中找到target,将image_content插入到target后面
-    """
+    """在content中找到target,将image_content插入到target后面."""
     index = content.find(target)
     if index != -1:
-        content = content[:index+len(target)] + "\n\n" + image_content + "\n\n" + content[index+len(target):]
+        content = (
+            content[: index + len(target)]
+            + '\n\n'
+            + image_content
+            + '\n\n'
+            + content[index + len(target) :]
+        )
     else:
-        logger.error(f"Can't find the location of image {image_content} in the markdown file, search target is {target}")
+        logger.error(
+            f"Can't find the location of image {image_content} in the markdown file, search target is {target}"
+        )
     return content
 
+
 def __insert_before(content, image_content, target):
-    """
-    在content中找到target,将image_content插入到target前面
-    """
+    """在content中找到target,将image_content插入到target前面."""
     index = content.find(target)
     if index != -1:
-        content = content[:index] + "\n\n" + image_content + "\n\n" + content[index:]
+        content = content[:index] + '\n\n' + image_content + '\n\n' + content[index:]
     else:
-        logger.error(f"Can't find the location of image {image_content} in the markdown file, search target is {target}")
+        logger.error(
+            f"Can't find the location of image {image_content} in the markdown file, search target is {target}"
+        )
     return content
 
 
 @DeprecationWarning
 def mk_mm_markdown_1(para_dict: dict):
-    """拼装多模态markdown"""
+    """拼装多模态markdown."""
     content_lst = []
     for _, page_info in para_dict.items():
-        page_lst = [] # 一个page内的段落列表
-        para_blocks = page_info.get("para_blocks")
-        pymu_raw_blocks = page_info.get("preproc_blocks")
-        
+        page_lst = []  # 一个page内的段落列表
+        para_blocks = page_info.get('para_blocks')
+        pymu_raw_blocks = page_info.get('preproc_blocks')
+
         all_page_images = []
-        all_page_images.extend(page_info.get("images",[]))
-        all_page_images.extend(page_info.get("image_backup", []) )
-        all_page_images.extend(page_info.get("tables",[]))
-        all_page_images.extend(page_info.get("table_backup",[]) )
-        
-        if not para_blocks or not pymu_raw_blocks: # 只有图片的拼接的场景
+        all_page_images.extend(page_info.get('images', []))
+        all_page_images.extend(page_info.get('image_backup', []))
+        all_page_images.extend(page_info.get('tables', []))
+        all_page_images.extend(page_info.get('table_backup', []))
+
+        if not para_blocks or not pymu_raw_blocks:  # 只有图片的拼接的场景
             for img in all_page_images:
-                page_lst.append(f"![]({img['image_path']})") # TODO 图片顺序
-            page_md = "\n\n".join(page_lst)
-            
+                page_lst.append(f"![]({img['image_path']})")  # TODO 图片顺序
+            page_md = '\n\n'.join(page_lst)
+
         else:
             for block in para_blocks:
-                item = block["paras"]
+                item = block['paras']
                 for _, p in item.items():
-                    para_text = p["para_text"]
-                    is_title = p["is_para_title"]
+                    para_text = p['para_text']
+                    is_title = p['is_para_title']
                     title_level = p['para_title_level']
-                    md_title_prefix = "#"*title_level
+                    md_title_prefix = '#' * title_level
                     if is_title:
-                        page_lst.append(f"{md_title_prefix} {para_text}")
+                        page_lst.append(f'{md_title_prefix} {para_text}')
                     else:
                         page_lst.append(para_text)
-                        
+
             """拼装成一个页面的文本"""
-            page_md = "\n\n".join(page_lst)
+            page_md = '\n\n'.join(page_lst)
             """插入图片"""
             for img in all_page_images:
                 imgbox = img['bbox']
@@ -118,192 +124,215 @@ def mk_mm_markdown_1(para_dict: dict):
                 # 先看在哪个block内
                 for block in pymu_raw_blocks:
                     bbox = block['bbox']
-                    if bbox[0]-1 <= imgbox[0] < bbox[2]+1 and bbox[1]-1 <= imgbox[1] < bbox[3]+1:# 确定在block内
-                        for l in block['lines']:
+                    if (
+                        bbox[0] - 1 <= imgbox[0] < bbox[2] + 1
+                        and bbox[1] - 1 <= imgbox[1] < bbox[3] + 1
+                    ):  # 确定在block内
+                        for l in block['lines']:  # noqa: E741
                             line_box = l['bbox']
-                            if line_box[0]-1 <= imgbox[0] < line_box[2]+1 and line_box[1]-1 <= imgbox[1] < line_box[3]+1: # 在line内的,插入line前面
-                                line_txt = "".join([s['text'] for s in l['spans']])
-                                page_md = __insert_before(page_md, img_content, line_txt)
+                            if (
+                                line_box[0] - 1 <= imgbox[0] < line_box[2] + 1
+                                and line_box[1] - 1 <= imgbox[1] < line_box[3] + 1
+                            ):  # 在line内的,插入line前面
+                                line_txt = ''.join([s['text'] for s in l['spans']])
+                                page_md = __insert_before(
+                                    page_md, img_content, line_txt
+                                )
                                 break
                             break
-                        else:# 在行与行之间
+                        else:  # 在行与行之间
                             # 找到图片x0,y0与line的x0,y0最近的line
                             min_distance = 100000
                             min_line = None
-                            for l in block['lines']:
+                            for l in block['lines']:  # noqa: E741
                                 line_box = l['bbox']
-                                distance = math.sqrt((line_box[0] - imgbox[0])**2 + (line_box[1] - imgbox[1])**2)
+                                distance = math.sqrt(
+                                    (line_box[0] - imgbox[0]) ** 2
+                                    + (line_box[1] - imgbox[1]) ** 2
+                                )
                                 if distance < min_distance:
                                     min_distance = distance
                                     min_line = l
                             if min_line:
-                                line_txt = "".join([s['text'] for s in min_line['spans']])
+                                line_txt = ''.join(
+                                    [s['text'] for s in min_line['spans']]
+                                )
                                 img_h = imgbox[3] - imgbox[1]
-                                if min_distance<img_h: # 文字在图片前面
-                                    page_md = __insert_after(page_md, img_content, line_txt)
+                                if min_distance < img_h:  # 文字在图片前面
+                                    page_md = __insert_after(
+                                        page_md, img_content, line_txt
+                                    )
                                 else:
-                                    page_md = __insert_before(page_md, img_content, line_txt)
+                                    page_md = __insert_before(
+                                        page_md, img_content, line_txt
+                                    )
                             else:
-                                logger.error(f"Can't find the location of image {img['image_path']} in the markdown file #1")
-                else:# 应当在两个block之间
+                                logger.error(
+                                    f"Can't find the location of image {img['image_path']} in the markdown file  #1"
+                                )
+                else:  # 应当在两个block之间
                     # 找到上方最近的block,如果上方没有就找大下方最近的block
                     top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, imgbox)
                     if top_txt_block:
-                        line_txt = "".join([s['text'] for s in top_txt_block['lines'][-1]['spans']])
+                        line_txt = ''.join(
+                            [s['text'] for s in top_txt_block['lines'][-1]['spans']]
+                        )
                         page_md = __insert_after(page_md, img_content, line_txt)
                     else:
-                        bottom_txt_block = find_bottom_nearest_text_bbox(pymu_raw_blocks, imgbox)
+                        bottom_txt_block = find_bottom_nearest_text_bbox(
+                            pymu_raw_blocks, imgbox
+                        )
                         if bottom_txt_block:
-                            line_txt = "".join([s['text'] for s in bottom_txt_block['lines'][0]['spans']])
+                            line_txt = ''.join(
+                                [
+                                    s['text']
+                                    for s in bottom_txt_block['lines'][0]['spans']
+                                ]
+                            )
                             page_md = __insert_before(page_md, img_content, line_txt)
                         else:
-                            logger.error(f"Can't find the location of image {img['image_path']} in the markdown file #2")
-                    
+                            logger.error(
+                                f"Can't find the location of image {img['image_path']} in the markdown file  #2"
+                            )
+
         content_lst.append(page_md)
-                    
+
     """拼装成全部页面的文本"""
-    content_text = "\n\n".join(content_lst)
+    content_text = '\n\n'.join(content_lst)
 
     return content_text
 
 
 def __insert_after_para(text, type, element, content_list):
-    """
-    在content_list中找到text,将image_path作为一个新的node插入到text后面
-    """
+    """在content_list中找到text,将image_path作为一个新的node插入到text后面."""
     for i, c in enumerate(content_list):
-        content_type = c.get("type")
-        if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get("text", ''):
-            if type == "image":
+        content_type = c.get('type')
+        if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get('text', ''):
+            if type == 'image':
                 content_node = {
-                    "type": "image",
-                    "img_path": element.get("image_path"),
-                    "img_alt": "",
-                    "img_title": "",
-                    "img_caption": "",
+                    'type': 'image',
+                    'img_path': element.get('image_path'),
+                    'img_alt': '',
+                    'img_title': '',
+                    'img_caption': '',
                 }
-            elif type == "table":
+            elif type == 'table':
                 content_node = {
-                    "type": "table",
-                    "img_path": element.get("image_path"),
-                    "table_latex": element.get("text"),
-                    "table_title": "",
-                    "table_caption": "",
-                    "table_quality": element.get("quality"),
+                    'type': 'table',
+                    'img_path': element.get('image_path'),
+                    'table_latex': element.get('text'),
+                    'table_title': '',
+                    'table_caption': '',
+                    'table_quality': element.get('quality'),
                 }
-            content_list.insert(i+1, content_node)
+            content_list.insert(i + 1, content_node)
             break
     else:
-        logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}")
-    
+        logger.error(
+            f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}"
+        )
 
 
 def __insert_before_para(text, type, element, content_list):
-    """
-    在content_list中找到text,将image_path作为一个新的node插入到text前面
-    """
+    """在content_list中找到text,将image_path作为一个新的node插入到text前面."""
     for i, c in enumerate(content_list):
-        content_type = c.get("type")
-        if content_type in  UNI_FORMAT_TEXT_TYPE and text in c.get("text", ''):
-            if type == "image":
+        content_type = c.get('type')
+        if content_type in UNI_FORMAT_TEXT_TYPE and text in c.get('text', ''):
+            if type == 'image':
                 content_node = {
-                    "type": "image",
-                    "img_path": element.get("image_path"),
-                    "img_alt": "",
-                    "img_title": "",
-                    "img_caption": "",
+                    'type': 'image',
+                    'img_path': element.get('image_path'),
+                    'img_alt': '',
+                    'img_title': '',
+                    'img_caption': '',
                 }
-            elif type == "table":
+            elif type == 'table':
                 content_node = {
-                    "type": "table",
-                    "img_path": element.get("image_path"),
-                    "table_latex": element.get("text"),
-                    "table_title": "",
-                    "table_caption": "",
-                    "table_quality": element.get("quality"),
+                    'type': 'table',
+                    'img_path': element.get('image_path'),
+                    'table_latex': element.get('text'),
+                    'table_title': '',
+                    'table_caption': '',
+                    'table_quality': element.get('quality'),
                 }
             content_list.insert(i, content_node)
             break
     else:
-        logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}")
-         
+        logger.error(
+            f"Can't find the location of image {element.get('image_path')} in the markdown file, search target is {text}"
+        )
+
 
 def mk_universal_format(pdf_info_list: list, img_buket_path):
-    """
-    构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY
-    """
+    """构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY."""
     content_lst = []
     for page_info in pdf_info_list:
-        page_lst = [] # 一个page内的段落列表
-        para_blocks = page_info.get("para_blocks")
-        pymu_raw_blocks = page_info.get("preproc_blocks")
-        
+        page_lst = []  # 一个page内的段落列表
+        para_blocks = page_info.get('para_blocks')
+        pymu_raw_blocks = page_info.get('preproc_blocks')
+
         all_page_images = []
-        all_page_images.extend(page_info.get("images",[]))
-        all_page_images.extend(page_info.get("image_backup", []) )
+        all_page_images.extend(page_info.get('images', []))
+        all_page_images.extend(page_info.get('image_backup', []))
         # all_page_images.extend(page_info.get("tables",[]))
         # all_page_images.extend(page_info.get("table_backup",[]) )
         all_page_tables = []
-        all_page_tables.extend(page_info.get("tables", []))
+        all_page_tables.extend(page_info.get('tables', []))
 
-        if not para_blocks or not pymu_raw_blocks: # 只有图片的拼接的场景
+        if not para_blocks or not pymu_raw_blocks:  # 只有图片的拼接的场景
             for img in all_page_images:
                 content_node = {
-                    "type": "image",
-                    "img_path": join_path(img_buket_path, img['image_path']),
-                    "img_alt":"",
-                    "img_title":"",
-                    "img_caption":""
+                    'type': 'image',
+                    'img_path': join_path(img_buket_path, img['image_path']),
+                    'img_alt': '',
+                    'img_title': '',
+                    'img_caption': '',
                 }
-                page_lst.append(content_node) # TODO 图片顺序
+                page_lst.append(content_node)  # TODO 图片顺序
             for table in all_page_tables:
                 content_node = {
-                    "type": "table",
-                    "img_path": join_path(img_buket_path, table['image_path']),
-                    "table_latex": table.get("text"),
-                    "table_title": "",
-                    "table_caption": "",
-                    "table_quality": table.get("quality"),
+                    'type': 'table',
+                    'img_path': join_path(img_buket_path, table['image_path']),
+                    'table_latex': table.get('text'),
+                    'table_title': '',
+                    'table_caption': '',
+                    'table_quality': table.get('quality'),
                 }
-                page_lst.append(content_node) # TODO 图片顺序
+                page_lst.append(content_node)  # TODO 图片顺序
         else:
             for block in para_blocks:
-                item = block["paras"]
+                item = block['paras']
                 for _, p in item.items():
-                    font_type = p['para_font_type']# 对于文本来说,要么是普通文本,要么是个行间公式
+                    font_type = p[
+                        'para_font_type'
+                    ]  # 对于文本来说,要么是普通文本,要么是个行间公式
                     if font_type == TYPE_INTERLINE_EQUATION:
-                        content_node = {
-                            "type": "equation",
-                            "latex": p["para_text"]
-                        }
+                        content_node = {'type': 'equation', 'latex': p['para_text']}
                         page_lst.append(content_node)
                     else:
-                        para_text = p["para_text"]
-                        is_title = p["is_para_title"]
+                        para_text = p['para_text']
+                        is_title = p['is_para_title']
                         title_level = p['para_title_level']
-                        
+
                         if is_title:
                             content_node = {
-                                "type": f"h{title_level}",
-                                "text": para_text
+                                'type': f'h{title_level}',
+                                'text': para_text,
                             }
                             page_lst.append(content_node)
                         else:
-                            content_node = {
-                                "type": "text",
-                                "text": para_text
-                            }
+                            content_node = {'type': 'text', 'text': para_text}
                             page_lst.append(content_node)
-                            
+
         content_lst.extend(page_lst)
-        
+
         """插入图片"""
         for img in all_page_images:
-            insert_img_or_table("image", img, pymu_raw_blocks, content_lst)
+            insert_img_or_table('image', img, pymu_raw_blocks, content_lst)
 
         """插入表格"""
         for table in all_page_tables:
-            insert_img_or_table("table", table, pymu_raw_blocks, content_lst)
+            insert_img_or_table('table', table, pymu_raw_blocks, content_lst)
     # end for
     return content_lst
 
@@ -313,13 +342,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
     # 先看在哪个block内
     for block in pymu_raw_blocks:
         bbox = block['bbox']
-        if bbox[0] - 1 <= element_bbox[0] < bbox[2] + 1 and bbox[1] - 1 <= element_bbox[1] < bbox[
-            3] + 1:  # 确定在这个大的block内,然后进入逐行比较距离
-            for l in block['lines']:
+        if (
+            bbox[0] - 1 <= element_bbox[0] < bbox[2] + 1
+            and bbox[1] - 1 <= element_bbox[1] < bbox[3] + 1
+        ):  # 确定在这个大的block内,然后进入逐行比较距离
+            for l in block['lines']:  # noqa: E741
                 line_box = l['bbox']
-                if line_box[0] - 1 <= element_bbox[0] < line_box[2] + 1 and line_box[1] - 1 <= element_bbox[1] < line_box[
-                    3] + 1:  # 在line内的,插入line前面
-                    line_txt = "".join([s['text'] for s in l['spans']])
+                if (
+                    line_box[0] - 1 <= element_bbox[0] < line_box[2] + 1
+                    and line_box[1] - 1 <= element_bbox[1] < line_box[3] + 1
+                ):  # 在line内的,插入line前面
+                    line_txt = ''.join([s['text'] for s in l['spans']])
                     __insert_before_para(line_txt, type, element, content_lst)
                     break
                 break
@@ -327,14 +360,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
                 # 找到图片x0,y0与line的x0,y0最近的line
                 min_distance = 100000
                 min_line = None
-                for l in block['lines']:
+                for l in block['lines']:  # noqa: E741
                     line_box = l['bbox']
-                    distance = math.sqrt((line_box[0] - element_bbox[0]) ** 2 + (line_box[1] - element_bbox[1]) ** 2)
+                    distance = math.sqrt(
+                        (line_box[0] - element_bbox[0]) ** 2
+                        + (line_box[1] - element_bbox[1]) ** 2
+                    )
                     if distance < min_distance:
                         min_distance = distance
                         min_line = l
                 if min_line:
-                    line_txt = "".join([s['text'] for s in min_line['spans']])
+                    line_txt = ''.join([s['text'] for s in min_line['spans']])
                     img_h = element_bbox[3] - element_bbox[1]
                     if min_distance < img_h:  # 文字在图片前面
                         __insert_after_para(line_txt, type, element, content_lst)
@@ -342,56 +378,61 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
                         __insert_before_para(line_txt, type, element, content_lst)
                     break
                 else:
-                    logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file #1")
+                    logger.error(
+                        f"Can't find the location of image {element.get('image_path')} in the markdown file  #1"
+                    )
     else:  # 应当在两个block之间
         # 找到上方最近的block,如果上方没有就找大下方最近的block
         top_txt_block = find_top_nearest_text_bbox(pymu_raw_blocks, element_bbox)
         if top_txt_block:
-            line_txt = "".join([s['text'] for s in top_txt_block['lines'][-1]['spans']])
+            line_txt = ''.join([s['text'] for s in top_txt_block['lines'][-1]['spans']])
             __insert_after_para(line_txt, type, element, content_lst)
         else:
-            bottom_txt_block = find_bottom_nearest_text_bbox(pymu_raw_blocks, element_bbox)
+            bottom_txt_block = find_bottom_nearest_text_bbox(
+                pymu_raw_blocks, element_bbox
+            )
             if bottom_txt_block:
-                line_txt = "".join([s['text'] for s in bottom_txt_block['lines'][0]['spans']])
+                line_txt = ''.join(
+                    [s['text'] for s in bottom_txt_block['lines'][0]['spans']]
+                )
                 __insert_before_para(line_txt, type, element, content_lst)
             else:  # TODO ,图片可能独占一列,这种情况上下是没有图片的
-                logger.error(f"Can't find the location of image {element.get('image_path')} in the markdown file #2")
+                logger.error(
+                    f"Can't find the location of image {element.get('image_path')} in the markdown file  #2"
+                )
 
 
 def mk_mm_markdown(content_list):
-    """
-    基于同一格式的内容列表,构造markdown,含图片
-    """
+    """基于同一格式的内容列表,构造markdown,含图片."""
     content_md = []
     for c in content_list:
-        content_type = c.get("type")
-        if content_type == "text":
-            content_md.append(c.get("text"))
-        elif content_type == "equation":
-            content = c.get("latex")
-            if content.startswith("$$") and content.endswith("$$"):
+        content_type = c.get('type')
+        if content_type == 'text':
+            content_md.append(c.get('text'))
+        elif content_type == 'equation':
+            content = c.get('latex')
+            if content.startswith('$$') and content.endswith('$$'):
                 content_md.append(content)
             else:
                 content_md.append(f"\n$$\n{c.get('latex')}\n$$\n")
         elif content_type in UNI_FORMAT_TEXT_TYPE:
             content_md.append(f"{'#'*int(content_type[1])} {c.get('text')}")
-        elif content_type == "image":
+        elif content_type == 'image':
             content_md.append(f"![]({c.get('img_path')})")
-    return "\n\n".join(content_md)
+    return '\n\n'.join(content_md)
+
 
 def mk_nlp_markdown(content_list):
-    """
-    基于同一格式的内容列表,构造markdown,不含图片
-    """
+    """基于同一格式的内容列表,构造markdown,不含图片."""
     content_md = []
     for c in content_list:
-        content_type = c.get("type")
-        if content_type == "text":
-            content_md.append(c.get("text"))
-        elif content_type == "equation":
+        content_type = c.get('type')
+        if content_type == 'text':
+            content_md.append(c.get('text'))
+        elif content_type == 'equation':
             content_md.append(f"$$\n{c.get('latex')}\n$$")
-        elif content_type == "table":
+        elif content_type == 'table':
             content_md.append(f"$$$\n{c.get('table_latex')}\n$$$")
         elif content_type in UNI_FORMAT_TEXT_TYPE:
             content_md.append(f"{'#'*int(content_type[1])} {c.get('text')}")
-    return "\n\n".join(content_md)
+    return '\n\n'.join(content_md)

+ 7 - 8
magic_pdf/dict2md/ocr_mkcontent.py

@@ -2,21 +2,20 @@ import re
 
 from loguru import logger
 
+from magic_pdf.config.make_content_config import DropMode, MakeMode
+from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.libs.commons import join_path
 from magic_pdf.libs.language import detect_lang
-from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
-from magic_pdf.libs.ocr_content_type import BlockType, ContentType
 from magic_pdf.para.para_split_v3 import ListLineTag
 
 
 def __is_hyphen_at_line_end(line):
-    """
-    Check if a line ends with one or more letters followed by a hyphen.
-    
+    """Check if a line ends with one or more letters followed by a hyphen.
+
     Args:
     line (str): The line of text to check.
-    
+
     Returns:
     bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
     """
@@ -162,7 +161,7 @@ def merge_para_with_text(para_block):
                     if span_type in [ContentType.Text, ContentType.InterlineEquation]:
                         para_text += content  # 中文/日语/韩文语境下,content间不需要空格分隔
                     elif span_type == ContentType.InlineEquation:
-                        para_text += f" {content} "
+                        para_text += f' {content} '
                 else:
                     if span_type in [ContentType.Text, ContentType.InlineEquation]:
                         # 如果是前一行带有-连字符,那么末尾不应该加空格
@@ -171,7 +170,7 @@ def merge_para_with_text(para_block):
                         elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
                             para_text += content
                         else:  # 西方文本语境下 content间需要空格分隔
-                            para_text += f"{content} "
+                            para_text += f'{content} '
                     elif span_type == ContentType.InterlineEquation:
                         para_text += content
             else:

+ 101 - 79
magic_pdf/filter/pdf_meta_scan.py

@@ -1,16 +1,13 @@
-"""
-输入: s3路径,每行一个
-输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置
-"""
+"""输入: s3路径,每行一个 输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置."""
+
 import sys
-import click
+from collections import Counter
 
-from magic_pdf.libs.commons import read_file, mymax, get_top_percent_list
-from magic_pdf.libs.commons import fitz
+import click
 from loguru import logger
-from collections import Counter
 
-from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.config.drop_reason import DropReason
+from magic_pdf.libs.commons import fitz, get_top_percent_list, mymax, read_file
 from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.pdf_check import detect_invalid_chars
 
@@ -19,8 +16,10 @@ junk_limit_min = 10
 
 
 def calculate_max_image_area_per_page(result: list, page_width_pts, page_height_pts):
-    max_image_area_per_page = [mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz]) for page_img_sz in
-                               result]
+    max_image_area_per_page = [
+        mymax([(x1 - x0) * (y1 - y0) for x0, y0, x1, y1, _ in page_img_sz])
+        for page_img_sz in result
+    ]
     page_area = int(page_width_pts) * int(page_height_pts)
     max_image_area_per_page = [area / page_area for area in max_image_area_per_page]
     max_image_area_per_page = [area for area in max_image_area_per_page if area > 0.6]
@@ -32,8 +31,10 @@ def process_image(page, junk_img_bojids=[]):
     items = page.get_images()
     dedup = set()
     for img in items:
-        # 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
-        img_bojid = img[0]  # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
+        #  这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
+        img_bojid = img[
+            0
+        ]  # 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
         if img_bojid in junk_img_bojids:  # 如果是垃圾图像,就跳过
             continue
         recs = page.get_image_rects(img, transform=True)
@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]):
             x0, y0, x1, y1 = map(int, rec)
             width = x1 - x0
             height = y1 - y0
-            if (x0, y0, x1, y1, img_bojid) in dedup:  # 这里面会出现一些重复的bbox,无需重复出现,需要去掉
+            if (
+                x0,
+                y0,
+                x1,
+                y1,
+                img_bojid,
+            ) in dedup:  # 这里面会出现一些重复的bbox,无需重复出现,需要去掉
                 continue
-            if not all([width, height]):  # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
+            if not all(
+                [width, height]
+            ):  # 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
                 continue
             dedup.add((x0, y0, x1, y1, img_bojid))
             page_result.append([x0, y0, x1, y1, img_bojid])
@@ -52,29 +61,33 @@ def process_image(page, junk_img_bojids=[]):
 
 
 def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
-    """
-    返回每个页面里的图片的四元组,每个页面多个图片。
+    """返回每个页面里的图片的四元组,每个页面多个图片。
+
     :param doc:
     :return:
     """
-    # 使用 Counter 计数 img_bojid 的出现次数
+    #  使用 Counter 计数 img_bojid 的出现次数
     img_bojid_counter = Counter(img[0] for page in doc for img in page.get_images())
-    # 找出出现次数超过 len(doc) 半数的 img_bojid
+    #  找出出现次数超过 len(doc) 半数的 img_bojid
 
     junk_limit = max(len(doc) * 0.5, junk_limit_min)  # 对一些页数比较少的进行豁免
 
-    junk_img_bojids = [img_bojid for img_bojid, count in img_bojid_counter.items() if count >= junk_limit]
-
-    #todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
-    #有两种扫描版,一种文字版,这里可能会有误判
-    #扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
-    #扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
-    #文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
+    junk_img_bojids = [
+        img_bojid
+        for img_bojid, count in img_bojid_counter.items()
+        if count >= junk_limit
+    ]
+
+    #  todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
+    #  有两种扫描版,一种文字版,这里可能会有误判
+    #  扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
+    #  扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
+    # 文  字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
     imgs_len_list = [len(page.get_images()) for page in doc]
 
     special_limit_pages = 10
 
-    # 统一用前十页结果做判断
+    #  统一用前十页结果做判断
     result = []
     break_loop = False
     for i, page in enumerate(doc):
@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
             break
         if i >= special_limit_pages:
             break
-        page_result = process_image(page)  # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
+        page_result = process_image(
+            page
+        )  # 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
         result.append(page_result)
         for item in result:
-            if not any(item):  # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
-                if max(imgs_len_list) == min(imgs_len_list) and max(
-                        imgs_len_list) >= junk_limit_min:  # 如果是特殊文字版,就把junklist置空并break
+            if not any(
+                item
+            ):  # 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
+                if (
+                    max(imgs_len_list) == min(imgs_len_list)
+                    and max(imgs_len_list) >= junk_limit_min
+                ):  # 如果是特殊文字版,就把junklist置空并break
                     junk_img_bojids = []
                 else:  # 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
                     pass
@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
         top_eighty_percent = get_top_percent_list(imgs_len_list, 0.8)
         # 检查前80%的元素是否都相等
         if len(set(top_eighty_percent)) == 1 and max(imgs_len_list) >= junk_limit_min:
-
             # # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
             # if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
 
-            #前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
-            max_image_area_per_page = calculate_max_image_area_per_page(result, page_width_pts, page_height_pts)
-            if len(max_image_area_per_page) < 0.8 * special_limit_pages:  # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
+            # 前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
+            max_image_area_per_page = calculate_max_image_area_per_page(
+                result, page_width_pts, page_height_pts
+            )
+            if (
+                len(max_image_area_per_page) < 0.8 * special_limit_pages
+            ):  # 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
                 junk_img_bojids = []
             else:  # 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
                 pass
         else:  # 每页图片数量不一致,需要清掉junklist全量跑前50页图片
             junk_img_bojids = []
 
-    #正式进入取前50页图片的信息流程
+    # 正式进入取前50页图片的信息流程
     result = []
     for i, page in enumerate(doc):
         if i >= scan_max_page:
@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
 def get_pdf_page_size_pts(doc: fitz.Document):
     page_cnt = len(doc)
     l: int = min(page_cnt, 50)
-    #把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
+    # 把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
     page_width_list = []
     page_height_list = []
     for i in range(l):
@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
         # 拿所有text的blocks
         # text_block = page.get_text("words")
         # text_block_len = sum([len(t[4]) for t in text_block])
-        #拿所有text的str
-        text_block = page.get_text("text")
+        # 拿所有text的str
+        text_block = page.get_text('text')
         text_block_len = len(text_block)
         # logger.info(f"page {page.number} text_block_len: {text_block_len}")
         text_len_lst.append(text_block_len)
@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
 
 
 def get_pdf_text_layout_per_page(doc: fitz.Document):
-    """
-    根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
+    """根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
 
     Args:
         doc (fitz.Document): PDF文档对象。
 
     Returns:
         List[str]: 每一页的文本布局(横向、纵向、未知)。
-
     """
     text_layout_list = []
 
@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
         # 创建每一页的纵向和横向的文本行数计数器
         vertical_count = 0
         horizontal_count = 0
-        text_dict = page.get_text("dict")
-        if "blocks" in text_dict:
-            for block in text_dict["blocks"]:
+        text_dict = page.get_text('dict')
+        if 'blocks' in text_dict:
+            for block in text_dict['blocks']:
                 if 'lines' in block:
-                    for line in block["lines"]:
+                    for line in block['lines']:
                         # 获取line的bbox顶点坐标
                         x0, y0, x1, y1 = line['bbox']
                         # 计算bbox的宽高
@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
                         if len(font_sizes) > 0:
                             average_font_size = sum(font_sizes) / len(font_sizes)
                         else:
-                            average_font_size = 10  # 有的line拿不到font_size,先定一个阈值100
-                        if area <= average_font_size ** 2:  # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
+                            average_font_size = (
+                                10  # 有的line拿不到font_size,先定一个阈值100
+                            )
+                        if (
+                            area <= average_font_size**2
+                        ):  # 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
                             continue
                         else:
                             if 'wmode' in line:  # 通过wmode判断文本方向
@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
         # print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
         # 判断每一页的文本布局
         if vertical_count == 0 and horizontal_count == 0:  # 该页没有文本,无法判断
-            text_layout_list.append("unknow")
+            text_layout_list.append('unknow')
             continue
         else:
             if vertical_count > horizontal_count:  # 该页的文本纵向行数大于横向的
-                text_layout_list.append("vertical")
+                text_layout_list.append('vertical')
             else:  # 该页的文本横向行数大于纵向的
-                text_layout_list.append("horizontal")
+                text_layout_list.append('horizontal')
         # logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
     return text_layout_list
 
 
-'''定义一个自定义异常用来抛出单页svg太多的pdf'''
+"""定义一个自定义异常用来抛出单页svg太多的pdf"""
 
 
 class PageSvgsTooManyError(Exception):
-    def __init__(self, message="Page SVGs are too many"):
+    def __init__(self, message='Page SVGs are too many'):
         self.message = message
         super().__init__(self.message)
 
@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document):
         if page_id >= scan_max_page:
             break
         # 拿所有text的str
-        text_block = page.get_text("text")
+        text_block = page.get_text('text')
         page_language = detect_lang(text_block)
         language_lst.append(page_language)
 
@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document):
 
 
 def check_invalid_chars(pdf_bytes):
-    """
-    乱码检测
-    """
+    """乱码检测."""
     return detect_invalid_chars(pdf_bytes)
 
 
@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes):
     :param pdf_bytes: pdf文件的二进制数据
     几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
     """
-    doc = fitz.open("pdf", pdf_bytes)
+    doc = fitz.open('pdf', pdf_bytes)
     is_needs_password = doc.needs_pass
     is_encrypted = doc.is_encrypted
     total_page = len(doc)
     if total_page == 0:
-        logger.warning(f"drop this pdf, drop_reason: {DropReason.EMPTY_PDF}")
-        result = {"_need_drop": True, "_drop_reason": DropReason.EMPTY_PDF}
+        logger.warning(f'drop this pdf, drop_reason: {DropReason.EMPTY_PDF}')
+        result = {'_need_drop': True, '_drop_reason': DropReason.EMPTY_PDF}
         return result
     else:
         page_width_pts, page_height_pts = get_pdf_page_size_pts(doc)
@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
         imgs_per_page = get_imgs_per_page(doc)
         # logger.info(f"imgs_per_page: {imgs_per_page}")
 
-        image_info_per_page, junk_img_bojids = get_image_info(doc, page_width_pts, page_height_pts)
+        image_info_per_page, junk_img_bojids = get_image_info(
+            doc, page_width_pts, page_height_pts
+        )
         # logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
         text_len_per_page = get_pdf_textlen_per_page(doc)
         # logger.info(f"text_len_per_page: {text_len_per_page}")
@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes):
 
         # 最后输出一条json
         res = {
-            "is_needs_password": is_needs_password,
-            "is_encrypted": is_encrypted,
-            "total_page": total_page,
-            "page_width_pts": int(page_width_pts),
-            "page_height_pts": int(page_height_pts),
-            "image_info_per_page": image_info_per_page,
-            "text_len_per_page": text_len_per_page,
-            "text_layout_per_page": text_layout_per_page,
-            "text_language": text_language,
+            'is_needs_password': is_needs_password,
+            'is_encrypted': is_encrypted,
+            'total_page': total_page,
+            'page_width_pts': int(page_width_pts),
+            'page_height_pts': int(page_height_pts),
+            'image_info_per_page': image_info_per_page,
+            'text_len_per_page': text_len_per_page,
+            'text_layout_per_page': text_layout_per_page,
+            'text_language': text_language,
             # "svgs_per_page": svgs_per_page,
-            "imgs_per_page": imgs_per_page,  # 增加每页img数量list
-            "junk_img_bojids": junk_img_bojids,  # 增加垃圾图片的bojid list
-            "invalid_chars": invalid_chars,
-            "metadata": doc.metadata
+            'imgs_per_page': imgs_per_page,  # 增加每页img数量list
+            'junk_img_bojids': junk_img_bojids,  # 增加垃圾图片的bojid list
+            'invalid_chars': invalid_chars,
+            'metadata': doc.metadata,
         }
         # logger.info(json.dumps(res, ensure_ascii=False))
         return res
@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes):
 @click.option('--s3-pdf-path', help='s3上pdf文件的路径')
 @click.option('--s3-profile', help='s3上的profile')
 def main(s3_pdf_path: str, s3_profile: str):
-    """
-
-    """
+    """"""
     try:
         file_content = read_file(s3_pdf_path, s3_profile)
         pdf_meta_scan(file_content)
     except Exception as e:
-        print(f"ERROR: {s3_pdf_path}, {e}", file=sys.stderr)
+        print(f'ERROR: {s3_pdf_path}, {e}', file=sys.stderr)
         logger.exception(e)
 
 
@@ -381,7 +403,7 @@ if __name__ == '__main__':
     # "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
     # "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
     # "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
-    # file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","")
+    # file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","")  # noqa: E501
     # file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
     # doc = fitz.open("pdf", file_content)
     # text_layout_lst = get_pdf_text_layout_per_page(doc)

+ 1 - 1
magic_pdf/integrations/rag/utils.py

@@ -5,13 +5,13 @@ from pathlib import Path
 from loguru import logger
 
 import magic_pdf.model as model_config
+from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.data.data_reader_writer import FileBasedDataReader
 from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
                                              ElementRelation, ElementRelType,
                                              LayoutElements,
                                              LayoutElementsExtra, PageInfo)
-from magic_pdf.libs.ocr_content_type import BlockType, ContentType
 from magic_pdf.tools.common import do_parse, prepare_env
 
 

+ 0 - 55
magic_pdf/libs/Constants.py

@@ -1,55 +0,0 @@
-"""
-span维度自定义字段
-"""
-# span是否是跨页合并的
-CROSS_PAGE = "cross_page"
-
-"""
-block维度自定义字段
-"""
-# block中lines是否被删除
-LINES_DELETED = "lines_deleted"
-
-# table recognition max time default value
-TABLE_MAX_TIME_VALUE = 400
-
-# pp_table_result_max_length
-TABLE_MAX_LEN = 480
-
-# table master structure dict
-TABLE_MASTER_DICT = "table_master_structure_dict.txt"
-
-# table master dir
-TABLE_MASTER_DIR = "table_structure_tablemaster_infer/"
-
-# pp detect model dir
-DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer"
-
-# pp rec model dir
-REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer"
-
-# pp rec char dict path
-REC_CHAR_DICT = "ppocr_keys_v1.txt"
-
-# pp rec copy rec directory
-PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer"
-
-# pp rec copy det directory
-PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer"
-
-
-class MODEL_NAME:
-    # pp table structure algorithm
-    TABLE_MASTER = "tablemaster"
-    # struct eqtable
-    STRUCT_EQTABLE = "struct_eqtable"
-
-    DocLayout_YOLO = "doclayout_yolo"
-
-    LAYOUTLMv3 = "layoutlmv3"
-
-    YOLO_V8_MFD = "yolo_v8_mfd"
-
-    UniMerNet_v2_Small = "unimernet_small"
-
-    RAPID_TABLE = "rapid_table"

+ 0 - 11
magic_pdf/libs/MakeContentConfig.py

@@ -1,11 +0,0 @@
-class MakeMode:
-    MM_MD = "mm_markdown"
-    NLP_MD = "nlp_markdown"
-    STANDARD_FORMAT = "standard_format"
-
-
-class DropMode:
-    WHOLE_PDF = "whole_pdf"
-    SINGLE_PAGE = "single_page"
-    NONE = "none"
-    NONE_WITH_REASON = "none_with_reason"

+ 5 - 5
magic_pdf/libs/config_reader.py

@@ -5,7 +5,7 @@ import os
 
 from loguru import logger
 
-from magic_pdf.libs.Constants import MODEL_NAME
+from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.libs.commons import parse_bucket_key
 
 # 定义配置文件名常量
@@ -99,7 +99,7 @@ def get_table_recog_config():
 
 def get_layout_config():
     config = read_config()
-    layout_config = config.get("layout-config")
+    layout_config = config.get('layout-config')
     if layout_config is None:
         logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
         return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
@@ -109,7 +109,7 @@ def get_layout_config():
 
 def get_formula_config():
     config = read_config()
-    formula_config = config.get("formula-config")
+    formula_config = config.get('formula-config')
     if formula_config is None:
         logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
         return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
@@ -117,5 +117,5 @@ def get_formula_config():
         return formula_config
 
 
-if __name__ == "__main__":
-    ak, sk, endpoint = get_s3_config("llm-raw")
+if __name__ == '__main__':
+    ak, sk, endpoint = get_s3_config('llm-raw')

+ 3 - 2
magic_pdf/libs/draw_bbox.py

@@ -1,7 +1,8 @@
+from magic_pdf.config.constants import CROSS_PAGE
+from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
+                                               ContentType)
 from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.libs.commons import fitz  # PyMuPDF
-from magic_pdf.libs.Constants import CROSS_PAGE
-from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
 from magic_pdf.model.magic_model import MagicModel
 
 

+ 0 - 27
magic_pdf/libs/drop_reason.py

@@ -1,27 +0,0 @@
-
-class DropReason:
-    TEXT_BLCOK_HOR_OVERLAP = "text_block_horizontal_overlap" # 文字块有水平互相覆盖,导致无法准确定位文字顺序
-    USEFUL_BLOCK_HOR_OVERLAP = "useful_block_horizontal_overlap" # 需保留的block水平覆盖
-    COMPLICATED_LAYOUT = "complicated_layout" # 复杂的布局,暂时不支持
-    TOO_MANY_LAYOUT_COLUMNS = "too_many_layout_columns" # 目前不支持分栏超过2列的
-    COLOR_BACKGROUND_TEXT_BOX = "color_background_text_box" # 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
-    HIGH_COMPUTATIONAL_lOAD_BY_IMGS = "high_computational_load_by_imgs" # 含特殊图片,计算量太大,从而丢弃
-    HIGH_COMPUTATIONAL_lOAD_BY_SVGS = "high_computational_load_by_svgs" # 特殊的SVG图,计算量太大,从而丢弃
-    HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES = "high_computational_load_by_total_pages" # 计算量超过负荷,当前方法下计算量消耗过大
-    MISS_DOC_LAYOUT_RESULT = "missing doc_layout_result" # 版面分析失败
-    Exception = "_exception" # 解析中发生异常
-    ENCRYPTED = "encrypted" # PDF是加密的
-    EMPTY_PDF = "total_page=0" # PDF页面总数为0
-    NOT_IS_TEXT_PDF = "not_is_text_pdf" # 不是文字版PDF,无法直接解析
-    DENSE_SINGLE_LINE_BLOCK = "dense_single_line_block" # 无法清晰的分段
-    TITLE_DETECTION_FAILED = "title_detection_failed" # 探测标题失败
-    TITLE_LEVEL_FAILED = "title_level_failed" # 分析标题级别失败(例如一级、二级、三级标题)
-    PARA_SPLIT_FAILED = "para_split_failed" # 识别段落失败
-    PARA_MERGE_FAILED = "para_merge_failed" # 段落合并失败
-    NOT_ALLOW_LANGUAGE = "not_allow_language" # 不支持的语种
-    SPECIAL_PDF = "special_pdf"
-    PSEUDO_SINGLE_COLUMN = "pseudo_single_column" # 无法精确判断文字分栏
-    CAN_NOT_DETECT_PAGE_LAYOUT="can_not_detect_page_layout" # 无法分析页面的版面
-    NEGATIVE_BBOX_AREA = "negative_bbox_area" # 缩放导致 bbox 面积为负
-    OVERLAP_BLOCKS_CAN_NOT_SEPARATION = "overlap_blocks_can_t_separation" # 无法分离重叠的block
-    

+ 0 - 19
magic_pdf/libs/drop_tag.py

@@ -1,19 +0,0 @@
-
-COLOR_BG_HEADER_TXT_BLOCK = "color_background_header_txt_block"
-PAGE_NO = "page-no" # 页码
-CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
-VERTICAL_TEXT = 'vertical-text' # 垂直文本
-ROTATE_TEXT = 'rotate-text' # 旋转文本
-EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
-ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
-ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
-
-
-class DropTag:
-    PAGE_NUMBER = "page_no"
-    HEADER = "header"
-    FOOTER = "footer"
-    FOOTNOTE = "footnote"
-    NOT_IN_LAYOUT = "not_in_layout"
-    SPAN_OVERLAP = "span_overlap"
-    BLOCK_OVERLAP = "block_overlap"

+ 2 - 2
magic_pdf/model/magic_model.py

@@ -1,6 +1,8 @@
 import enum
 import json
 
+from magic_pdf.config.model_block_type import ModelBlockTypeEnum
+from magic_pdf.config.ocr_content_type import CategoryId, ContentType
 from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
                                                FileBasedDataWriter)
 from magic_pdf.data.dataset import Dataset
@@ -11,8 +13,6 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
 from magic_pdf.libs.commons import fitz, join_path
 from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.local_math import float_gt
-from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
-from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
 from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
 
 CAPATION_OVERLAP_AREA_RATIO = 0.6

+ 109 - 59
magic_pdf/model/pdf_extract_kit.py

@@ -1,10 +1,12 @@
-import numpy as np
-import torch
-from loguru import logger
+# flake8: noqa
 import os
 import time
+
 import cv2
+import numpy as np
+import torch
 import yaml
+from loguru import logger
 from PIL import Image
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
@@ -13,20 +15,21 @@ os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
 try:
     import torchtext
 
-    if torchtext.__version__ >= "0.18.0":
+    if torchtext.__version__ >= '0.18.0':
         torchtext.disable_torchtext_deprecation_warning()
 except ImportError:
     pass
 
-from magic_pdf.libs.Constants import *
+from magic_pdf.config.constants import *
 from magic_pdf.model.model_list import AtomicModel
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
-from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
+from magic_pdf.model.sub_modules.model_utils import (
+    clean_vram, crop_img, get_res_list_from_layout_res)
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
+    get_adjusted_mfdetrec_res, get_ocr_result_list)
 
 
 class CustomPEKModel:
-
     def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
         """
         ======== model init ========
@@ -41,42 +44,54 @@ class CustomPEKModel:
         model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
         # 构建 model_configs.yaml 文件的完整路径
         config_path = os.path.join(model_config_dir, 'model_configs.yaml')
-        with open(config_path, "r", encoding='utf-8') as f:
+        with open(config_path, 'r', encoding='utf-8') as f:
             self.configs = yaml.load(f, Loader=yaml.FullLoader)
         # 初始化解析配置
 
         # layout config
-        self.layout_config = kwargs.get("layout_config")
-        self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
+        self.layout_config = kwargs.get('layout_config')
+        self.layout_model_name = self.layout_config.get(
+            'model', MODEL_NAME.DocLayout_YOLO
+        )
 
         # formula config
-        self.formula_config = kwargs.get("formula_config")
-        self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
-        self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
-        self.apply_formula = self.formula_config.get("enable", True)
+        self.formula_config = kwargs.get('formula_config')
+        self.mfd_model_name = self.formula_config.get(
+            'mfd_model', MODEL_NAME.YOLO_V8_MFD
+        )
+        self.mfr_model_name = self.formula_config.get(
+            'mfr_model', MODEL_NAME.UniMerNet_v2_Small
+        )
+        self.apply_formula = self.formula_config.get('enable', True)
 
         # table config
-        self.table_config = kwargs.get("table_config")
-        self.apply_table = self.table_config.get("enable", False)
-        self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
-        self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
+        self.table_config = kwargs.get('table_config')
+        self.apply_table = self.table_config.get('enable', False)
+        self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
+        self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
 
         # ocr config
         self.apply_ocr = ocr
-        self.lang = kwargs.get("lang", None)
+        self.lang = kwargs.get('lang', None)
 
         logger.info(
-            "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
-            "apply_table: {}, table_model: {}, lang: {}".format(
-                self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
-                self.lang
+            'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
+            'apply_table: {}, table_model: {}, lang: {}'.format(
+                self.layout_model_name,
+                self.apply_formula,
+                self.apply_ocr,
+                self.apply_table,
+                self.table_model_name,
+                self.lang,
             )
         )
         # 初始化解析方案
-        self.device = kwargs.get("device", "cpu")
-        logger.info("using device: {}".format(self.device))
-        models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
-        logger.info("using models_dir: {}".format(models_dir))
+        self.device = kwargs.get('device', 'cpu')
+        logger.info('using device: {}'.format(self.device))
+        models_dir = kwargs.get(
+            'models_dir', os.path.join(root_dir, 'resources', 'models')
+        )
+        logger.info('using models_dir: {}'.format(models_dir))
 
         atom_model_manager = AtomModelSingleton()
 
@@ -85,18 +100,24 @@ class CustomPEKModel:
             # 初始化公式检测模型
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
-                mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
-                device=self.device
+                mfd_weights=str(
+                    os.path.join(
+                        models_dir, self.configs['weights'][self.mfd_model_name]
+                    )
+                ),
+                device=self.device,
             )
 
             # 初始化公式解析模型
-            mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
-            mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
+            mfr_weight_dir = str(
+                os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
+            )
+            mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
             self.mfr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_cfg_path=mfr_cfg_path,
-                device=self.device
+                device=self.device,
             )
 
         # 初始化layout模型
@@ -104,16 +125,28 @@ class CustomPEKModel:
             self.layout_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Layout,
                 layout_model_name=MODEL_NAME.LAYOUTLMv3,
-                layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
-                layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
-                device=self.device
+                layout_weights=str(
+                    os.path.join(
+                        models_dir, self.configs['weights'][self.layout_model_name]
+                    )
+                ),
+                layout_config_file=str(
+                    os.path.join(
+                        model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
+                    )
+                ),
+                device=self.device,
             )
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             self.layout_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Layout,
                 layout_model_name=MODEL_NAME.DocLayout_YOLO,
-                doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
-                device=self.device
+                doclayout_yolo_weights=str(
+                    os.path.join(
+                        models_dir, self.configs['weights'][self.layout_model_name]
+                    )
+                ),
+                device=self.device,
             )
         # 初始化ocr
         if self.apply_ocr:
@@ -121,23 +154,22 @@ class CustomPEKModel:
                 atom_model_name=AtomicModel.OCR,
                 ocr_show_log=show_log,
                 det_db_box_thresh=0.3,
-                lang=self.lang
+                lang=self.lang,
             )
         # init table model
         if self.apply_table:
-            table_model_dir = self.configs["weights"][self.table_model_name]
+            table_model_dir = self.configs['weights'][self.table_model_name]
             self.table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.Table,
                 table_model_name=self.table_model_name,
                 table_model_path=str(os.path.join(models_dir, table_model_dir)),
                 table_max_time=self.table_max_time,
-                device=self.device
+                device=self.device,
             )
 
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
-
         page_start = time.time()
 
         # layout检测
@@ -150,7 +182,7 @@ class CustomPEKModel:
             # doclayout_yolo
             layout_res = self.layout_model.predict(image)
         layout_cost = round(time.time() - layout_start, 2)
-        logger.info(f"layout detection time: {layout_cost}")
+        logger.info(f'layout detection time: {layout_cost}')
 
         pil_img = Image.fromarray(image)
 
@@ -158,32 +190,40 @@ class CustomPEKModel:
             # 公式检测
             mfd_start = time.time()
             mfd_res = self.mfd_model.predict(image)
-            logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
+            logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
 
             # 公式识别
             mfr_start = time.time()
             formula_list = self.mfr_model.predict(mfd_res, image)
             layout_res.extend(formula_list)
             mfr_cost = round(time.time() - mfr_start, 2)
-            logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
+            logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
 
         # 清理显存
         clean_vram(self.device, vram_threshold=8)
 
         # 从layout_res中获取ocr区域、表格区域、公式区域
-        ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
+        ocr_res_list, table_res_list, single_page_mfdetrec_res = (
+            get_res_list_from_layout_res(layout_res)
+        )
 
         # ocr识别
         if self.apply_ocr:
             ocr_start = time.time()
             # Process each area that requires OCR processing
             for res in ocr_res_list:
-                new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
-                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
+                new_image, useful_list = crop_img(
+                    res, pil_img, crop_paste_x=50, crop_paste_y=50
+                )
+                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                    single_page_mfdetrec_res, useful_list
+                )
 
                 # OCR recognition
                 new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
+                ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[
+                    0
+                ]
 
                 # Integration results
                 if ocr_res:
@@ -191,7 +231,7 @@ class CustomPEKModel:
                     layout_res.extend(ocr_result_list)
 
             ocr_cost = round(time.time() - ocr_start, 2)
-            logger.info(f"ocr time: {ocr_cost}")
+            logger.info(f'ocr time: {ocr_cost}')
 
         # 表格识别 table recognition
         if self.apply_table:
@@ -202,27 +242,37 @@ class CustomPEKModel:
                 html_code = None
                 if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
-                        table_result = self.table_model.predict(new_image, "html")
+                        table_result = self.table_model.predict(new_image, 'html')
                         if len(table_result) > 0:
                             html_code = table_result[0]
                 elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
                     html_code = self.table_model.img2html(new_image)
                 elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
-                    html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
+                    html_code, table_cell_bboxes, elapse = self.table_model.predict(
+                        new_image
+                    )
                 run_time = time.time() - single_table_start_time
                 if run_time > self.table_max_time:
-                    logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
+                    logger.warning(
+                        f'table recognition processing exceeds max time {self.table_max_time}s'
+                    )
                 # 判断是否返回正常
                 if html_code:
-                    expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
+                    expected_ending = html_code.strip().endswith(
+                        '</html>'
+                    ) or html_code.strip().endswith('</table>')
                     if expected_ending:
-                        res["html"] = html_code
+                        res['html'] = html_code
                     else:
-                        logger.warning(f"table recognition processing fails, not found expected HTML table end")
+                        logger.warning(
+                            'table recognition processing fails, not found expected HTML table end'
+                        )
                 else:
-                    logger.warning(f"table recognition processing fails, not get html return")
-            logger.info(f"table time: {round(time.time() - table_start, 2)}")
+                    logger.warning(
+                        'table recognition processing fails, not get html return'
+                    )
+            logger.info(f'table time: {round(time.time() - table_start, 2)}')
 
-        logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
+        logger.info(f'-----page total time: {round(time.time() - page_start, 2)}-----')
 
         return layout_res

+ 39 - 34
magic_pdf/model/sub_modules/model_init.py

@@ -1,17 +1,22 @@
 from loguru import logger
 
-from magic_pdf.libs.Constants import MODEL_NAME
+from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
-from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
+from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
+    DocLayoutYOLOModel
+from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
+    Layoutlmv3_Predictor
 from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
-
 from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
-from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
+from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
+    ModifiedPaddleOCR
+from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
+    RapidTableModel
 # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
-from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
-from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
-from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
+from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
+    StructTableModel
+from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
+    TableMasterPaddleModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
         table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
         config = {
-            "model_dir": model_path,
-            "device": _device_
+            'model_dir': model_path,
+            'device': _device_
         }
         table_model = TableMasterPaddleModel(config)
     elif table_model_type == MODEL_NAME.RAPID_TABLE:
         table_model = RapidTableModel()
     else:
-        logger.error("table model type not allow")
+        logger.error('table model type not allow')
         exit(1)
 
     return table_model
@@ -87,8 +92,8 @@ class AtomModelSingleton:
         return cls._instance
 
     def get_atom_model(self, atom_model_name: str, **kwargs):
-        lang = kwargs.get("lang", None)
-        layout_model_name = kwargs.get("layout_model_name", None)
+        lang = kwargs.get('lang', None)
+        layout_model_name = kwargs.get('layout_model_name', None)
         key = (atom_model_name, layout_model_name, lang)
         if key not in self._models:
             self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
@@ -98,47 +103,47 @@ class AtomModelSingleton:
 def atom_model_init(model_name: str, **kwargs):
     atom_model = None
     if model_name == AtomicModel.Layout:
-        if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
+        if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
             atom_model = layout_model_init(
-                kwargs.get("layout_weights"),
-                kwargs.get("layout_config_file"),
-                kwargs.get("device")
+                kwargs.get('layout_weights'),
+                kwargs.get('layout_config_file'),
+                kwargs.get('device')
             )
-        elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
+        elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
             atom_model = doclayout_yolo_model_init(
-                kwargs.get("doclayout_yolo_weights"),
-                kwargs.get("device")
+                kwargs.get('doclayout_yolo_weights'),
+                kwargs.get('device')
             )
     elif model_name == AtomicModel.MFD:
         atom_model = mfd_model_init(
-            kwargs.get("mfd_weights"),
-            kwargs.get("device")
+            kwargs.get('mfd_weights'),
+            kwargs.get('device')
         )
     elif model_name == AtomicModel.MFR:
         atom_model = mfr_model_init(
-            kwargs.get("mfr_weight_dir"),
-            kwargs.get("mfr_cfg_path"),
-            kwargs.get("device")
+            kwargs.get('mfr_weight_dir'),
+            kwargs.get('mfr_cfg_path'),
+            kwargs.get('device')
         )
     elif model_name == AtomicModel.OCR:
         atom_model = ocr_model_init(
-            kwargs.get("ocr_show_log"),
-            kwargs.get("det_db_box_thresh"),
-            kwargs.get("lang")
+            kwargs.get('ocr_show_log'),
+            kwargs.get('det_db_box_thresh'),
+            kwargs.get('lang')
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
-            kwargs.get("table_model_name"),
-            kwargs.get("table_model_path"),
-            kwargs.get("table_max_time"),
-            kwargs.get("device")
+            kwargs.get('table_model_name'),
+            kwargs.get('table_model_path'),
+            kwargs.get('table_max_time'),
+            kwargs.get('device')
         )
     else:
-        logger.error("model name not allow")
+        logger.error('model name not allow')
         exit(1)
 
     if atom_model is None:
-        logger.error("model init failed")
+        logger.error('model init failed')
         exit(1)
     else:
         return atom_model

+ 30 - 28
magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py

@@ -1,23 +1,25 @@
+import os
+
 import cv2
+import numpy as np
 from paddleocr.ppstructure.table.predict_table import TableSystem
 from paddleocr.ppstructure.utility import init_args
-from magic_pdf.libs.Constants import *
-import os
 from PIL import Image
-import numpy as np
+
+from magic_pdf.config.constants import *  # noqa: F403
 
 
 class TableMasterPaddleModel(object):
-    """
-        This class is responsible for converting image of table into HTML format using a pre-trained model.
+    """This class is responsible for converting image of table into HTML format
+    using a pre-trained model.
 
-        Attributes:
-        - table_sys: An instance of TableSystem initialized with parsed arguments.
+    Attributes:
+    - table_sys: An instance of TableSystem initialized with parsed arguments.
 
-        Methods:
-        - __init__(config): Initializes the model with configuration parameters.
-        - img2html(image): Converts a PIL Image or NumPy array to HTML string.
-        - parse_args(**kwargs): Parses configuration arguments.
+    Methods:
+    - __init__(config): Initializes the model with configuration parameters.
+    - img2html(image): Converts a PIL Image or NumPy array to HTML string.
+    - parse_args(**kwargs): Parses configuration arguments.
     """
 
     def __init__(self, config):
@@ -40,30 +42,30 @@ class TableMasterPaddleModel(object):
             image = np.asarray(image)
             image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
         pred_res, _ = self.table_sys(image)
-        pred_html = pred_res["html"]
+        pred_html = pred_res['html']
         # res = '<td><table  border="1">' + pred_html.replace("<html><body><table>", "").replace(
         # "</table></body></html>","") + "</table></td>\n"
         return pred_html
 
     def parse_args(self, **kwargs):
         parser = init_args()
-        model_dir = kwargs.get("model_dir")
-        table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)
-        table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)
-        det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)
-        rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
-        rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
-        device = kwargs.get("device", "cpu")
-        use_gpu = True if device.startswith("cuda") else False
+        model_dir = kwargs.get('model_dir')
+        table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)  # noqa: F405
+        table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)  # noqa: F405
+        det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)  # noqa: F405
+        rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)  # noqa: F405
+        rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)  # noqa: F405
+        device = kwargs.get('device', 'cpu')
+        use_gpu = True if device.startswith('cuda') else False
         config = {
-            "use_gpu": use_gpu,
-            "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
-            "table_algorithm": "TableMaster",
-            "table_model_dir": table_model_dir,
-            "table_char_dict_path": table_char_dict_path,
-            "det_model_dir": det_model_dir,
-            "rec_model_dir": rec_model_dir,
-            "rec_char_dict_path": rec_char_dict_path,
+            'use_gpu': use_gpu,
+            'table_max_len': kwargs.get('table_max_len', TABLE_MAX_LEN),  # noqa: F405
+            'table_algorithm': 'TableMaster',
+            'table_model_dir': table_model_dir,
+            'table_char_dict_path': table_char_dict_path,
+            'det_model_dir': det_model_dir,
+            'rec_model_dir': rec_model_dir,
+            'rec_char_dict_path': rec_char_dict_path,
         }
         parser.set_defaults(**config)
         return parser.parse_args([])

Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 408 - 247
magic_pdf/para/para_split.py


+ 352 - 182
magic_pdf/para/para_split_v2.py

@@ -1,15 +1,16 @@
 import copy
+import re
 
-from sklearn.cluster import DBSCAN
 import numpy as np
 from loguru import logger
-import re
-from magic_pdf.libs.boxbase import _is_in_or_part_overlap_with_area_ratio as is_in_layout
-from magic_pdf.libs.ocr_content_type import ContentType, BlockType
-from magic_pdf.model.magic_model import MagicModel
-from magic_pdf.libs.Constants import *
+from sklearn.cluster import DBSCAN
+
+from magic_pdf.config.constants import *  # noqa: F403
+from magic_pdf.config.ocr_content_type import BlockType, ContentType
+from magic_pdf.libs.boxbase import \
+    _is_in_or_part_overlap_with_area_ratio as is_in_layout
 
-LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ":", ":", ")", ")", ";"]
+LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ':', ':', ')', ')', ';']
 INLINE_EQUATION = ContentType.InlineEquation
 INTERLINE_EQUATION = ContentType.InterlineEquation
 TEXT = ContentType.Text
@@ -36,7 +37,9 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
         ones_indices = []
         i = 0
         while i < len(lst):  # Loop through the entire list
-            if lst[i] == 1:  # If we encounter a '1', we might be at the start of a pattern
+            if (
+                lst[i] == 1
+            ):  # If we encounter a '1', we might be at the start of a pattern
                 start = i
                 ones_in_this_interval = [i]
                 i += 1
@@ -46,7 +49,10 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
                         ones_in_this_interval.append(i)
                     i += 1
                 if len(ones_in_this_interval) > 1 or (
-                        start < len(lst) - 1 and ones_in_this_interval and lst[start + 1] in [2, 3]):
+                    start < len(lst) - 1
+                    and ones_in_this_interval
+                    and lst[start + 1] in [2, 3]
+                ):
                     indices.append((start, i - 1))
                     ones_indices.append(ones_in_this_interval)
             else:
@@ -65,7 +71,12 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
                 while i < len(lst) and lst[i] in [2, 3]:
                     i += 1
                 # 验证下一个序列是否符合条件
-                if i < len(lst) - 1 and lst[i] == 1 and lst[i + 1] in [2, 3] and lst[i - 1] in [2, 3]:
+                if (
+                    i < len(lst) - 1
+                    and lst[i] == 1
+                    and lst[i + 1] in [2, 3]
+                    and lst[i - 1] in [2, 3]
+                ):
                     while i < len(lst) and lst[i] in [1, 2, 3]:
                         if lst[i] == 1:
                             ones_in_this_interval.append(i)
@@ -114,7 +125,7 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
     """
     if len(lines) > 0:
         x_map_tag_dict, min_x_tag = cluster_line_x(lines)
-    for l in lines:
+    for l in lines:  # noqa: E741
         span_text = __get_span_text(l['spans'][0])
         if not span_text:
             line_fea_encode.append(0)
@@ -142,28 +153,26 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
     list_indice, list_start_idx = find_repeating_patterns2(line_fea_encode)
     if len(list_indice) > 0:
         if debug_able:
-            logger.info(f"发现了列表,列表行数:{list_indice}, {list_start_idx}")
+            logger.info(f'发现了列表,列表行数:{list_indice}, {list_start_idx}')
 
     # TODO check一下这个特列表里缩进的行左侧是不是对齐的。
-    segments = []
+
     for start, end in list_indice:
         for i in range(start, end + 1):
             if i > 0:
                 if line_fea_encode[i] == 4:
                     if debug_able:
-                        logger.info(f"列表行的第{i}行不是顶格的")
+                        logger.info(f'列表行的第{i}行不是顶格的')
                     break
         else:
             if debug_able:
-                logger.info(f"列表行的第{start}到第{end}行是列表")
+                logger.info(f'列表行的第{start}到第{end}行是列表')
 
     return split_indices(total_lines, list_indice), list_start_idx
 
 
 def cluster_line_x(lines: list) -> dict:
-    """
-    对一个block内所有lines的bbox的x0聚类
-    """
+    """对一个block内所有lines的bbox的x0聚类."""
     min_distance = 5
     min_sample = 1
     x0_lst = np.array([[round(line['bbox'][0]), 0] for line in lines])
@@ -171,14 +180,16 @@ def cluster_line_x(lines: list) -> dict:
     x0_uniq_label = np.unique(x0_clusters.labels_)
     # x1_lst = np.array([[line['bbox'][2], 0] for line in lines])
     x0_2_new_val = {}  # 存储旧值对应的新值映射
-    min_x0 = round(lines[0]["bbox"][0])
+    min_x0 = round(lines[0]['bbox'][0])
     for label in x0_uniq_label:
         if label == -1:
             continue
         x0_index_of_label = np.where(x0_clusters.labels_ == label)
         x0_raw_val = x0_lst[x0_index_of_label][:, 0]
         x0_new_val = np.min(x0_lst[x0_index_of_label][:, 0])
-        x0_2_new_val.update({round(raw_val): round(x0_new_val) for raw_val in x0_raw_val})
+        x0_2_new_val.update(
+            {round(raw_val): round(x0_new_val) for raw_val in x0_raw_val}
+        )
         if x0_new_val < min_x0:
             min_x0 = x0_new_val
     return x0_2_new_val, min_x0
@@ -193,27 +204,41 @@ def if_match_reference_list(text: str) -> bool:
 
 
 def __valign_lines(blocks, layout_bboxes):
-    """
-    在一个layoutbox内对齐行的左侧和右侧。
-    扫描行的左侧和右侧,如果x0, x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
-    3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。
-    """
+    """在一个layoutbox内对齐行的左侧和右侧。 扫描行的左侧和右侧,如果x0,
+    x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
+    3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。"""
 
     min_distance = 3
     min_sample = 2
     new_layout_bboxes = []
     # add bbox_fs for para split calculation
     for block in blocks:
-        block["bbox_fs"] = copy.deepcopy(block["bbox"])
+        block['bbox_fs'] = copy.deepcopy(block['bbox'])
     for layout_box in layout_bboxes:
-        blocks_in_layoutbox = [b for b in blocks if
-                               b["type"] == BlockType.Text and is_in_layout(b['bbox'], layout_box['layout_bbox'])]
-        if len(blocks_in_layoutbox) == 0 or len(blocks_in_layoutbox[0]["lines"]) == 0:
+        blocks_in_layoutbox = [
+            b
+            for b in blocks
+            if b['type'] == BlockType.Text
+            and is_in_layout(b['bbox'], layout_box['layout_bbox'])
+        ]
+        if len(blocks_in_layoutbox) == 0 or len(blocks_in_layoutbox[0]['lines']) == 0:
             new_layout_bboxes.append(layout_box['layout_bbox'])
             continue
 
-        x0_lst = np.array([[line['bbox'][0], 0] for block in blocks_in_layoutbox for line in block['lines']])
-        x1_lst = np.array([[line['bbox'][2], 0] for block in blocks_in_layoutbox for line in block['lines']])
+        x0_lst = np.array(
+            [
+                [line['bbox'][0], 0]
+                for block in blocks_in_layoutbox
+                for line in block['lines']
+            ]
+        )
+        x1_lst = np.array(
+            [
+                [line['bbox'][2], 0]
+                for block in blocks_in_layoutbox
+                for line in block['lines']
+            ]
+        )
         x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst)
         x1_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x1_lst)
         x0_uniq_label = np.unique(x0_clusters.labels_)
@@ -248,11 +273,13 @@ def __valign_lines(blocks, layout_bboxes):
 
         # 由于修改了block里的line长度,现在需要重新计算block的bbox
         for block in blocks_in_layoutbox:
-            if len(block["lines"]) > 0:
-                block['bbox_fs'] = [min([line['bbox'][0] for line in block['lines']]),
-                                    min([line['bbox'][1] for line in block['lines']]),
-                                    max([line['bbox'][2] for line in block['lines']]),
-                                    max([line['bbox'][3] for line in block['lines']])]
+            if len(block['lines']) > 0:
+                block['bbox_fs'] = [
+                    min([line['bbox'][0] for line in block['lines']]),
+                    min([line['bbox'][1] for line in block['lines']]),
+                    max([line['bbox'][2] for line in block['lines']]),
+                    max([line['bbox'][3] for line in block['lines']]),
+                ]
         """新计算layout的bbox,因为block的bbox变了。"""
         layout_x0 = min([block['bbox_fs'][0] for block in blocks_in_layoutbox])
         layout_y0 = min([block['bbox_fs'][1] for block in blocks_in_layoutbox])
@@ -264,18 +291,19 @@ def __valign_lines(blocks, layout_bboxes):
 
 
 def __align_text_in_layout(blocks, layout_bboxes):
-    """
-    由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
-    """
+    """由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。"""
     for layout in layout_bboxes:
         lb = layout['layout_bbox']
-        blocks_in_layoutbox = [block for block in blocks if
-                               block["type"] == BlockType.Text and is_in_layout(block['bbox'], lb)]
+        blocks_in_layoutbox = [
+            block
+            for block in blocks
+            if block['type'] == BlockType.Text and is_in_layout(block['bbox'], lb)
+        ]
         if len(blocks_in_layoutbox) == 0:
             continue
 
         for block in blocks_in_layoutbox:
-            for line in block.get("lines", []):
+            for line in block.get('lines', []):
                 x0, x1 = line['bbox'][0], line['bbox'][2]
                 if x0 < lb[0]:
                     line['bbox'][0] = lb[0]
@@ -284,9 +312,7 @@ def __align_text_in_layout(blocks, layout_bboxes):
 
 
 def __common_pre_proc(blocks, layout_bboxes):
-    """
-    不分语言的,对文本进行预处理
-    """
+    """不分语言的,对文本进行预处理."""
     # __add_line_period(blocks, layout_bboxes)
     __align_text_in_layout(blocks, layout_bboxes)
     aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
@@ -295,32 +321,30 @@ def __common_pre_proc(blocks, layout_bboxes):
 
 
 def __pre_proc_zh_blocks(blocks, layout_bboxes):
-    """
-    对中文文本进行分段预处理
-    """
+    """对中文文本进行分段预处理."""
     pass
 
 
 def __pre_proc_en_blocks(blocks, layout_bboxes):
-    """
-    对英文文本进行分段预处理
-    """
+    """对英文文本进行分段预处理."""
     pass
 
 
 def __group_line_by_layout(blocks, layout_bboxes):
-    """
-    每个layout内的行进行聚合
-    """
+    """每个layout内的行进行聚合."""
     # 因为只是一个block一行目前, 一个block就是一个段落
     blocks_group = []
     for lyout in layout_bboxes:
-        blocks_in_layout = [block for block in blocks if is_in_layout(block.get('bbox_fs', None), lyout['layout_bbox'])]
+        blocks_in_layout = [
+            block
+            for block in blocks
+            if is_in_layout(block.get('bbox_fs', None), lyout['layout_bbox'])
+        ]
         blocks_group.append(blocks_in_layout)
     return blocks_group
 
 
-def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
+def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang='en'):
     """
     lines_group 进行行分段——layout内部进行分段。lines_group内每个元素是一个Layoutbox内的所有行。
     1. 先计算每个group的左右边界。
@@ -336,17 +360,20 @@ def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
         if len(blocks) == 0:
             list_info.append([False, False])
             continue
-        if blocks[0]["type"] != BlockType.Text and blocks[-1]["type"] != BlockType.Text:
+        if blocks[0]['type'] != BlockType.Text and blocks[-1]['type'] != BlockType.Text:
             list_info.append([False, False])
             continue
-        if blocks[0]["type"] != BlockType.Text:
+        if blocks[0]['type'] != BlockType.Text:
             is_start_list = False
-        if blocks[-1]["type"] != BlockType.Text:
+        if blocks[-1]['type'] != BlockType.Text:
             is_end_list = False
 
-        lines = [line for block in blocks if
-                 block["type"] == BlockType.Text for line in
-                 block['lines']]
+        lines = [
+            line
+            for block in blocks
+            if block['type'] == BlockType.Text
+            for line in block['lines']
+        ]
         total_lines = len(lines)
         if total_lines == 1 or total_lines == 0:
             list_info.append([False, False])
@@ -359,7 +386,9 @@ def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
                     2. 左对齐的列表块(其特点是左侧顶格的行数小于等于非顶格的行数,非定格首字母会有小写,顶格90%是大写。并且左侧顶格行数大于1,大于1是为了这种模式连续出现才能称之为列表)
                         这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。
         """
-        text_segments, list_start_line = __detect_list_lines(lines, new_layout_bbox, lang)
+        text_segments, list_start_line = __detect_list_lines(
+            lines, new_layout_bbox, lang
+        )
         """根据list_range,把lines分成几个部分
 
         """
@@ -368,10 +397,17 @@ def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
                 for i in range(0, len(list_start)):
                     index = list_start[i] - 1
                     if index >= 0:
-                        if "content" in lines[index]["spans"][-1] and lines[index]["spans"][-1].get('type', '') not in [
-                            ContentType.InlineEquation, ContentType.InterlineEquation]:
-                            lines[index]["spans"][-1]["content"] += '\n\n'
-        layout_list_info = [False, False]  # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
+                        if 'content' in lines[index]['spans'][-1] and lines[index][
+                            'spans'
+                        ][-1].get('type', '') not in [
+                            ContentType.InlineEquation,
+                            ContentType.InterlineEquation,
+                        ]:
+                            lines[index]['spans'][-1]['content'] += '\n\n'
+        layout_list_info = [
+            False,
+            False,
+        ]  # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
         for content_type, start, end in text_segments:
             if content_type == 'list':
                 if start == 0 and is_start_list is None:
@@ -388,8 +424,7 @@ def __split_para_lines(lines: list, text_blocks: list) -> list:
     other_paras = []
     text_lines = []
     for line in lines:
-
-        spans_types = [span["type"] for span in line]
+        spans_types = [span['type'] for span in line]
         if ContentType.Table in spans_types:
             other_paras.append([line])
             continue
@@ -402,20 +437,22 @@ def __split_para_lines(lines: list, text_blocks: list) -> list:
         text_lines.append(line)
 
     for block in text_blocks:
-        block_bbox = block["bbox"]
+        block_bbox = block['bbox']
         para = []
         for line in text_lines:
-            bbox = line["bbox"]
+            bbox = line['bbox']
             if is_in_layout(bbox, block_bbox):
                 para.append(line)
         if len(para) > 0:
             text_paras.append(para)
     paras = other_paras.extend(text_paras)
-    paras_sorted = sorted(paras, key=lambda x: x[0]["bbox"][1])
+    paras_sorted = sorted(paras, key=lambda x: x[0]['bbox'][1])
     return paras_sorted
 
 
-def __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info, page_num, lang):
+def __connect_list_inter_layout(
+    blocks_group, new_layout_bbox, layout_list_info, page_num, lang
+):
     global debug_able
     """
     如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
@@ -429,74 +466,108 @@ def __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info,
             continue
         pre_layout_list_info = layout_list_info[i - 1]
         next_layout_list_info = layout_list_info[i]
-        pre_last_para = blocks_group[i - 1][-1].get("lines", [])
+        pre_last_para = blocks_group[i - 1][-1].get('lines', [])
         next_paras = blocks_group[i]
         next_first_para = next_paras[0]
 
-        if pre_layout_list_info[1] and not next_layout_list_info[0] and next_first_para[
-            "type"] == BlockType.Text:  # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
+        if (
+            pre_layout_list_info[1]
+            and not next_layout_list_info[0]
+            and next_first_para['type'] == BlockType.Text
+        ):  # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
             if debug_able:
-                logger.info(f"连接page {page_num} 内的list")
+                logger.info(f'连接page {page_num} 内的list')
             # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
             may_list_lines = []
-            lines = next_first_para.get("lines", [])
+            lines = next_first_para.get('lines', [])
 
             for line in lines:
-                if line['bbox'][0] > __find_layout_bbox_by_line(line['bbox'], new_layout_bbox)[0]:
+                if (
+                    line['bbox'][0]
+                    > __find_layout_bbox_by_line(line['bbox'], new_layout_bbox)[0]
+                ):
                     may_list_lines.append(line)
                 else:
                     break
             # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
-            if len(may_list_lines) > 0 and len(set([x['bbox'][0] for x in may_list_lines])) == 1:
+            if (
+                len(may_list_lines) > 0
+                and len(set([x['bbox'][0] for x in may_list_lines])) == 1
+            ):
                 pre_last_para.extend(may_list_lines)
-                next_first_para["lines"] = next_first_para["lines"][len(may_list_lines):]
-
-    return blocks_group, [layout_list_info[0][0], layout_list_info[-1][1]]  # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
-
-
-def __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox,
-                              pre_page_list_info, next_page_list_info, page_num, lang):
-    """
-    如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
-    根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
-    """
-    if len(pre_page_paras) == 0 or len(next_page_paras) == 0:  # 0的时候最后的return 会出错
+                next_first_para['lines'] = next_first_para['lines'][
+                    len(may_list_lines) :
+                ]
+
+    return blocks_group, [
+        layout_list_info[0][0],
+        layout_list_info[-1][1],
+    ]  # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
+
+
+def __connect_list_inter_page(
+    pre_page_paras,
+    next_page_paras,
+    pre_page_layout_bbox,
+    next_page_layout_bbox,
+    pre_page_list_info,
+    next_page_list_info,
+    page_num,
+    lang,
+):
+    """如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
+    因为没有区分列表和段落,所以这个方法暂时不实现。
+    根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
+    if (
+        len(pre_page_paras) == 0 or len(next_page_paras) == 0
+    ):  # 0的时候最后的return 会出错
         return False
     if len(pre_page_paras[-1]) == 0 or len(next_page_paras[0]) == 0:
         return False
-    if pre_page_paras[-1][-1]["type"] != BlockType.Text or next_page_paras[0][0]["type"] != BlockType.Text:
+    if (
+        pre_page_paras[-1][-1]['type'] != BlockType.Text
+        or next_page_paras[0][0]['type'] != BlockType.Text
+    ):
         return False
-    if pre_page_list_info[1] and not next_page_list_info[0]:  # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
+    if (
+        pre_page_list_info[1] and not next_page_list_info[0]
+    ):  # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
         if debug_able:
-            logger.info(f"连接page {page_num} 内的list")
+            logger.info(f'连接page {page_num} 内的list')
         # 向layout_paras[i] 寻找开头具有相同缩进的连续的行
         may_list_lines = []
         next_page_first_para = next_page_paras[0][0]
-        if next_page_first_para["type"] == BlockType.Text:
-            lines = next_page_first_para["lines"]
+        if next_page_first_para['type'] == BlockType.Text:
+            lines = next_page_first_para['lines']
             for line in lines:
-                if line['bbox'][0] > __find_layout_bbox_by_line(line['bbox'], next_page_layout_bbox)[0]:
+                if (
+                    line['bbox'][0]
+                    > __find_layout_bbox_by_line(line['bbox'], next_page_layout_bbox)[0]
+                ):
                     may_list_lines.append(line)
                 else:
                     break
         # 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
-        if len(may_list_lines) > 0 and len(set([x['bbox'][0] for x in may_list_lines])) == 1:
+        if (
+            len(may_list_lines) > 0
+            and len(set([x['bbox'][0] for x in may_list_lines])) == 1
+        ):
             # pre_page_paras[-1].append(may_list_lines)
             # 下一页合并到上一页最后一段,打一个cross_page的标签
             for line in may_list_lines:
-                for span in line["spans"]:
-                    span[CROSS_PAGE] = True
-            pre_page_paras[-1][-1]["lines"].extend(may_list_lines)
-            next_page_first_para["lines"] = next_page_first_para["lines"][len(may_list_lines):]
+                for span in line['spans']:
+                    span[CROSS_PAGE] = True  # noqa: F405
+            pre_page_paras[-1][-1]['lines'].extend(may_list_lines)
+            next_page_first_para['lines'] = next_page_first_para['lines'][
+                len(may_list_lines) :
+            ]
             return True
 
     return False
 
 
 def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
-    """
-    根据line找到所在的layout
-    """
+    """根据line找到所在的layout."""
     for layout in layout_bboxes:
         if is_in_layout(line_bbox, layout):
             return layout
@@ -525,39 +596,59 @@ def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox):
                 connected_layout_blocks.append(blocks_group[i])
                 continue
             # text类型的段才需要考虑layout间的合并
-            if blocks_group[i - 1][-1]["type"] != BlockType.Text or blocks_group[i][0]["type"] != BlockType.Text:
+            if (
+                blocks_group[i - 1][-1]['type'] != BlockType.Text
+                or blocks_group[i][0]['type'] != BlockType.Text
+            ):
                 connected_layout_blocks.append(blocks_group[i])
                 continue
-            if len(blocks_group[i - 1][-1]["lines"]) == 0 or len(blocks_group[i][0]["lines"]) == 0:
+            if (
+                len(blocks_group[i - 1][-1]['lines']) == 0
+                or len(blocks_group[i][0]['lines']) == 0
+            ):
                 connected_layout_blocks.append(blocks_group[i])
                 continue
-            pre_last_line = blocks_group[i - 1][-1]["lines"][-1]
-            next_first_line = blocks_group[i][0]["lines"][0]
-        except Exception as e:
-            logger.error(f"page layout {i} has no line")
+            pre_last_line = blocks_group[i - 1][-1]['lines'][-1]
+            next_first_line = blocks_group[i][0]['lines'][0]
+        except Exception:
+            logger.error(f'page layout {i} has no line')
             continue
-        pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
+        pre_last_line_text = ''.join(
+            [__get_span_text(span) for span in pre_last_line['spans']]
+        )
         pre_last_line_type = pre_last_line['spans'][-1]['type']
-        next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
+        next_first_line_text = ''.join(
+            [__get_span_text(span) for span in next_first_line['spans']]
+        )
         next_first_line_type = next_first_line['spans'][0]['type']
-        if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
+        if pre_last_line_type not in [
+            TEXT,
+            INLINE_EQUATION,
+        ] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
             connected_layout_blocks.append(blocks_group[i])
             continue
         pre_layout = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)
-        next_layout = __find_layout_bbox_by_line(next_first_line['bbox'], new_layout_bbox)
+        next_layout = __find_layout_bbox_by_line(
+            next_first_line['bbox'], new_layout_bbox
+        )
 
         pre_x2_max = pre_layout[2] if pre_layout else -1
         next_x0_min = next_layout[0] if next_layout else -1
 
         pre_last_line_text = pre_last_line_text.strip()
         next_first_line_text = next_first_line_text.strip()
-        if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text and pre_last_line_text[
-            -1] not in LINE_STOP_FLAG and \
-                next_first_line['bbox'][0] == next_x0_min:  # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
+        if (
+            pre_last_line['bbox'][2] == pre_x2_max
+            and pre_last_line_text
+            and pre_last_line_text[-1] not in LINE_STOP_FLAG
+            and next_first_line['bbox'][0] == next_x0_min
+        ):  # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
             """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
-            connected_layout_blocks[-1][-1]["lines"].extend(blocks_group[i][0]["lines"])
-            blocks_group[i][0]["lines"] = []  # 删除后一个layout第一个段落中的lines,因为他已经被合并到前一个layout的最后一个段落了
-            blocks_group[i][0][LINES_DELETED] = True
+            connected_layout_blocks[-1][-1]['lines'].extend(blocks_group[i][0]['lines'])
+            blocks_group[i][0][
+                'lines'
+            ] = []  # 删除后一个layout第一个段落中的lines,因为他已经被合并到前一个layout的最后一个段落了
+            blocks_group[i][0][LINES_DELETED] = True  # noqa: F405
             # if len(layout_paras[i]) == 0:
             #     layout_paras.pop(i)
             # else:
@@ -569,8 +660,14 @@ def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox):
     return connected_layout_blocks
 
 
-def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num,
-                              lang):
+def __connect_para_inter_page(
+    pre_page_paras,
+    next_page_paras,
+    pre_page_layout_bbox,
+    next_page_layout_bbox,
+    page_num,
+    lang,
+):
     """
     连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。
     是否可以连接的条件:
@@ -578,33 +675,53 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
     2. 后一个页面的第一个段落第一行没有空白开头。
     """
     # 有的页面可能压根没有文字
-    if len(pre_page_paras) == 0 or len(next_page_paras) == 0 or len(pre_page_paras[0]) == 0 or len(
-            next_page_paras[0]) == 0:  # TODO [[]]为什么出现在pre_page_paras里?
+    if (
+        len(pre_page_paras) == 0
+        or len(next_page_paras) == 0
+        or len(pre_page_paras[0]) == 0
+        or len(next_page_paras[0]) == 0
+    ):  # TODO [[]]为什么出现在pre_page_paras里?
         return False
     pre_last_block = pre_page_paras[-1][-1]
     next_first_block = next_page_paras[0][0]
-    if pre_last_block["type"] != BlockType.Text or next_first_block["type"] != BlockType.Text:
+    if (
+        pre_last_block['type'] != BlockType.Text
+        or next_first_block['type'] != BlockType.Text
+    ):
         return False
-    if len(pre_last_block["lines"]) == 0 or len(next_first_block["lines"]) == 0:
+    if len(pre_last_block['lines']) == 0 or len(next_first_block['lines']) == 0:
         return False
-    pre_last_para = pre_last_block["lines"]
-    next_first_para = next_first_block["lines"]
+    pre_last_para = pre_last_block['lines']
+    next_first_para = next_first_block['lines']
     pre_last_line = pre_last_para[-1]
     next_first_line = next_first_para[0]
-    pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
+    pre_last_line_text = ''.join(
+        [__get_span_text(span) for span in pre_last_line['spans']]
+    )
     pre_last_line_type = pre_last_line['spans'][-1]['type']
-    next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
+    next_first_line_text = ''.join(
+        [__get_span_text(span) for span in next_first_line['spans']]
+    )
     next_first_line_type = next_first_line['spans'][0]['type']
 
-    if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT,
-                                                                                         INLINE_EQUATION]:  # TODO,真的要做好,要考虑跨table, image, 行间的情况
+    if pre_last_line_type not in [
+        TEXT,
+        INLINE_EQUATION,
+    ] or next_first_line_type not in [
+        TEXT,
+        INLINE_EQUATION,
+    ]:  # TODO,真的要做好,要考虑跨table, image, 行间的情况
         # 不是文本,不连接
         return False
 
-    pre_x2_max_bbox = __find_layout_bbox_by_line(pre_last_line['bbox'], pre_page_layout_bbox)
+    pre_x2_max_bbox = __find_layout_bbox_by_line(
+        pre_last_line['bbox'], pre_page_layout_bbox
+    )
     if not pre_x2_max_bbox:
         return False
-    next_x0_min_bbox = __find_layout_bbox_by_line(next_first_line['bbox'], next_page_layout_bbox)
+    next_x0_min_bbox = __find_layout_bbox_by_line(
+        next_first_line['bbox'], next_page_layout_bbox
+    )
     if not next_x0_min_bbox:
         return False
 
@@ -613,18 +730,21 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
 
     pre_last_line_text = pre_last_line_text.strip()
     next_first_line_text = next_first_line_text.strip()
-    if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and \
-            next_first_line['bbox'][0] == next_x0_min:  # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
+    if (
+        pre_last_line['bbox'][2] == pre_x2_max
+        and pre_last_line_text[-1] not in LINE_STOP_FLAG
+        and next_first_line['bbox'][0] == next_x0_min
+    ):  # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
         """连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
         # 下一页合并到上一页最后一段,打一个cross_page的标签
         for line in next_first_para:
-            for span in line["spans"]:
-                span[CROSS_PAGE] = True
+            for span in line['spans']:
+                span[CROSS_PAGE] = True  # noqa: F405
         pre_last_para.extend(next_first_para)
 
         # next_page_paras[0].pop(0)  # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
-        next_page_paras[0][0]["lines"] = []
-        next_page_paras[0][0][LINES_DELETED] = True
+        next_page_paras[0][0]['lines'] = []
+        next_page_paras[0][0][LINES_DELETED] = True  # noqa: F405
         return True
     else:
         return False
@@ -667,38 +787,73 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang):
         single_line_paras_tag = []
         for i in range(len(layout_para)):
             # single_line_paras_tag.append(len(layout_para[i]) == 1 and layout_para[i][0]['spans'][0]['type'] == TEXT)
-            single_line_paras_tag.append(layout_para[i]['type'] == BlockType.Text and len(layout_para[i]["lines"]) == 1)
+            single_line_paras_tag.append(
+                layout_para[i]['type'] == BlockType.Text
+                and len(layout_para[i]['lines']) == 1
+            )
         """找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
-        consecutive_single_line_indices = find_consecutive_true_regions(single_line_paras_tag)
+        consecutive_single_line_indices = find_consecutive_true_regions(
+            single_line_paras_tag
+        )
         if len(consecutive_single_line_indices) > 0:
-            """检查这些行是否是高度相同的,居中的"""
+            """检查这些行是否是高度相同的,居中的."""
             for start, end in consecutive_single_line_indices:
                 # start += index_offset
                 # end += index_offset
-                line_hi = np.array([block["lines"][0]['bbox'][3] - block["lines"][0]['bbox'][1] for block in
-                                    layout_para[start:end + 1]])
-                first_line_text = ''.join([__get_span_text(span) for span in layout_para[start]["lines"][0]['spans']])
-                if "Table" in first_line_text or "Figure" in first_line_text:
+                line_hi = np.array(
+                    [
+                        block['lines'][0]['bbox'][3] - block['lines'][0]['bbox'][1]
+                        for block in layout_para[start : end + 1]
+                    ]
+                )
+                first_line_text = ''.join(
+                    [
+                        __get_span_text(span)
+                        for span in layout_para[start]['lines'][0]['spans']
+                    ]
+                )
+                if 'Table' in first_line_text or 'Figure' in first_line_text:
                     pass
                 if debug_able:
                     logger.info(line_hi.std())
 
                 if line_hi.std() < 2:
-                    """行高度相同,那么判断是否居中"""
-                    all_left_x0 = [block["lines"][0]['bbox'][0] for block in layout_para[start:end + 1]]
-                    all_right_x1 = [block["lines"][0]['bbox'][2] for block in layout_para[start:end + 1]]
+                    """行高度相同,那么判断是否居中."""
+                    all_left_x0 = [
+                        block['lines'][0]['bbox'][0]
+                        for block in layout_para[start : end + 1]
+                    ]
+                    all_right_x1 = [
+                        block['lines'][0]['bbox'][2]
+                        for block in layout_para[start : end + 1]
+                    ]
                     layout_center = (layout_box[0] + layout_box[2]) / 2
-                    if all([x0 < layout_center < x1 for x0, x1 in zip(all_left_x0, all_right_x1)]) \
-                            and not all([x0 == layout_box[0] for x0 in all_left_x0]) \
-                            and not all([x1 == layout_box[2] for x1 in all_right_x1]):
-                        merge_para = [block["lines"][0] for block in layout_para[start:end + 1]]
-                        para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']])
+                    if (
+                        all(
+                            [
+                                x0 < layout_center < x1
+                                for x0, x1 in zip(all_left_x0, all_right_x1)
+                            ]
+                        )
+                        and not all([x0 == layout_box[0] for x0 in all_left_x0])
+                        and not all([x1 == layout_box[2] for x1 in all_right_x1])
+                    ):
+                        merge_para = [
+                            block['lines'][0] for block in layout_para[start : end + 1]
+                        ]
+                        para_text = ''.join(
+                            [
+                                __get_span_text(span)
+                                for line in merge_para
+                                for span in line['spans']
+                            ]
+                        )
                         if debug_able:
                             logger.info(para_text)
-                        layout_para[start]["lines"] = merge_para
+                        layout_para[start]['lines'] = merge_para
                         for i_para in range(start + 1, end + 1):
-                            layout_para[i_para]["lines"] = []
-                            layout_para[i_para][LINES_DELETED] = True
+                            layout_para[i_para]['lines'] = []
+                            layout_para[i_para][LINES_DELETED] = True  # noqa: F405
                         # layout_para[start:end + 1] = [merge_para]
 
                         # index_offset -= end - start
@@ -707,18 +862,13 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang):
 
 
 def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang):
-    """
-    找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。
-    """
+    """找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。"""
 
     pass
 
 
 def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
-    """
-    根据line和layout情况进行分段
-    先实现一个根据行末尾特征分段的简单方法。
-    """
+    """根据line和layout情况进行分段 先实现一个根据行末尾特征分段的简单方法。"""
     """
     算法思路:
     1. 扫描layout里每一行,找出来行尾距离layout有边界有一定距离的行。
@@ -727,15 +877,20 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
     4. 图、表,目前独占一行,不考虑分段。
     """
     blocks_group = __group_line_by_layout(blocks, layout_bboxes)  # block内分段
-    layout_list_info = __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang)  # layout内分段
-    blocks_group, page_list_info = __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info,
-                                                               page_num, lang)  # layout之间连接列表段落
-    connected_layout_blocks = __connect_para_inter_layoutbox(blocks_group, new_layout_bbox)  # layout间链接段落
+    layout_list_info = __split_para_in_layoutbox(
+        blocks_group, new_layout_bbox, lang
+    )  # layout内分段
+    blocks_group, page_list_info = __connect_list_inter_layout(
+        blocks_group, new_layout_bbox, layout_list_info, page_num, lang
+    )  # layout之间连接列表段落
+    connected_layout_blocks = __connect_para_inter_layoutbox(
+        blocks_group, new_layout_bbox
+    )  # layout间链接段落
 
     return connected_layout_blocks, page_list_info
 
 
-def para_split(pdf_info_dict, debug_mode, lang="en"):
+def para_split(pdf_info_dict, debug_mode, lang='en'):
     global debug_able
     debug_able = debug_mode
     new_layout_of_pages = []  # 数组的数组,每个元素是一个页面的layoutS
@@ -745,7 +900,9 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
         layout_bboxes = page['layout_bboxes']
         new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
         new_layout_of_pages.append(new_layout_bbox)
-        splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang)
+        splited_blocks, page_list_info = __do_split_page(
+            blocks, layout_bboxes, new_layout_bbox, page_num, lang
+        )
         all_page_list_info.append(page_list_info)
         page['para_blocks'] = splited_blocks
 
@@ -759,18 +916,31 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
         pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
         next_page_layout_bbox = new_layout_of_pages[page_num]
 
-        is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
-                                            next_page_layout_bbox, page_num, lang)
+        is_conn = __connect_para_inter_page(
+            pre_page_paras,
+            next_page_paras,
+            pre_page_layout_bbox,
+            next_page_layout_bbox,
+            page_num,
+            lang,
+        )
         if debug_able:
             if is_conn:
-                logger.info(f"连接了第{page_num - 1}页和第{page_num}页的段落")
-
-        is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
-                                                 next_page_layout_bbox, all_page_list_info[page_num - 1],
-                                                 all_page_list_info[page_num], page_num, lang)
+                logger.info(f'连接了第{page_num - 1}页和第{page_num}页的段落')
+
+        is_list_conn = __connect_list_inter_page(
+            pre_page_paras,
+            next_page_paras,
+            pre_page_layout_bbox,
+            next_page_layout_bbox,
+            all_page_list_info[page_num - 1],
+            all_page_list_info[page_num],
+            page_num,
+            lang,
+        )
         if debug_able:
             if is_list_conn:
-                logger.info(f"连接了第{page_num - 1}页和第{page_num}页的列表段落")
+                logger.info(f'连接了第{page_num - 1}页和第{page_num}页的列表段落')
 
     """接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
     1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
@@ -786,4 +956,4 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
     for page_num, page in enumerate(pdf_info_dict.values()):
         page_paras = page['para_blocks']
         page_blocks = [block for layout in page_paras for block in layout]
-        page["para_blocks"] = page_blocks
+        page['para_blocks'] = page_blocks

+ 81 - 46
magic_pdf/para/para_split_v3.py

@@ -1,17 +1,30 @@
 import copy
 
-from loguru import logger
-
-from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
-from magic_pdf.libs.ocr_content_type import BlockType, ContentType
-
-LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
+from magic_pdf.config.constants import CROSS_PAGE, LINES_DELETED
+from magic_pdf.config.ocr_content_type import BlockType, ContentType
+
+LINE_STOP_FLAG = (
+    '.',
+    '!',
+    '?',
+    '。',
+    '!',
+    '?',
+    ')',
+    ')',
+    '"',
+    '”',
+    ':',
+    ':',
+    ';',
+    ';',
+)
 LIST_END_FLAG = ('.', '。', ';', ';')
 
 
 class ListLineTag:
-    IS_LIST_START_LINE = "is_list_start_line"
-    IS_LIST_END_LINE = "is_list_end_line"
+    IS_LIST_START_LINE = 'is_list_start_line'
+    IS_LIST_END_LINE = 'is_list_end_line'
 
 
 def __process_blocks(blocks):
@@ -27,12 +40,14 @@ def __process_blocks(blocks):
 
         # 如果当前块是 text 类型
         if current_block['type'] == 'text':
-            current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"])
-            if 'lines' in current_block and len(current_block["lines"]) > 0:
-                current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]),
-                                            min([line['bbox'][1] for line in current_block['lines']]),
-                                            max([line['bbox'][2] for line in current_block['lines']]),
-                                            max([line['bbox'][3] for line in current_block['lines']])]
+            current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
+            if 'lines' in current_block and len(current_block['lines']) > 0:
+                current_block['bbox_fs'] = [
+                    min([line['bbox'][0] for line in current_block['lines']]),
+                    min([line['bbox'][1] for line in current_block['lines']]),
+                    max([line['bbox'][2] for line in current_block['lines']]),
+                    max([line['bbox'][3] for line in current_block['lines']]),
+                ]
             current_group.append(current_block)
 
         # 检查下一个块是否存在
@@ -83,9 +98,10 @@ def __is_list_or_index_block(block):
         # logger.info(f"block_weight_radio: {block_weight_radio}")
 
         # 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
-        if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
-                abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
-                block['bbox_fs'][2] - last_line['bbox'][2] > line_height
+        if (
+            first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
+            and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
+            and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
         ):
             multiple_para_flag = True
 
@@ -93,14 +109,14 @@ def __is_list_or_index_block(block):
             line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
             block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
             if (
-                    line['bbox'][0] - block['bbox_fs'][0] > 0.8 * line_height and
-                    block['bbox_fs'][2] - line['bbox'][2] > 0.8 * line_height
+                line['bbox'][0] - block['bbox_fs'][0] > 0.8 * line_height
+                and block['bbox_fs'][2] - line['bbox'][2] > 0.8 * line_height
             ):
                 external_sides_not_close_num += 1
             if abs(line_mid_x - block_mid_x) < line_height / 2:
                 center_close_num += 1
 
-            line_text = ""
+            line_text = ''
 
             for span in line['spans']:
                 span_type = span['type']
@@ -148,15 +164,19 @@ def __is_list_or_index_block(block):
                     if line_text[-1].isdigit():
                         num_end_count += 1
 
-            if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
+            if (
+                num_start_count / len(lines_text_list) >= 0.8
+                or num_end_count / len(lines_text_list) >= 0.8
+            ):
                 line_num_flag = True
             if flag_end_count / len(lines_text_list) >= 0.8:
                 line_end_flag = True
 
         # 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
-        if ((left_close_num / len(block['lines']) >= 0.8 or right_close_num / len(block['lines']) >= 0.8)
-                and line_num_flag
-        ):
+        if (
+            left_close_num / len(block['lines']) >= 0.8
+            or right_close_num / len(block['lines']) >= 0.8
+        ) and line_num_flag:
             for line in block['lines']:
                 line[ListLineTag.IS_LIST_START_LINE] = True
             return BlockType.Index
@@ -164,20 +184,20 @@ def __is_list_or_index_block(block):
         # 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
         # 补充条件block的长宽比有要求
         elif (
-                external_sides_not_close_num >= 2 and
-                center_close_num == len(block['lines']) and
-                external_sides_not_close_num / len(block['lines']) >= 0.5 and
-                block_height / block_weight > 0.4
+            external_sides_not_close_num >= 2
+            and center_close_num == len(block['lines'])
+            and external_sides_not_close_num / len(block['lines']) >= 0.5
+            and block_height / block_weight > 0.4
         ):
             for line in block['lines']:
                 line[ListLineTag.IS_LIST_START_LINE] = True
             return BlockType.List
 
         elif (
-                left_close_num >= 2
-                and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
-                and not multiple_para_flag
-                # and block_weight_radio > 0.27
+            left_close_num >= 2
+            and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
+            and not multiple_para_flag
+            # and block_weight_radio > 0.27
         ):
             # 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
             if left_close_num / len(block['lines']) > 0.8:
@@ -189,10 +209,15 @@ def __is_list_or_index_block(block):
                 # 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
                 elif line_end_flag:
                     for i, line in enumerate(block['lines']):
-                        if len(lines_text_list[i]) > 0 and lines_text_list[i][-1] in LIST_END_FLAG:
+                        if (
+                            len(lines_text_list[i]) > 0
+                            and lines_text_list[i][-1] in LIST_END_FLAG
+                        ):
                             line[ListLineTag.IS_LIST_END_LINE] = True
                             if i + 1 < len(block['lines']):
-                                block['lines'][i + 1][ListLineTag.IS_LIST_START_LINE] = True
+                                block['lines'][i + 1][
+                                    ListLineTag.IS_LIST_START_LINE
+                                ] = True
                 # line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
                 else:
                     line_start_flag = False
@@ -201,7 +226,10 @@ def __is_list_or_index_block(block):
                             line[ListLineTag.IS_LIST_START_LINE] = True
                             line_start_flag = False
 
-                        if abs(block['bbox_fs'][2] - line['bbox'][2]) > 0.1 * block_weight:
+                        if (
+                            abs(block['bbox_fs'][2] - line['bbox'][2])
+                            > 0.1 * block_weight
+                        ):
                             line[ListLineTag.IS_LIST_END_LINE] = True
                             line_start_flag = True
             # 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
@@ -243,11 +271,13 @@ def __merge_2_text_blocks(block1, block2):
                     first_span = first_line['spans'][0]
                     if len(first_span['content']) > 0:
                         span_start_with_num = first_span['content'][0].isdigit()
-                        if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height
-                                and not last_span['content'].endswith(LINE_STOP_FLAG)
-                                # 两个block宽度差距超过2倍也不合并
-                                and abs(block1_weight - block2_weight) < min_block_weight
-                                and not span_start_with_num
+                        if (
+                            abs(block2['bbox_fs'][2] - last_line['bbox'][2])
+                            < line_height
+                            and not last_span['content'].endswith(LINE_STOP_FLAG)
+                            # 两个block宽度差距超过2倍也不合并
+                            and abs(block1_weight - block2_weight) < min_block_weight
+                            and not span_start_with_num
                         ):
                             if block1['page_num'] != block2['page_num']:
                                 for line in block1['lines']:
@@ -284,7 +314,6 @@ def __is_list_group(text_blocks_group):
 def __para_merge_page(blocks):
     page_text_blocks_groups = __process_blocks(blocks)
     for text_blocks_group in page_text_blocks_groups:
-
         if len(text_blocks_group) > 0:
             # 需要先在合并前对所有block判断是否为list or index block
             for block in text_blocks_group:
@@ -293,7 +322,6 @@ def __para_merge_page(blocks):
                 # logger.info(f"{block['type']}:{block}")
 
         if len(text_blocks_group) > 1:
-
             # 在合并前判断这个group 是否是一个 list group
             is_list_group = __is_list_group(text_blocks_group)
 
@@ -305,11 +333,18 @@ def __para_merge_page(blocks):
                 if i - 1 >= 0:
                     prev_block = text_blocks_group[i - 1]
 
-                    if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group:
+                    if (
+                        current_block['type'] == 'text'
+                        and prev_block['type'] == 'text'
+                        and not is_list_group
+                    ):
                         __merge_2_text_blocks(current_block, prev_block)
                     elif (
-                            (current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or
-                            (current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index)
+                        current_block['type'] == BlockType.List
+                        and prev_block['type'] == BlockType.List
+                    ) or (
+                        current_block['type'] == BlockType.Index
+                        and prev_block['type'] == BlockType.Index
                     ):
                         __merge_2_list_blocks(current_block, prev_block)
 
@@ -339,4 +374,4 @@ if __name__ == '__main__':
     # 调用函数
     groups = __process_blocks(input_blocks)
     for group_index, group in enumerate(groups):
-        print(f"Group {group_index}: {group}")
+        print(f'Group {group_index}: {group}')

+ 174 - 100
magic_pdf/pdf_parse_union_core.py

@@ -2,38 +2,47 @@ import time
 
 from loguru import logger
 
+from magic_pdf.config.drop_reason import DropReason
+from magic_pdf.config.ocr_content_type import ContentType
+from magic_pdf.layout.layout_sort import (LAYOUT_UNPROC, get_bboxes_layout,
+                                          get_columns_cnt_of_layout)
 from magic_pdf.libs.commons import fitz, get_delta_time
-from magic_pdf.layout.layout_sort import get_bboxes_layout, LAYOUT_UNPROC, get_columns_cnt_of_layout
 from magic_pdf.libs.convert_utils import dict_to_list
-from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.local_math import float_equal
-from magic_pdf.libs.ocr_content_type import ContentType
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.para.para_split_v2 import para_split
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
-from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
+from magic_pdf.pre_proc.construct_page_dict import \
+    ocr_construct_page_component_v2
 from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
-from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \
-    combine_chars_to_pymudict
-from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split
-from magic_pdf.pre_proc.ocr_dict_merge import sort_blocks_by_layout, fill_spans_in_blocks, fix_block_spans, \
-    fix_discarded_block
-from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
-    remove_overlaps_low_confidence_spans
-from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
+from magic_pdf.pre_proc.equations_replace import (
+    combine_chars_to_pymudict, remove_chars_in_text_blocks,
+    replace_equations_in_textblock)
+from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
+    ocr_prepare_bboxes_for_layout_split
+from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
+                                               fix_block_spans,
+                                               fix_discarded_block,
+                                               sort_blocks_by_layout)
+from magic_pdf.pre_proc.ocr_span_list_modify import (
+    get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
+    remove_overlaps_min_spans)
+from magic_pdf.pre_proc.resolve_bbox_conflict import \
+    check_useful_block_horizontal_overlap
 
 
 def remove_horizontal_overlap_block_which_smaller(all_bboxes):
     useful_blocks = []
     for bbox in all_bboxes:
-        useful_blocks.append({
-            "bbox": bbox[:4]
-        })
-    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks)
+        useful_blocks.append({'bbox': bbox[:4]})
+    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
+        check_useful_block_horizontal_overlap(useful_blocks)
+    )
     if is_useful_block_horz_overlap:
         logger.warning(
-            f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}")
+            f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
+        )
         for bbox in all_bboxes.copy():
             if smaller_bbox == bbox[:4]:
                 all_bboxes.remove(bbox)
@@ -41,27 +50,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
     return is_useful_block_horz_overlap, all_bboxes
 
 
-def __replace_STX_ETX(text_str:str):
-    """ Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
-Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
+def __replace_STX_ETX(text_str: str):
+    """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
+    Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
 
-    Args:
-        text_str (str): raw text
+        Args:
+            text_str (str): raw text
 
-    Returns:
-        _type_: replaced text
+        Returns:
+            _type_: replaced text
     """
     if text_str:
         s = text_str.replace('\u0002', "'")
-        s = s.replace("\u0003", "'")
+        s = s.replace('\u0003', "'")
         return s
     return text_str
 
 
 def txt_spans_extract(pdf_page, inline_equations, interline_equations):
-    text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
-    char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
-        "blocks"
+    text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
+    char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
+        'blocks'
     ]
     text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
     text_blocks = replace_equations_in_textblock(
@@ -71,189 +80,254 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
     text_blocks = remove_chars_in_text_blocks(text_blocks)
     spans = []
     for v in text_blocks:
-        for line in v["lines"]:
-            for span in line["spans"]:
-                bbox = span["bbox"]
+        for line in v['lines']:
+            for span in line['spans']:
+                bbox = span['bbox']
                 if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
                     continue
-                if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation):
+                if span.get('type') not in (
+                    ContentType.InlineEquation,
+                    ContentType.InterlineEquation,
+                ):
                     spans.append(
                         {
-                            "bbox": list(span["bbox"]),
-                            "content": __replace_STX_ETX(span["text"]),
-                            "type": ContentType.Text,
-                            "score": 1.0,
+                            'bbox': list(span['bbox']),
+                            'content': __replace_STX_ETX(span['text']),
+                            'type': ContentType.Text,
+                            'score': 1.0,
                         }
                     )
     return spans
 
 
 def replace_text_span(pymu_spans, ocr_spans):
-    return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
+    return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
 
 
-def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode):
+def parse_page_core(
+    pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+):
     need_drop = False
     drop_reason = []
 
-    '''从magic_model对象中获取后面会用到的区块信息'''
+    """从magic_model对象中获取后面会用到的区块信息"""
     img_blocks = magic_model.get_imgs(page_id)
     table_blocks = magic_model.get_tables(page_id)
     discarded_blocks = magic_model.get_discarded(page_id)
     text_blocks = magic_model.get_text_blocks(page_id)
     title_blocks = magic_model.get_title_blocks(page_id)
-    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
+    inline_equations, interline_equations, interline_equation_blocks = (
+        magic_model.get_equations(page_id)
+    )
 
     page_w, page_h = magic_model.get_page_size(page_id)
 
     spans = magic_model.get_all_spans(page_id)
 
-    '''根据parse_mode,构造spans'''
-    if parse_mode == "txt":
+    """根据parse_mode,构造spans"""
+    if parse_mode == 'txt':
         """ocr 中文本类的 span 用 pymu spans 替换!"""
         pymu_spans = txt_spans_extract(
             pdf_docs[page_id], inline_equations, interline_equations
         )
         spans = replace_text_span(pymu_spans, spans)
-    elif parse_mode == "ocr":
+    elif parse_mode == 'ocr':
         pass
     else:
-        raise Exception("parse_mode must be txt or ocr")
+        raise Exception('parse_mode must be txt or ocr')
 
-    '''删除重叠spans中置信度较低的那些'''
+    """删除重叠spans中置信度较低的那些"""
     spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
-    '''删除重叠spans中较小的那些'''
+    """删除重叠spans中较小的那些"""
     spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
-    '''对image和table截图'''
-    spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
+    """对image和table截图"""
+    spans = ocr_cut_image_and_table(
+        spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter
+    )
 
-    '''将所有区块的bbox整理到一起'''
+    """将所有区块的bbox整理到一起"""
     # interline_equation_blocks参数不够准,后面切换到interline_equations上
     interline_equation_blocks = []
     if len(interline_equation_blocks) > 0:
-        all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
-            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
-            interline_equation_blocks, page_w, page_h)
+        all_bboxes, all_discarded_blocks, drop_reasons = (
+            ocr_prepare_bboxes_for_layout_split(
+                img_blocks,
+                table_blocks,
+                discarded_blocks,
+                text_blocks,
+                title_blocks,
+                interline_equation_blocks,
+                page_w,
+                page_h,
+            )
+        )
     else:
-        all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
-            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
-            interline_equations, page_w, page_h)
+        all_bboxes, all_discarded_blocks, drop_reasons = (
+            ocr_prepare_bboxes_for_layout_split(
+                img_blocks,
+                table_blocks,
+                discarded_blocks,
+                text_blocks,
+                title_blocks,
+                interline_equations,
+                page_w,
+                page_h,
+            )
+        )
 
     if len(drop_reasons) > 0:
         need_drop = True
         drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)
 
-    '''先处理不需要排版的discarded_blocks'''
-    discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
+    """先处理不需要排版的discarded_blocks"""
+    discarded_block_with_spans, spans = fill_spans_in_blocks(
+        all_discarded_blocks, spans, 0.4
+    )
     fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
 
-    '''如果当前页面没有bbox则跳过'''
+    """如果当前页面没有bbox则跳过"""
     if len(all_bboxes) == 0:
-        logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
-        return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
-                                               [], [], interline_equations, fix_discarded_blocks,
-                                               need_drop, drop_reason)
+        logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
+        return ocr_construct_page_component_v2(
+            [],
+            [],
+            page_id,
+            page_w,
+            page_h,
+            [],
+            [],
+            [],
+            interline_equations,
+            fix_discarded_blocks,
+            need_drop,
+            drop_reason,
+        )
 
     """在切分之前,先检查一下bbox是否有左右重叠的情况,如果有,那么就认为这个pdf暂时没有能力处理好,这种左右重叠的情况大概率是由于pdf里的行间公式、表格没有被正确识别出来造成的 """
 
     while True:  # 循环检查左右重叠的情况,如果存在就删除掉较小的那个bbox,直到不存在左右重叠的情况
-        is_useful_block_horz_overlap, all_bboxes = remove_horizontal_overlap_block_which_smaller(all_bboxes)
+        is_useful_block_horz_overlap, all_bboxes = (
+            remove_horizontal_overlap_block_which_smaller(all_bboxes)
+        )
         if is_useful_block_horz_overlap:
             need_drop = True
             drop_reason.append(DropReason.USEFUL_BLOCK_HOR_OVERLAP)
         else:
             break
 
-    '''根据区块信息计算layout'''
+    """根据区块信息计算layout"""
     page_boundry = [0, 0, page_w, page_h]
     layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)
 
     if len(text_blocks) > 0 and len(all_bboxes) > 0 and len(layout_bboxes) == 0:
         logger.warning(
-            f"skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
+            f'skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}'
+        )
         need_drop = True
         drop_reason.append(DropReason.CAN_NOT_DETECT_PAGE_LAYOUT)
 
     """以下去掉复杂的布局和超过2列的布局"""
-    if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]):  # 复杂的布局
+    if any(
+        [lay['layout_label'] == LAYOUT_UNPROC for lay in layout_bboxes]
+    ):  # 复杂的布局
         logger.warning(
-            f"skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}")
+            f'skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}'
+        )
         need_drop = True
         drop_reason.append(DropReason.COMPLICATED_LAYOUT)
 
     layout_column_width = get_columns_cnt_of_layout(layout_tree)
     if layout_column_width > 2:  # 去掉超过2列的布局pdf
         logger.warning(
-            f"skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
+            f'skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}'
+        )
         need_drop = True
         drop_reason.append(DropReason.TOO_MANY_LAYOUT_COLUMNS)
 
-    '''根据layout顺序,对当前页面所有需要留下的block进行排序'''
+    """根据layout顺序,对当前页面所有需要留下的block进行排序"""
     sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)
 
-    '''将span填入排好序的blocks中'''
+    """将span填入排好序的blocks中"""
     block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.3)
 
-    '''对block进行fix操作'''
+    """对block进行fix操作"""
     fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
 
-    '''获取QA需要外置的list'''
+    """获取QA需要外置的list"""
     images, tables, interline_equations = get_qa_need_list_v2(fix_blocks)
 
-    '''构造pdf_info_dict'''
-    page_info = ocr_construct_page_component_v2(fix_blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
-                                                images, tables, interline_equations, fix_discarded_blocks,
-                                                need_drop, drop_reason)
+    """构造pdf_info_dict"""
+    page_info = ocr_construct_page_component_v2(
+        fix_blocks,
+        layout_bboxes,
+        page_id,
+        page_w,
+        page_h,
+        layout_tree,
+        images,
+        tables,
+        interline_equations,
+        fix_discarded_blocks,
+        need_drop,
+        drop_reason,
+    )
     return page_info
 
 
-def pdf_parse_union(pdf_bytes,
-                    model_list,
-                    imageWriter,
-                    parse_mode,
-                    start_page_id=0,
-                    end_page_id=None,
-                    debug_mode=False,
-                    ):
+def pdf_parse_union(
+    pdf_bytes,
+    model_list,
+    imageWriter,
+    parse_mode,
+    start_page_id=0,
+    end_page_id=None,
+    debug_mode=False,
+):
     pdf_bytes_md5 = compute_md5(pdf_bytes)
-    pdf_docs = fitz.open("pdf", pdf_bytes)
+    pdf_docs = fitz.open('pdf', pdf_bytes)
 
-    '''初始化空的pdf_info_dict'''
+    """初始化空的pdf_info_dict"""
     pdf_info_dict = {}
 
-    '''用model_list和docs对象初始化magic_model'''
+    """用model_list和docs对象初始化magic_model"""
     magic_model = MagicModel(model_list, pdf_docs)
 
-    '''根据输入的起始范围解析pdf'''
+    """根据输入的起始范围解析pdf"""
     # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
-    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
+    end_page_id = (
+        end_page_id
+        if end_page_id is not None and end_page_id >= 0
+        else len(pdf_docs) - 1
+    )
 
     if end_page_id > len(pdf_docs) - 1:
-        logger.warning("end_page_id is out of range, use pdf_docs length")
+        logger.warning('end_page_id is out of range, use pdf_docs length')
         end_page_id = len(pdf_docs) - 1
 
-    '''初始化启动时间'''
+    """初始化启动时间"""
     start_time = time.time()
 
     for page_id, page in enumerate(pdf_docs):
-        '''debug时输出每页解析的耗时'''
+        """debug时输出每页解析的耗时."""
         if debug_mode:
             time_now = time.time()
             logger.info(
-                f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
+                f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
             )
             start_time = time_now
 
-        '''解析pdf中的每一页'''
+        """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
-            page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+            page_info = parse_page_core(
+                pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
+            )
         else:
             page_w = page.rect.width
             page_h = page.rect.height
-            page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
-                                                [], [], [], [],
-                                                True, "skip page")
-        pdf_info_dict[f"page_{page_id}"] = page_info
+            page_info = ocr_construct_page_component_v2(
+                [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
+            )
+        pdf_info_dict[f'page_{page_id}'] = page_info
 
     """分段"""
     para_split(pdf_info_dict, debug_mode=debug_mode)
@@ -261,7 +335,7 @@ def pdf_parse_union(pdf_bytes,
     """dict转list"""
     pdf_info_list = dict_to_list(pdf_info_dict)
     new_pdf_info_dict = {
-        "pdf_info": pdf_info_list,
+        'pdf_info': pdf_info_list,
     }
 
     return new_pdf_info_dict

+ 10 - 8
magic_pdf/pdf_parse_union_core_v2.py

@@ -7,17 +7,17 @@ from typing import List
 import torch
 from loguru import logger
 
+from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.data.dataset import Dataset, PageableData
 from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.commons import fitz, get_delta_time
 from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
 from magic_pdf.libs.convert_utils import dict_to_list
-from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.local_math import float_equal
-from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
@@ -30,8 +30,8 @@ from magic_pdf.pre_proc.equations_replace import (
 from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
     ocr_prepare_bboxes_for_layout_split_v2
 from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
-                                               fix_discarded_block,
-                                               fix_block_spans_v2)
+                                               fix_block_spans_v2,
+                                               fix_discarded_block)
 from magic_pdf.pre_proc.ocr_span_list_modify import (
     get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
     remove_overlaps_min_spans)
@@ -164,8 +164,8 @@ class ModelSingleton:
 
 
 def do_predict(boxes: List[List[int]], model) -> List[int]:
-    from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (boxes2inputs, parse_logits,
-                                                                                 prepare_inputs)
+    from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (
+        boxes2inputs, parse_logits, prepare_inputs)
 
     inputs = boxes2inputs(boxes)
     inputs = prepare_inputs(inputs, model)
@@ -206,7 +206,9 @@ def cal_block_index(fix_blocks, sorted_bboxes):
                 del block['real_lines']
 
         import numpy as np
-        from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import recursive_xy_cut
+
+        from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import \
+            recursive_xy_cut
 
         random_boxes = np.array(block_bboxes)
         np.random.shuffle(random_boxes)
@@ -291,7 +293,7 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
                     page_line_list.append(bbox)
         elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
             bbox = block['bbox']
-            block["real_lines"] = copy.deepcopy(block['lines'])
+            block['real_lines'] = copy.deepcopy(block['lines'])
             lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
             block['lines'] = []
             for line in lines:

+ 2 - 2
magic_pdf/pipe/AbsPipe.py

@@ -1,12 +1,12 @@
 from abc import ABC, abstractmethod
 
+from magic_pdf.config.drop_reason import DropReason
+from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.dict2md.ocr_mkcontent import union_make
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
-from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.json_compressor import JsonCompressor
-from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 
 
 class AbsPipe(ABC):

+ 1 - 1
magic_pdf/pipe/OCRPipe.py

@@ -1,7 +1,7 @@
 from loguru import logger
 
+from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
-from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_ocr_pdf

+ 1 - 1
magic_pdf/pipe/TXTPipe.py

@@ -1,7 +1,7 @@
 from loguru import logger
 
+from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
-from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_txt_pdf

+ 1 - 1
magic_pdf/pipe/UNIPipe.py

@@ -2,9 +2,9 @@ import json
 
 from loguru import logger
 
+from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.commons import join_path
-from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf

+ 7 - 14
magic_pdf/post_proc/pdf_post_filter.py

@@ -1,19 +1,17 @@
 from loguru import logger
 
+from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.layout.layout_sort import get_columns_cnt_of_layout
-from magic_pdf.libs.drop_reason import DropReason
 
 
 def __is_pseudo_single_column(page_info) -> bool:
-    """
-    判断一个页面是否伪单列。
+    """判断一个页面是否伪单列。
 
     Args:
         page_info (dict): 页面信息字典,包括'_layout_tree'和'preproc_blocks'。
 
     Returns:
         Tuple[bool, Optional[str]]: 如果页面伪单列返回(True, extra_info),否则返回(False, None)。
-
     """
     layout_tree = page_info['_layout_tree']
     layout_column_width = get_columns_cnt_of_layout(layout_tree)
@@ -41,27 +39,22 @@ def __is_pseudo_single_column(page_info) -> bool:
             if num_lines > 20:
                 radio = num_satisfying_lines / num_lines
                 if radio >= 0.5:
-                    extra_info = f"{{num_lines: {num_lines}, num_satisfying_lines: {num_satisfying_lines}}}"
+                    extra_info = f'{{num_lines: {num_lines}, num_satisfying_lines: {num_satisfying_lines}}}'
                     block_text = []
                     for line in lines:
                         if line['spans']:
                             for span in line['spans']:
                                 block_text.append(span['text'])
-                    logger.warning(f"pseudo_single_column block_text: {block_text}")
+                    logger.warning(f'pseudo_single_column block_text: {block_text}')
                     return True, extra_info
 
     return False, None
 
 
 def pdf_post_filter(page_info) -> tuple:
-    """
-    return:(True|False, err_msg)
-        True, 如果pdf符合要求
-        False, 如果pdf不符合要求
-
-    """
+    """return:(True|False, err_msg) True, 如果pdf符合要求 False, 如果pdf不符合要求."""
     bool_is_pseudo_single_column, extra_info = __is_pseudo_single_column(page_info)
     if bool_is_pseudo_single_column:
-        return False, {"_need_drop": True, "_drop_reason": DropReason.PSEUDO_SINGLE_COLUMN, "extra_info": extra_info}
+        return False, {'_need_drop': True, '_drop_reason': DropReason.PSEUDO_SINGLE_COLUMN, 'extra_info': extra_info}
 
-    return True, None
+    return True, None

+ 9 - 11
magic_pdf/pre_proc/cut_image.py

@@ -1,7 +1,7 @@
 from loguru import logger
 
+from magic_pdf.config.ocr_content_type import ContentType
 from magic_pdf.libs.commons import join_path
-from magic_pdf.libs.ocr_content_type import ContentType
 from magic_pdf.libs.pdf_image_tools import cut_image
 
 
@@ -29,9 +29,7 @@ def txt_save_images_by_bboxes(page_num: int, page, pdf_bytes_md5: str,
                               image_bboxes: list, images_overlap_backup: list, table_bboxes: list,
                               equation_inline_bboxes: list,
                               equation_interline_bboxes: list, imageWriter) -> dict:
-    """
-    返回一个dict, key为bbox, 值是图片地址
-    """
+    """返回一个dict, key为bbox, 值是图片地址."""
     image_info = []
     image_backup_info = []
     table_info = []
@@ -46,26 +44,26 @@ def txt_save_images_by_bboxes(page_num: int, page, pdf_bytes_md5: str,
     for bbox in image_bboxes:
         if not check_img_bbox(bbox):
             continue
-        image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
-        image_info.append({"bbox": bbox, "image_path": image_path})
+        image_path = cut_image(bbox, page_num, page, return_path('images'), imageWriter)
+        image_info.append({'bbox': bbox, 'image_path': image_path})
 
     for bbox in images_overlap_backup:
         if not check_img_bbox(bbox):
             continue
-        image_path = cut_image(bbox, page_num, page, return_path("images"), imageWriter)
-        image_backup_info.append({"bbox": bbox, "image_path": image_path})
+        image_path = cut_image(bbox, page_num, page, return_path('images'), imageWriter)
+        image_backup_info.append({'bbox': bbox, 'image_path': image_path})
 
     for bbox in table_bboxes:
         if not check_img_bbox(bbox):
             continue
-        image_path = cut_image(bbox, page_num, page, return_path("tables"), imageWriter)
-        table_info.append({"bbox": bbox, "image_path": image_path})
+        image_path = cut_image(bbox, page_num, page, return_path('tables'), imageWriter)
+        table_info.append({'bbox': bbox, 'image_path': image_path})
 
     return image_info, image_backup_info, table_info, inline_eq_info, interline_eq_info
 
 
 def check_img_bbox(bbox) -> bool:
     if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
-        logger.warning(f"image_bboxes: 错误的box, {bbox}")
+        logger.warning(f'image_bboxes: 错误的box, {bbox}')
         return False
     return True

+ 203 - 212
magic_pdf/pre_proc/equations_replace.py

@@ -1,49 +1,45 @@
-"""
-对pymupdf返回的结构里的公式进行替换,替换为模型识别的公式结果
-"""
+"""对pymupdf返回的结构里的公式进行替换,替换为模型识别的公式结果."""
 
-from magic_pdf.libs.commons import fitz
 import json
 import os
 from pathlib import Path
+
 from loguru import logger
-from magic_pdf.libs.ocr_content_type import ContentType
+
+from magic_pdf.config.ocr_content_type import ContentType
+from magic_pdf.libs.commons import fitz
 
 TYPE_INLINE_EQUATION = ContentType.InlineEquation
 TYPE_INTERLINE_EQUATION = ContentType.InterlineEquation
 
 
 def combine_chars_to_pymudict(block_dict, char_dict):
-    """
-    把block级别的pymupdf 结构里加入char结构
-    """
+    """把block级别的pymupdf 结构里加入char结构."""
     # 因为block_dict 被裁剪过,因此先把他和char_dict文字块对齐,才能进行补充
-    char_map = {tuple(item["bbox"]): item for item in char_dict}
+    char_map = {tuple(item['bbox']): item for item in char_dict}
 
-    for i in range(len(block_dict)):  # blcok
+    for i in range(len(block_dict)):  # block
         block = block_dict[i]
-        key = block["bbox"]
+        key = block['bbox']
         char_dict_item = char_map[tuple(key)]
-        char_dict_map = {tuple(item["bbox"]): item for item in char_dict_item["lines"]}
-        for j in range(len(block["lines"])):
-            lines = block["lines"][j]
-            with_char_lines = char_dict_map[lines["bbox"]]
-            for k in range(len(lines["spans"])):
-                spans = lines["spans"][k]
+        char_dict_map = {tuple(item['bbox']): item for item in char_dict_item['lines']}
+        for j in range(len(block['lines'])):
+            lines = block['lines'][j]
+            with_char_lines = char_dict_map[lines['bbox']]
+            for k in range(len(lines['spans'])):
+                spans = lines['spans'][k]
                 try:
-                    chars = with_char_lines["spans"][k]["chars"]
-                except Exception as e:
-                    logger.error(char_dict[i]["lines"][j])
+                    chars = with_char_lines['spans'][k]['chars']
+                except Exception:
+                    logger.error(char_dict[i]['lines'][j])
 
-                spans["chars"] = chars
+                spans['chars'] = chars
 
     return block_dict
 
 
 def calculate_overlap_area_2_minbox_area_ratio(bbox1, min_bbox):
-    """
-    计算box1和box2的重叠面积占最小面积的box的比例
-    """
+    """计算box1和box2的重叠面积占最小面积的box的比例."""
     # Determine the coordinates of the intersection rectangle
     x_left = max(bbox1[0], min_bbox[0])
     y_top = max(bbox1[1], min_bbox[1])
@@ -74,13 +70,13 @@ def _is_xin(bbox1, bbox2):
 
 
 def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks):
-    """消除掉整个块都在行间公式块内部的文本块"""
+    """消除掉整个块都在行间公式块内部的文本块."""
     for eq_bbox in interline_bboxes:
         removed_txt_blk = []
         for text_blk in text_blocks:
-            text_bbox = text_blk["bbox"]
+            text_bbox = text_blk['bbox']
             if (
-                calculate_overlap_area_2_minbox_area_ratio(eq_bbox["bbox"], text_bbox)
+                calculate_overlap_area_2_minbox_area_ratio(eq_bbox['bbox'], text_bbox)
                 >= 0.7
             ):
                 removed_txt_blk.append(text_blk)
@@ -91,9 +87,7 @@ def remove_text_block_in_interline_equation_bbox(interline_bboxes, text_blocks):
 
 
 def _is_in_or_part_overlap(box1, box2) -> bool:
-    """
-    两个bbox是否有部分重叠或者包含
-    """
+    """两个bbox是否有部分重叠或者包含."""
     if box1 is None or box2 is None:
         return False
 
@@ -111,62 +105,65 @@ def _is_in_or_part_overlap(box1, box2) -> bool:
 def remove_text_block_overlap_interline_equation_bbox(
     interline_eq_bboxes, pymu_block_list
 ):
-
-    """消除掉行行内公式有部分重叠的文本块的内容。
-    同时重新计算消除重叠之后文本块的大小"""
+    """消除掉行行内公式有部分重叠的文本块的内容。 同时重新计算消除重叠之后文本块的大小."""
     deleted_block = []
     for text_block in pymu_block_list:
         deleted_line = []
-        for line in text_block["lines"]:
+        for line in text_block['lines']:
             deleted_span = []
-            for span in line["spans"]:
+            for span in line['spans']:
                 deleted_chars = []
-                for char in span["chars"]:
+                for char in span['chars']:
                     if any(
-                            [
-                                (calculate_overlap_area_2_minbox_area_ratio(eq_bbox["bbox"], char["bbox"]) > 0.5)
-                                for eq_bbox in interline_eq_bboxes
-                            ]
+                        [
+                            (
+                                calculate_overlap_area_2_minbox_area_ratio(
+                                    eq_bbox['bbox'], char['bbox']
+                                )
+                                > 0.5
+                            )
+                            for eq_bbox in interline_eq_bboxes
+                        ]
                     ):
                         deleted_chars.append(char)
                 # 检查span里没有char则删除这个span
                 for char in deleted_chars:
-                    span["chars"].remove(char)
+                    span['chars'].remove(char)
                 # 重新计算这个span的大小
-                if len(span["chars"]) == 0:  # 删除这个span
+                if len(span['chars']) == 0:  # 删除这个span
                     deleted_span.append(span)
                 else:
-                    span["bbox"] = (
-                        min([b["bbox"][0] for b in span["chars"]]),
-                        min([b["bbox"][1] for b in span["chars"]]),
-                        max([b["bbox"][2] for b in span["chars"]]),
-                        max([b["bbox"][3] for b in span["chars"]]),
+                    span['bbox'] = (
+                        min([b['bbox'][0] for b in span['chars']]),
+                        min([b['bbox'][1] for b in span['chars']]),
+                        max([b['bbox'][2] for b in span['chars']]),
+                        max([b['bbox'][3] for b in span['chars']]),
                     )
 
             # 检查这个span
             for span in deleted_span:
-                line["spans"].remove(span)
-            if len(line["spans"]) == 0:  # 删除这个line
+                line['spans'].remove(span)
+            if len(line['spans']) == 0:  # 删除这个line
                 deleted_line.append(line)
             else:
-                line["bbox"] = (
-                    min([b["bbox"][0] for b in line["spans"]]),
-                    min([b["bbox"][1] for b in line["spans"]]),
-                    max([b["bbox"][2] for b in line["spans"]]),
-                    max([b["bbox"][3] for b in line["spans"]]),
+                line['bbox'] = (
+                    min([b['bbox'][0] for b in line['spans']]),
+                    min([b['bbox'][1] for b in line['spans']]),
+                    max([b['bbox'][2] for b in line['spans']]),
+                    max([b['bbox'][3] for b in line['spans']]),
                 )
 
         # 检查这个block是否可以删除
         for line in deleted_line:
-            text_block["lines"].remove(line)
-        if len(text_block["lines"]) == 0:  # 删除block
+            text_block['lines'].remove(line)
+        if len(text_block['lines']) == 0:  # 删除block
             deleted_block.append(text_block)
         else:
-            text_block["bbox"] = (
-                min([b["bbox"][0] for b in text_block["lines"]]),
-                min([b["bbox"][1] for b in text_block["lines"]]),
-                max([b["bbox"][2] for b in text_block["lines"]]),
-                max([b["bbox"][3] for b in text_block["lines"]]),
+            text_block['bbox'] = (
+                min([b['bbox'][0] for b in text_block['lines']]),
+                min([b['bbox'][1] for b in text_block['lines']]),
+                max([b['bbox'][2] for b in text_block['lines']]),
+                max([b['bbox'][3] for b in text_block['lines']]),
             )
 
     # 检查text block删除
@@ -179,33 +176,33 @@ def remove_text_block_overlap_interline_equation_bbox(
 
 
 def insert_interline_equations_textblock(interline_eq_bboxes, pymu_block_list):
-    """在行间公式对应的地方插上一个伪造的block"""
+    """在行间公式对应的地方插上一个伪造的block."""
     for eq in interline_eq_bboxes:
-        bbox = eq["bbox"]
-        latex_content = eq["latex"]
+        bbox = eq['bbox']
+        latex_content = eq['latex']
         text_block = {
-            "number": len(pymu_block_list),
-            "type": 0,
-            "bbox": bbox,
-            "lines": [
+            'number': len(pymu_block_list),
+            'type': 0,
+            'bbox': bbox,
+            'lines': [
                 {
-                    "spans": [
+                    'spans': [
                         {
-                            "size": 9.962599754333496,
-                            "type": TYPE_INTERLINE_EQUATION,
-                            "flags": 4,
-                            "font": TYPE_INTERLINE_EQUATION,
-                            "color": 0,
-                            "ascender": 0.9409999847412109,
-                            "descender": -0.3050000071525574,
-                            "latex": latex_content,
-                            "origin": [bbox[0], bbox[1]],
-                            "bbox": bbox,
+                            'size': 9.962599754333496,
+                            'type': TYPE_INTERLINE_EQUATION,
+                            'flags': 4,
+                            'font': TYPE_INTERLINE_EQUATION,
+                            'color': 0,
+                            'ascender': 0.9409999847412109,
+                            'descender': -0.3050000071525574,
+                            'latex': latex_content,
+                            'origin': [bbox[0], bbox[1]],
+                            'bbox': bbox,
                         }
                     ],
-                    "wmode": 0,
-                    "dir": [1.0, 0.0],
-                    "bbox": bbox,
+                    'wmode': 0,
+                    'dir': [1.0, 0.0],
+                    'bbox': bbox,
                 }
             ],
         }
@@ -250,53 +247,52 @@ def __y_overlap_ratio(box1, box2):
 
 
 def replace_line_v2(eqinfo, line):
-    """
-    扫描这一行所有的和公式框X方向重叠的char,然后计算char的左、右x0, x1,位于这个区间内的span删除掉。
-    最后与这个x0,x1有相交的span0, span1内部进行分割。
-    """
+    """扫描这一行所有的和公式框X方向重叠的char,然后计算char的左、右x0, x1,位于这个区间内的span删除掉。
+    最后与这个x0,x1有相交的span0, span1内部进行分割。"""
     first_overlap_span = -1
     first_overlap_span_idx = -1
     last_overlap_span = -1
     delete_chars = []
-    for i in range(0, len(line["spans"])):
-        if "chars" not in line["spans"][i]:
+    for i in range(0, len(line['spans'])):
+        if 'chars' not in line['spans'][i]:
             continue
 
-        if line["spans"][i].get("_type", None) is not None:
+        if line['spans'][i].get('_type', None) is not None:
             continue  # 忽略,因为已经是插入的伪造span公式了
 
-        for char in line["spans"][i]["chars"]:
-            if __is_x_dir_overlap(eqinfo["bbox"], char["bbox"]):
-                line_txt = ""
-                for span in line["spans"]:
-                    span_txt = "<span>"
-                    for ch in span["chars"]:
-                        span_txt = span_txt + ch["c"]
+        for char in line['spans'][i]['chars']:
+            if __is_x_dir_overlap(eqinfo['bbox'], char['bbox']):
+                line_txt = ''
+                for span in line['spans']:
+                    span_txt = '<span>'
+                    for ch in span['chars']:
+                        span_txt = span_txt + ch['c']
 
-                    span_txt = span_txt + "</span>"
+                    span_txt = span_txt + '</span>'
 
                     line_txt = line_txt + span_txt
 
                 if first_overlap_span_idx == -1:
-                    first_overlap_span = line["spans"][i]
+                    first_overlap_span = line['spans'][i]
                     first_overlap_span_idx = i
-                last_overlap_span = line["spans"][i]
+                last_overlap_span = line['spans'][i]
                 delete_chars.append(char)
 
     # 第一个和最后一个char要进行检查,到底属于公式多还是属于正常span多
     if len(delete_chars) > 0:
-        ch0_bbox = delete_chars[0]["bbox"]
-        if x_overlap_ratio(eqinfo["bbox"], ch0_bbox) < 0.51:
+        ch0_bbox = delete_chars[0]['bbox']
+        if x_overlap_ratio(eqinfo['bbox'], ch0_bbox) < 0.51:
             delete_chars.remove(delete_chars[0])
     if len(delete_chars) > 0:
-        ch0_bbox = delete_chars[-1]["bbox"]
-        if x_overlap_ratio(eqinfo["bbox"], ch0_bbox) < 0.51:
+        ch0_bbox = delete_chars[-1]['bbox']
+        if x_overlap_ratio(eqinfo['bbox'], ch0_bbox) < 0.51:
             delete_chars.remove(delete_chars[-1])
 
     # 计算x方向上被删除区间内的char的真实x0, x1
     if len(delete_chars):
-        x0, x1 = min([b["bbox"][0] for b in delete_chars]), max(
-            [b["bbox"][2] for b in delete_chars]
+        x0, x1 = (
+            min([b['bbox'][0] for b in delete_chars]),
+            max([b['bbox'][2] for b in delete_chars]),
         )
     else:
         # logger.debug(f"行内公式替换没有发生,尝试下一行匹配, eqinfo={eqinfo}")
@@ -304,101 +300,101 @@ def replace_line_v2(eqinfo, line):
 
     # 删除位于x0, x1这两个中间的span
     delete_span = []
-    for span in line["spans"]:
-        span_box = span["bbox"]
+    for span in line['spans']:
+        span_box = span['bbox']
         if x0 <= span_box[0] and span_box[2] <= x1:
             delete_span.append(span)
     for span in delete_span:
-        line["spans"].remove(span)
+        line['spans'].remove(span)
 
     equation_span = {
-        "size": 9.962599754333496,
-        "type": TYPE_INLINE_EQUATION,
-        "flags": 4,
-        "font": TYPE_INLINE_EQUATION,
-        "color": 0,
-        "ascender": 0.9409999847412109,
-        "descender": -0.3050000071525574,
-        "latex": "",
-        "origin": [337.1410153102337, 216.0205245153934],
-        "bbox": eqinfo["bbox"]
+        'size': 9.962599754333496,
+        'type': TYPE_INLINE_EQUATION,
+        'flags': 4,
+        'font': TYPE_INLINE_EQUATION,
+        'color': 0,
+        'ascender': 0.9409999847412109,
+        'descender': -0.3050000071525574,
+        'latex': '',
+        'origin': [337.1410153102337, 216.0205245153934],
+        'bbox': eqinfo['bbox'],
     }
     # equation_span = line['spans'][0].copy()
-    equation_span["latex"] = eqinfo['latex']
-    equation_span["bbox"] = [x0, equation_span["bbox"][1], x1, equation_span["bbox"][3]]
-    equation_span["origin"] = [equation_span["bbox"][0], equation_span["bbox"][1]]
-    equation_span["chars"] = delete_chars
-    equation_span["type"] = TYPE_INLINE_EQUATION
-    equation_span["_eq_bbox"] = eqinfo["bbox"]
-    line["spans"].insert(first_overlap_span_idx + 1, equation_span)  # 放入公式
+    equation_span['latex'] = eqinfo['latex']
+    equation_span['bbox'] = [x0, equation_span['bbox'][1], x1, equation_span['bbox'][3]]
+    equation_span['origin'] = [equation_span['bbox'][0], equation_span['bbox'][1]]
+    equation_span['chars'] = delete_chars
+    equation_span['type'] = TYPE_INLINE_EQUATION
+    equation_span['_eq_bbox'] = eqinfo['bbox']
+    line['spans'].insert(first_overlap_span_idx + 1, equation_span)  # 放入公式
 
     # logger.info(f"==>text is 【{line_txt}】, equation is 【{eqinfo['latex_text']}】")
 
     # 第一个、和最后一个有overlap的span进行分割,然后插入对应的位置
     first_span_chars = [
         char
-        for char in first_overlap_span["chars"]
-        if (char["bbox"][2] + char["bbox"][0]) / 2 < x0
+        for char in first_overlap_span['chars']
+        if (char['bbox'][2] + char['bbox'][0]) / 2 < x0
     ]
     tail_span_chars = [
         char
-        for char in last_overlap_span["chars"]
-        if (char["bbox"][0] + char["bbox"][2]) / 2 > x1
+        for char in last_overlap_span['chars']
+        if (char['bbox'][0] + char['bbox'][2]) / 2 > x1
     ]
 
     if len(first_span_chars) > 0:
-        first_overlap_span["chars"] = first_span_chars
-        first_overlap_span["text"] = "".join([char["c"] for char in first_span_chars])
-        first_overlap_span["bbox"] = (
-            first_overlap_span["bbox"][0],
-            first_overlap_span["bbox"][1],
-            max([chr["bbox"][2] for chr in first_span_chars]),
-            first_overlap_span["bbox"][3],
+        first_overlap_span['chars'] = first_span_chars
+        first_overlap_span['text'] = ''.join([char['c'] for char in first_span_chars])
+        first_overlap_span['bbox'] = (
+            first_overlap_span['bbox'][0],
+            first_overlap_span['bbox'][1],
+            max([chr['bbox'][2] for chr in first_span_chars]),
+            first_overlap_span['bbox'][3],
         )
         # first_overlap_span['_type'] = "first"
     else:
         # 删掉
         if first_overlap_span not in delete_span:
-            line["spans"].remove(first_overlap_span)
+            line['spans'].remove(first_overlap_span)
 
     if len(tail_span_chars) > 0:
-        min_of_tail_span_x0 = min([chr["bbox"][0] for chr in tail_span_chars])
-        min_of_tail_span_y0 = min([chr["bbox"][1] for chr in tail_span_chars])
-        max_of_tail_span_x1 = max([chr["bbox"][2] for chr in tail_span_chars])
-        max_of_tail_span_y1 = max([chr["bbox"][3] for chr in tail_span_chars])
+        min_of_tail_span_x0 = min([chr['bbox'][0] for chr in tail_span_chars])
+        min_of_tail_span_y0 = min([chr['bbox'][1] for chr in tail_span_chars])
+        max_of_tail_span_x1 = max([chr['bbox'][2] for chr in tail_span_chars])
+        max_of_tail_span_y1 = max([chr['bbox'][3] for chr in tail_span_chars])
 
         if last_overlap_span == first_overlap_span:  # 这个时候应该插入一个新的
-            tail_span_txt = "".join([char["c"] for char in tail_span_chars])
+            tail_span_txt = ''.join([char['c'] for char in tail_span_chars])  # noqa: F841
             last_span_to_insert = last_overlap_span.copy()
-            last_span_to_insert["chars"] = tail_span_chars
-            last_span_to_insert["text"] = "".join(
-                [char["c"] for char in tail_span_chars]
+            last_span_to_insert['chars'] = tail_span_chars
+            last_span_to_insert['text'] = ''.join(
+                [char['c'] for char in tail_span_chars]
             )
-            if equation_span["bbox"][2] >= last_overlap_span["bbox"][2]:
-                last_span_to_insert["bbox"] = (
+            if equation_span['bbox'][2] >= last_overlap_span['bbox'][2]:
+                last_span_to_insert['bbox'] = (
                     min_of_tail_span_x0,
                     min_of_tail_span_y0,
                     max_of_tail_span_x1,
-                    max_of_tail_span_y1
+                    max_of_tail_span_y1,
                 )
             else:
-                last_span_to_insert["bbox"] = (
-                    min([chr["bbox"][0] for chr in tail_span_chars]),
-                    last_overlap_span["bbox"][1],
-                    last_overlap_span["bbox"][2],
-                    last_overlap_span["bbox"][3],
+                last_span_to_insert['bbox'] = (
+                    min([chr['bbox'][0] for chr in tail_span_chars]),
+                    last_overlap_span['bbox'][1],
+                    last_overlap_span['bbox'][2],
+                    last_overlap_span['bbox'][3],
                 )
             # 插入到公式对象之后
-            equation_idx = line["spans"].index(equation_span)
-            line["spans"].insert(equation_idx + 1, last_span_to_insert)  # 放入公式
+            equation_idx = line['spans'].index(equation_span)
+            line['spans'].insert(equation_idx + 1, last_span_to_insert)  # 放入公式
         else:  # 直接修改原来的span
-            last_overlap_span["chars"] = tail_span_chars
-            last_overlap_span["text"] = "".join([char["c"] for char in tail_span_chars])
-            last_overlap_span["bbox"] = (
-                min([chr["bbox"][0] for chr in tail_span_chars]),
-                last_overlap_span["bbox"][1],
-                last_overlap_span["bbox"][2],
-                last_overlap_span["bbox"][3],
+            last_overlap_span['chars'] = tail_span_chars
+            last_overlap_span['text'] = ''.join([char['c'] for char in tail_span_chars])
+            last_overlap_span['bbox'] = (
+                min([chr['bbox'][0] for chr in tail_span_chars]),
+                last_overlap_span['bbox'][1],
+                last_overlap_span['bbox'][2],
+                last_overlap_span['bbox'][3],
             )
     else:
         # 删掉
@@ -406,15 +402,15 @@ def replace_line_v2(eqinfo, line):
             last_overlap_span not in delete_span
             and last_overlap_span != first_overlap_span
         ):
-            line["spans"].remove(last_overlap_span)
+            line['spans'].remove(last_overlap_span)
 
-    remain_txt = ""
-    for span in line["spans"]:
-        span_txt = "<span>"
-        for char in span["chars"]:
-            span_txt = span_txt + char["c"]
+    remain_txt = ''
+    for span in line['spans']:
+        span_txt = '<span>'
+        for char in span['chars']:
+            span_txt = span_txt + char['c']
 
-        span_txt = span_txt + "</span>"
+        span_txt = span_txt + '</span>'
 
         remain_txt = remain_txt + span_txt
 
@@ -424,17 +420,15 @@ def replace_line_v2(eqinfo, line):
 
 
 def replace_eq_blk(eqinfo, text_block):
-    """替换行内公式"""
-    for line in text_block["lines"]:
-        line_bbox = line["bbox"]
+    """替换行内公式."""
+    for line in text_block['lines']:
+        line_bbox = line['bbox']
         if (
-            _is_xin(eqinfo["bbox"], line_bbox)
-            or __y_overlap_ratio(eqinfo["bbox"], line_bbox) > 0.6
+            _is_xin(eqinfo['bbox'], line_bbox)
+            or __y_overlap_ratio(eqinfo['bbox'], line_bbox) > 0.6
         ):  # 定位到行, 使用y方向重合率是因为有的时候,一个行的宽度会小于公式位置宽度:行很高,公式很窄,
             replace_succ = replace_line_v2(eqinfo, line)
-            if (
-                not replace_succ
-            ):  # 有的时候,一个pdf的line高度从API里会计算的有问题,因此在行内span级别会替换不成功,这就需要继续重试下一行
+            if not replace_succ:  # 有的时候,一个pdf的line高度从API里会计算的有问题,因此在行内span级别会替换不成功,这就需要继续重试下一行
                 continue
             else:
                 break
@@ -444,13 +438,13 @@ def replace_eq_blk(eqinfo, text_block):
 
 
 def replace_inline_equations(inline_equation_bboxes, raw_text_blocks):
-    """替换行内公式"""
+    """替换行内公式."""
     for eqinfo in inline_equation_bboxes:
-        eqbox = eqinfo["bbox"]
+        eqbox = eqinfo['bbox']
         for blk in raw_text_blocks:
-            if _is_xin(eqbox, blk["bbox"]):
+            if _is_xin(eqbox, blk['bbox']):
                 if not replace_eq_blk(eqinfo, blk):
-                    logger.warning(f"行内公式没有替换成功:{eqinfo} ")
+                    logger.warning(f'行内公式没有替换成功:{eqinfo} ')
                 else:
                     break
 
@@ -458,20 +452,18 @@ def replace_inline_equations(inline_equation_bboxes, raw_text_blocks):
 
 
 def remove_chars_in_text_blocks(text_blocks):
-    """删除text_blocks里的char"""
+    """删除text_blocks里的char."""
     for blk in text_blocks:
-        for line in blk["lines"]:
-            for span in line["spans"]:
-                _ = span.pop("chars", "no such key")
+        for line in blk['lines']:
+            for span in line['spans']:
+                _ = span.pop('chars', 'no such key')
     return text_blocks
 
 
 def replace_equations_in_textblock(
     raw_text_blocks, inline_equation_bboxes, interline_equation_bboxes
 ):
-    """
-    替换行间和和行内公式为latex
-    """
+    """替换行间和和行内公式为latex."""
     raw_text_blocks = remove_text_block_in_interline_equation_bbox(
         interline_equation_bboxes, raw_text_blocks
     )  # 消除重叠:第一步,在公式内部的
@@ -486,22 +478,22 @@ def replace_equations_in_textblock(
 
 
 def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path):
-    """ """
-    new_pdf = f"{Path(pdf_path).parent}/{Path(pdf_path).stem}.step3-消除行内公式text_block.pdf"
-    with open(json_path, "r", encoding="utf-8") as f:
+    """"""
+    new_pdf = f'{Path(pdf_path).parent}/{Path(pdf_path).stem}.step3-消除行内公式text_block.pdf'
+    with open(json_path, 'r', encoding='utf-8') as f:
         obj = json.loads(f.read())
 
     if os.path.exists(new_pdf):
         os.remove(new_pdf)
-    new_doc = fitz.open("")
+    new_doc = fitz.open('')
 
-    doc = fitz.open(pdf_path)
+    doc = fitz.open(pdf_path)  # noqa: F841
     new_doc = fitz.open(pdf_path)
     for i in range(len(new_doc)):
         page = new_doc[i]
-        inline_equation_bboxes = obj[f"page_{i}"]["inline_equations"]
-        interline_equation_bboxes = obj[f"page_{i}"]["interline_equations"]
-        raw_text_blocks = obj[f"page_{i}"]["preproc_blocks"]
+        inline_equation_bboxes = obj[f'page_{i}']['inline_equations']
+        interline_equation_bboxes = obj[f'page_{i}']['interline_equations']
+        raw_text_blocks = obj[f'page_{i}']['preproc_blocks']
         raw_text_blocks = remove_text_block_in_interline_equation_bbox(
             interline_equation_bboxes, raw_text_blocks
         )  # 消除重叠:第一步,在公式内部的
@@ -514,11 +506,10 @@ def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path):
         )
 
         # 为了检验公式是否重复,把每一行里,含有公式的span背景改成黄色的
-        color_map = [fitz.pdfcolor["blue"], fitz.pdfcolor["green"]]
-        j = 0
+        color_map = [fitz.pdfcolor['blue'], fitz.pdfcolor['green']]  # noqa: F841
+        j = 0  # noqa: F841
         for blk in raw_text_blocks:
-            for i, line in enumerate(blk["lines"]):
-
+            for i, line in enumerate(blk['lines']):
                 # line_box = line['bbox']
                 # shape = page.new_shape()
                 # shape.draw_rect(line_box)
@@ -526,34 +517,34 @@ def draw_block_on_pdf_with_txt_replace_eq_bbox(json_path, pdf_path):
                 # shape.commit()
                 # j = j+1
 
-                for i, span in enumerate(line["spans"]):
+                for i, span in enumerate(line['spans']):
                     shape_page = page.new_shape()
-                    span_type = span.get("_type")
-                    color = fitz.pdfcolor["blue"]
-                    if span_type == "first":
-                        color = fitz.pdfcolor["blue"]
-                    elif span_type == "tail":
-                        color = fitz.pdfcolor["green"]
+                    span_type = span.get('_type')
+                    color = fitz.pdfcolor['blue']
+                    if span_type == 'first':
+                        color = fitz.pdfcolor['blue']
+                    elif span_type == 'tail':
+                        color = fitz.pdfcolor['green']
                     elif span_type == TYPE_INLINE_EQUATION:
-                        color = fitz.pdfcolor["black"]
+                        color = fitz.pdfcolor['black']
                     else:
                         color = None
 
-                    b = span["bbox"]
+                    b = span['bbox']
                     shape_page.draw_rect(b)
 
                     shape_page.finish(color=None, fill=color, fill_opacity=0.3)
                     shape_page.commit()
 
     new_doc.save(new_pdf)
-    logger.info(f"save ok {new_pdf}")
+    logger.info(f'save ok {new_pdf}')
     final_json = json.dumps(obj, ensure_ascii=False, indent=2)
-    with open("equations_test/final_json.json", "w") as f:
+    with open('equations_test/final_json.json', 'w') as f:
         f.write(final_json)
 
     return new_pdf
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     # draw_block_on_pdf_with_txt_replace_eq_bbox(new_json_path, equation_color_pdf)
     pass

+ 235 - 49
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -1,60 +1,181 @@
-from loguru import logger
 
-from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio, calculate_overlap_area_in_bbox1_area_ratio, \
-    calculate_iou, calculate_vertical_projection_overlap_ratio
-from magic_pdf.libs.drop_tag import DropTag
-from magic_pdf.libs.ocr_content_type import BlockType
-from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox_for_block
-
-
-def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_blocks, text_blocks,
-                                        title_blocks, interline_equation_blocks, page_w, page_h):
+from magic_pdf.config.ocr_content_type import BlockType
+from magic_pdf.libs.boxbase import (
+    calculate_iou, calculate_overlap_area_in_bbox1_area_ratio,
+    calculate_vertical_projection_overlap_ratio,
+    get_minbox_if_overlap_by_ratio)
+from magic_pdf.pre_proc.remove_bbox_overlap import \
+    remove_overlap_between_bbox_for_block
+
+
+def ocr_prepare_bboxes_for_layout_split(
+    img_blocks,
+    table_blocks,
+    discarded_blocks,
+    text_blocks,
+    title_blocks,
+    interline_equation_blocks,
+    page_w,
+    page_h,
+):
     all_bboxes = []
     all_discarded_blocks = []
     for image in img_blocks:
         x0, y0, x1, y1 = image['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]])
+        all_bboxes.append(
+            [
+                x0,
+                y0,
+                x1,
+                y1,
+                None,
+                None,
+                None,
+                BlockType.Image,
+                None,
+                None,
+                None,
+                None,
+                image['score'],
+            ]
+        )
 
     for table in table_blocks:
         x0, y0, x1, y1 = table['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
+        all_bboxes.append(
+            [
+                x0,
+                y0,
+                x1,
+                y1,
+                None,
+                None,
+                None,
+                BlockType.Table,
+                None,
+                None,
+                None,
+                None,
+                table['score'],
+            ]
+        )
 
     for text in text_blocks:
         x0, y0, x1, y1 = text['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
+        all_bboxes.append(
+            [
+                x0,
+                y0,
+                x1,
+                y1,
+                None,
+                None,
+                None,
+                BlockType.Text,
+                None,
+                None,
+                None,
+                None,
+                text['score'],
+            ]
+        )
 
     for title in title_blocks:
         x0, y0, x1, y1 = title['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]])
+        all_bboxes.append(
+            [
+                x0,
+                y0,
+                x1,
+                y1,
+                None,
+                None,
+                None,
+                BlockType.Title,
+                None,
+                None,
+                None,
+                None,
+                title['score'],
+            ]
+        )
 
     for interline_equation in interline_equation_blocks:
         x0, y0, x1, y1 = interline_equation['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]])
-
-    '''block嵌套问题解决'''
-    '''文本框与标题框重叠,优先信任文本框'''
+        all_bboxes.append(
+            [
+                x0,
+                y0,
+                x1,
+                y1,
+                None,
+                None,
+                None,
+                BlockType.InterlineEquation,
+                None,
+                None,
+                None,
+                None,
+                interline_equation['score'],
+            ]
+        )
+
+    """block嵌套问题解决"""
+    """文本框与标题框重叠,优先信任文本框"""
     all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
-    '''任何框体与舍弃框重叠,优先信任舍弃框'''
+    """任何框体与舍弃框重叠,优先信任舍弃框"""
     all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
 
     # interline_equation 与title或text框冲突的情况,分两种情况处理
-    '''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
+    """interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框"""
     all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
-    '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
+    """interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框"""
     # 通过后续大框套小框逻辑删除
 
-    '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
+    """discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)"""
     for discarded in discarded_blocks:
         x0, y0, x1, y1 = discarded['bbox']
-        all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
+        all_discarded_blocks.append(
+            [
+                x0,
+                y0,
+                x1,
+                y1,
+                None,
+                None,
+                None,
+                BlockType.Discarded,
+                None,
+                None,
+                None,
+                None,
+                discarded['score'],
+            ]
+        )
         # 将footnote加入到all_bboxes中,用来计算layout
         if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
-            all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
-
-    '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
+            all_bboxes.append(
+                [
+                    x0,
+                    y0,
+                    x1,
+                    y1,
+                    None,
+                    None,
+                    None,
+                    BlockType.Footnote,
+                    None,
+                    None,
+                    None,
+                    None,
+                    discarded['score'],
+                ]
+            )
+
+    """经过以上处理后,还存在大框套小框的情况,则删除小框"""
     all_bboxes = remove_overlaps_min_blocks(all_bboxes)
     all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
-    '''将剩余的bbox做分离处理,防止后面分layout时出错'''
+    """将剩余的bbox做分离处理,防止后面分layout时出错"""
     all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
 
     return all_bboxes, all_discarded_blocks, drop_reasons
@@ -64,18 +185,64 @@ def add_bboxes(blocks, block_type, bboxes):
     for block in blocks:
         x0, y0, x1, y1 = block['bbox']
         if block_type in [
-            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
-            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+            BlockType.ImageBody,
+            BlockType.ImageCaption,
+            BlockType.ImageFootnote,
+            BlockType.TableBody,
+            BlockType.TableCaption,
+            BlockType.TableFootnote,
         ]:
-            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
+            bboxes.append(
+                [
+                    x0,
+                    y0,
+                    x1,
+                    y1,
+                    None,
+                    None,
+                    None,
+                    block_type,
+                    None,
+                    None,
+                    None,
+                    None,
+                    block['score'],
+                    block['group_id'],
+                ]
+            )
         else:
-            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
+            bboxes.append(
+                [
+                    x0,
+                    y0,
+                    x1,
+                    y1,
+                    None,
+                    None,
+                    None,
+                    block_type,
+                    None,
+                    None,
+                    None,
+                    None,
+                    block['score'],
+                ]
+            )
 
 
 def ocr_prepare_bboxes_for_layout_split_v2(
-        img_body_blocks, img_caption_blocks, img_footnote_blocks,
-        table_body_blocks, table_caption_blocks, table_footnote_blocks,
-        discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
+    img_body_blocks,
+    img_caption_blocks,
+    img_footnote_blocks,
+    table_body_blocks,
+    table_caption_blocks,
+    table_footnote_blocks,
+    discarded_blocks,
+    text_blocks,
+    title_blocks,
+    interline_equation_blocks,
+    page_w,
+    page_h,
 ):
     all_bboxes = []
 
@@ -89,40 +256,40 @@ def ocr_prepare_bboxes_for_layout_split_v2(
     add_bboxes(title_blocks, BlockType.Title, all_bboxes)
     add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
 
-    '''block嵌套问题解决'''
-    '''文本框与标题框重叠,优先信任文本框'''
+    """block嵌套问题解决"""
+    """文本框与标题框重叠,优先信任文本框"""
     all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
-    '''任何框体与舍弃框重叠,优先信任舍弃框'''
+    """任何框体与舍弃框重叠,优先信任舍弃框"""
     all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
 
     # interline_equation 与title或text框冲突的情况,分两种情况处理
-    '''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
+    """interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框"""
     all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
-    '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
+    """interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框"""
     # 通过后续大框套小框逻辑删除
 
-    '''discarded_blocks'''
+    """discarded_blocks"""
     all_discarded_blocks = []
     add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
 
-    '''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
+    """footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的"""
     footnote_blocks = []
     for discarded in discarded_blocks:
         x0, y0, x1, y1 = discarded['bbox']
         if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
             footnote_blocks.append([x0, y0, x1, y1])
 
-    '''移除在footnote下面的任何框'''
+    """移除在footnote下面的任何框"""
     need_remove_blocks = find_blocks_under_footnote(all_bboxes, footnote_blocks)
     if len(need_remove_blocks) > 0:
         for block in need_remove_blocks:
             all_bboxes.remove(block)
             all_discarded_blocks.append(block)
 
-    '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
+    """经过以上处理后,还存在大框套小框的情况,则删除小框"""
     all_bboxes = remove_overlaps_min_blocks(all_bboxes)
     all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
-    '''将剩余的bbox做分离处理,防止后面分layout时出错'''
+    """将剩余的bbox做分离处理,防止后面分layout时出错"""
     all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
 
     return all_bboxes, all_discarded_blocks
@@ -135,7 +302,13 @@ def find_blocks_under_footnote(all_bboxes, footnote_blocks):
         for footnote_bbox in footnote_blocks:
             footnote_x0, footnote_y0, footnote_x1, footnote_y1 = footnote_bbox
             # 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
-            if block_y0 >= footnote_y1 and calculate_vertical_projection_overlap_ratio((block_x0, block_y0, block_x1, block_y1), footnote_bbox) >= 0.8:
+            if (
+                block_y0 >= footnote_y1
+                and calculate_vertical_projection_overlap_ratio(
+                    (block_x0, block_y0, block_x1, block_y1), footnote_bbox
+                )
+                >= 0.8
+            ):
                 if block not in need_remove_blocks:
                     need_remove_blocks.append(block)
                     break
@@ -203,7 +376,12 @@ def remove_need_drop_blocks(all_bboxes, discarded_blocks):
     for block in all_bboxes:
         for discarded_block in discarded_blocks:
             block_bbox = block[:4]
-            if calculate_overlap_area_in_bbox1_area_ratio(block_bbox, discarded_block['bbox']) > 0.6:
+            if (
+                calculate_overlap_area_in_bbox1_area_ratio(
+                    block_bbox, discarded_block['bbox']
+                )
+                > 0.6
+            ):
                 if block not in need_remove:
                     need_remove.append(block)
                     break
@@ -223,10 +401,18 @@ def remove_overlaps_min_blocks(all_bboxes):
             if block1 != block2:
                 block1_bbox = block1[:4]
                 block2_bbox = block2[:4]
-                overlap_box = get_minbox_if_overlap_by_ratio(block1_bbox, block2_bbox, 0.8)
+                overlap_box = get_minbox_if_overlap_by_ratio(
+                    block1_bbox, block2_bbox, 0.8
+                )
                 if overlap_box is not None:
-                    block_to_remove = next((block for block in all_bboxes if block[:4] == overlap_box), None)
-                    if block_to_remove is not None and block_to_remove not in need_remove:
+                    block_to_remove = next(
+                        (block for block in all_bboxes if block[:4] == overlap_box),
+                        None,
+                    )
+                    if (
+                        block_to_remove is not None
+                        and block_to_remove not in need_remove
+                    ):
                         large_block = block1 if block1 != block_to_remove else block2
                         x1, y1, x2, y2 = large_block[:4]
                         sx1, sy1, sx2, sy2 = block_to_remove[:4]

+ 3 - 3
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -1,8 +1,8 @@
+from magic_pdf.config.drop_tag import DropTag
+from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.libs.boxbase import (__is_overlaps_y_exceeds_threshold,
                                     _is_in_or_part_overlap_with_area_ratio,
                                     calculate_overlap_area_in_bbox1_area_ratio)
-from magic_pdf.libs.drop_tag import DropTag
-from magic_pdf.libs.ocr_content_type import BlockType, ContentType
 
 
 # 将每一个line中的span从左到右排序
@@ -157,7 +157,7 @@ def fill_spans_in_blocks(blocks, spans, radio):
             BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
             BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
         ]:
-            block_dict["group_id"] = block[-1]
+            block_dict['group_id'] = block[-1]
         block_spans = []
         for span in spans:
             span_bbox = span['bbox']

+ 123 - 60
magic_pdf/pre_proc/ocr_span_list_modify.py

@@ -1,9 +1,10 @@
-from loguru import logger
 
-from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, get_minbox_if_overlap_by_ratio, \
-    __is_overlaps_y_exceeds_threshold, calculate_iou
-from magic_pdf.libs.drop_tag import DropTag
-from magic_pdf.libs.ocr_content_type import ContentType, BlockType
+from magic_pdf.config.drop_tag import DropTag
+from magic_pdf.config.ocr_content_type import BlockType, ContentType
+from magic_pdf.libs.boxbase import (__is_overlaps_y_exceeds_threshold,
+                                    calculate_iou,
+                                    calculate_overlap_area_in_bbox1_area_ratio,
+                                    get_minbox_if_overlap_by_ratio)
 
 
 def remove_overlaps_low_confidence_spans(spans):
@@ -21,7 +22,10 @@ def remove_overlaps_low_confidence_spans(spans):
                             span_need_remove = span1
                         else:
                             span_need_remove = span2
-                        if span_need_remove is not None and span_need_remove not in dropped_spans:
+                        if (
+                            span_need_remove is not None
+                            and span_need_remove not in dropped_spans
+                        ):
                             dropped_spans.append(span_need_remove)
 
     if len(dropped_spans) > 0:
@@ -38,10 +42,17 @@ def remove_overlaps_min_spans(spans):
     for span1 in spans:
         for span2 in spans:
             if span1 != span2:
-                overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
+                overlap_box = get_minbox_if_overlap_by_ratio(
+                    span1['bbox'], span2['bbox'], 0.65
+                )
                 if overlap_box is not None:
-                    span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
-                    if span_need_remove is not None and span_need_remove not in dropped_spans:
+                    span_need_remove = next(
+                        (span for span in spans if span['bbox'] == overlap_box), None
+                    )
+                    if (
+                        span_need_remove is not None
+                        and span_need_remove not in dropped_spans
+                    ):
                         dropped_spans.append(span_need_remove)
 
     if len(dropped_spans) > 0:
@@ -58,7 +69,10 @@ def remove_spans_by_bboxes(spans, need_remove_spans_bboxes):
     need_remove_spans = []
     for span in spans:
         for removed_bbox in need_remove_spans_bboxes:
-            if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], removed_bbox) > 0.5:
+            if (
+                calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], removed_bbox)
+                > 0.5
+            ):
                 if span not in need_remove_spans:
                     need_remove_spans.append(span)
                     break
@@ -78,12 +92,22 @@ def remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict):
         for span in spans:
             # 通过判断span的bbox是否在removed_bboxes中, 判断是否需要删除该span
             for removed_bbox in removed_bboxes:
-                if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], removed_bbox) > 0.5:
+                if (
+                    calculate_overlap_area_in_bbox1_area_ratio(
+                        span['bbox'], removed_bbox
+                    )
+                    > 0.5
+                ):
                     need_remove_spans.append(span)
                     break
                 # 当drop_tag为DropTag.FOOTNOTE时, 判断span是否在removed_bboxes中任意一个的下方,如果是,则删除该span
-                elif drop_tag == DropTag.FOOTNOTE and (span['bbox'][1] + span['bbox'][3]) / 2 > removed_bbox[3] and \
-                        removed_bbox[0] < (span['bbox'][0] + span['bbox'][2]) / 2 < removed_bbox[2]:
+                elif (
+                    drop_tag == DropTag.FOOTNOTE
+                    and (span['bbox'][1] + span['bbox'][3]) / 2 > removed_bbox[3]
+                    and removed_bbox[0]
+                    < (span['bbox'][0] + span['bbox'][2]) / 2
+                    < removed_bbox[2]
+                ):
                     need_remove_spans.append(span)
                     break
 
@@ -98,11 +122,18 @@ def remove_spans_by_bboxes_dict(spans, need_remove_spans_bboxes_dict):
 def adjust_bbox_for_standalone_block(spans):
     # 对tpye=["interline_equation", "image", "table"]进行额外处理,如果左边有字的话,将该span的bbox中y0调整至不高于文字的y0
     for sb_span in spans:
-        if sb_span['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table]:
+        if sb_span['type'] in [
+            ContentType.InterlineEquation,
+            ContentType.Image,
+            ContentType.Table,
+        ]:
             for text_span in spans:
                 if text_span['type'] in [ContentType.Text, ContentType.InlineEquation]:
                     # 判断span2的纵向高度是否被span所覆盖
-                    if sb_span['bbox'][1] < text_span['bbox'][1] and sb_span['bbox'][3] > text_span['bbox'][3]:
+                    if (
+                        sb_span['bbox'][1] < text_span['bbox'][1]
+                        and sb_span['bbox'][3] > text_span['bbox'][3]
+                    ):
                         # 判断span2是否在span左边
                         if text_span['bbox'][0] < sb_span['bbox'][0]:
                             # 调整span的y0和span2的y0一致
@@ -120,11 +151,15 @@ def modify_y_axis(spans: list, displayed_list: list, text_inline_lines: list):
 
         lines = []
         current_line = [spans[0]]
-        if spans[0]["type"] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table]:
+        if spans[0]['type'] in [
+            ContentType.InterlineEquation,
+            ContentType.Image,
+            ContentType.Table,
+        ]:
             displayed_list.append(spans[0])
 
-        line_first_y0 = spans[0]["bbox"][1]
-        line_first_y = spans[0]["bbox"][3]
+        line_first_y0 = spans[0]['bbox'][1]
+        line_first_y = spans[0]['bbox'][3]
         # 用于给行间公式搜索
         # text_inline_lines = []
         for span in spans[1:]:
@@ -132,26 +167,43 @@ def modify_y_axis(spans: list, displayed_list: list, text_inline_lines: list):
             #     print("debug")
             # 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
             # image和table类型,同上
-            if span['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] or any(
-                    s['type'] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table] for s in
-                    current_line):
+            if span['type'] in [
+                ContentType.InterlineEquation,
+                ContentType.Image,
+                ContentType.Table,
+            ] or any(
+                s['type']
+                in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table]
+                for s in current_line
+            ):
                 # 传入
-                if span["type"] in [ContentType.InterlineEquation, ContentType.Image, ContentType.Table]:
+                if span['type'] in [
+                    ContentType.InterlineEquation,
+                    ContentType.Image,
+                    ContentType.Table,
+                ]:
                     displayed_list.append(span)
                 # 则开始新行
                 lines.append(current_line)
-                if len(current_line) > 1 or current_line[0]["type"] in [ContentType.Text, ContentType.InlineEquation]:
-                    text_inline_lines.append((current_line, (line_first_y0, line_first_y)))
+                if len(current_line) > 1 or current_line[0]['type'] in [
+                    ContentType.Text,
+                    ContentType.InlineEquation,
+                ]:
+                    text_inline_lines.append(
+                        (current_line, (line_first_y0, line_first_y))
+                    )
                 current_line = [span]
-                line_first_y0 = span["bbox"][1]
-                line_first_y = span["bbox"][3]
+                line_first_y0 = span['bbox'][1]
+                line_first_y = span['bbox'][3]
                 continue
 
             # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
-            if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']):
-                if span["type"] == "text":
-                    line_first_y0 = span["bbox"][1]
-                    line_first_y = span["bbox"][3]
+            if __is_overlaps_y_exceeds_threshold(
+                span['bbox'], current_line[-1]['bbox']
+            ):
+                if span['type'] == 'text':
+                    line_first_y0 = span['bbox'][1]
+                    line_first_y = span['bbox'][3]
                 current_line.append(span)
 
             else:
@@ -159,13 +211,16 @@ def modify_y_axis(spans: list, displayed_list: list, text_inline_lines: list):
                 lines.append(current_line)
                 text_inline_lines.append((current_line, (line_first_y0, line_first_y)))
                 current_line = [span]
-                line_first_y0 = span["bbox"][1]
-                line_first_y = span["bbox"][3]
+                line_first_y0 = span['bbox'][1]
+                line_first_y = span['bbox'][3]
 
             # 添加最后一行
         if current_line:
             lines.append(current_line)
-            if len(current_line) > 1 or current_line[0]["type"] in [ContentType.Text, ContentType.InlineEquation]:
+            if len(current_line) > 1 or current_line[0]['type'] in [
+                ContentType.Text,
+                ContentType.InlineEquation,
+            ]:
                 text_inline_lines.append((current_line, (line_first_y0, line_first_y)))
         for line in text_inline_lines:
             # 按照x0坐标排序
@@ -176,8 +231,8 @@ def modify_y_axis(spans: list, displayed_list: list, text_inline_lines: list):
         for line in text_inline_lines:
             current_line, (line_first_y0, line_first_y) = line
             for span in current_line:
-                span["bbox"][1] = line_first_y0
-                span["bbox"][3] = line_first_y
+                span['bbox'][1] = line_first_y0
+                span['bbox'][3] = line_first_y
 
         # return spans, displayed_list, text_inline_lines
 
@@ -189,34 +244,42 @@ def modify_inline_equation(spans: list, displayed_list: list, text_inline_lines:
         # if i == 8:
         #     print("debug")
         span = displayed_list[i]
-        span_y0, span_y = span["bbox"][1], span["bbox"][3]
+        span_y0, span_y = span['bbox'][1], span['bbox'][3]
 
         while j < len(text_inline_lines):
             text_line = text_inline_lines[j]
             y0, y1 = text_line[1]
             if (
-                    span_y0 < y0 < span_y or span_y0 < y1 < span_y or span_y0 < y0 and span_y > y1
-            ) and __is_overlaps_y_exceeds_threshold(
-                span['bbox'], (0, y0, 0, y1)
-            ):
+                span_y0 < y0 < span_y
+                or span_y0 < y1 < span_y
+                or span_y0 < y0
+                and span_y > y1
+            ) and __is_overlaps_y_exceeds_threshold(span['bbox'], (0, y0, 0, y1)):
                 # 调整公式类型
-                if span["type"] == ContentType.InterlineEquation:
+                if span['type'] == ContentType.InterlineEquation:
                     # 最后一行是行间公式
                     if j + 1 >= len(text_inline_lines):
-                        span["type"] = ContentType.InlineEquation
-                        span["bbox"][1] = y0
-                        span["bbox"][3] = y1
+                        span['type'] = ContentType.InlineEquation
+                        span['bbox'][1] = y0
+                        span['bbox'][3] = y1
                     else:
                         # 行间公式旁边有多行文字或者行间公式比文字高3倍则不转换
                         y0_next, y1_next = text_inline_lines[j + 1][1]
-                        if not __is_overlaps_y_exceeds_threshold(span['bbox'], (0, y0_next, 0, y1_next)) and 3 * (
-                                y1 - y0) > span_y - span_y0:
-                            span["type"] = ContentType.InlineEquation
-                            span["bbox"][1] = y0
-                            span["bbox"][3] = y1
+                        if (
+                            not __is_overlaps_y_exceeds_threshold(
+                                span['bbox'], (0, y0_next, 0, y1_next)
+                            )
+                            and 3 * (y1 - y0) > span_y - span_y0
+                        ):
+                            span['type'] = ContentType.InlineEquation
+                            span['bbox'][1] = y0
+                            span['bbox'][3] = y1
                 break
-            elif span_y < y0 or span_y0 < y0 < span_y and not __is_overlaps_y_exceeds_threshold(span['bbox'],
-                                                                                                (0, y0, 0, y1)):
+            elif (
+                span_y < y0
+                or span_y0 < y0 < span_y
+                and not __is_overlaps_y_exceeds_threshold(span['bbox'], (0, y0, 0, y1))
+            ):
                 break
             else:
                 j += 1
@@ -232,15 +295,15 @@ def get_qa_need_list(blocks):
     inline_equations = []
 
     for block in blocks:
-        for line in block["lines"]:
-            for span in line["spans"]:
-                if span["type"] == ContentType.Image:
+        for line in block['lines']:
+            for span in line['spans']:
+                if span['type'] == ContentType.Image:
                     images.append(span)
-                elif span["type"] == ContentType.Table:
+                elif span['type'] == ContentType.Table:
                     tables.append(span)
-                elif span["type"] == ContentType.InlineEquation:
+                elif span['type'] == ContentType.InlineEquation:
                     inline_equations.append(span)
-                elif span["type"] == ContentType.InterlineEquation:
+                elif span['type'] == ContentType.InterlineEquation:
                     interline_equations.append(span)
                 else:
                     continue
@@ -254,10 +317,10 @@ def get_qa_need_list_v2(blocks):
     interline_equations = []
 
     for block in blocks:
-        if block["type"] == BlockType.Image:
+        if block['type'] == BlockType.Image:
             images.append(block)
-        elif block["type"] == BlockType.Table:
+        elif block['type'] == BlockType.Table:
             tables.append(block)
-        elif block["type"] == BlockType.InterlineEquation:
+        elif block['type'] == BlockType.InterlineEquation:
             interline_equations.append(block)
     return images, tables, interline_equations

+ 37 - 33
magic_pdf/pre_proc/pdf_pre_filter.py

@@ -1,58 +1,65 @@
-from magic_pdf.libs.commons import fitz
+from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.libs.boxbase import _is_in, _is_in_or_part_overlap
-from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.libs.commons import fitz
 
 
 def __area(box):
     return (box[2] - box[0]) * (box[3] - box[1])
 
-def __is_contain_color_background_rect(page:fitz.Page, text_blocks, image_bboxes) -> bool:
-    """
-    检查page是包含有颜色背景的矩形
-    """
+
+def __is_contain_color_background_rect(
+    page: fitz.Page, text_blocks, image_bboxes
+) -> bool:
+    """检查page是包含有颜色背景的矩形."""
     color_bg_rect = []
     p_width, p_height = page.rect.width, page.rect.height
-    
+
     # 先找到最大的带背景矩形
     blocks = page.get_cdrawings()
     for block in blocks:
-        
-        if 'fill' in block and block['fill']: # 过滤掉透明的
+        if 'fill' in block and block['fill']:  # 过滤掉透明的
             fill = list(block['fill'])
             fill[0], fill[1], fill[2] = int(fill[0]), int(fill[1]), int(fill[2])
-            if fill==(1.0,1.0,1.0):
+            if fill == (1.0, 1.0, 1.0):
                 continue
             rect = block['rect']
             # 过滤掉特别小的矩形
-            if __area(rect) < 10*10:
+            if __area(rect) < 10 * 10:
                 continue
             # 为了防止是svg图片上的色块,这里过滤掉这类
-            
-            if any([_is_in_or_part_overlap(rect, img_bbox) for img_bbox in image_bboxes]):
+
+            if any(
+                [_is_in_or_part_overlap(rect, img_bbox) for img_bbox in image_bboxes]
+            ):
                 continue
             color_bg_rect.append(rect)
-            
+
     # 找到最大的背景矩形
     if len(color_bg_rect) > 0:
-        max_rect = max(color_bg_rect, key=lambda x:__area(x))
-        max_rect_int = (int(max_rect[0]), int(max_rect[1]), int(max_rect[2]), int(max_rect[3]))
+        max_rect = max(color_bg_rect, key=lambda x: __area(x))
+        max_rect_int = (
+            int(max_rect[0]),
+            int(max_rect[1]),
+            int(max_rect[2]),
+            int(max_rect[3]),
+        )
         # 判断最大的背景矩形是否包含超过3行文字,或者50个字 TODO
-        if max_rect[2]-max_rect[0] > 0.2*p_width and  max_rect[3]-max_rect[1] > 0.1*p_height:#宽度符合
-            #看是否有文本块落入到这个矩形中
+        if (
+            max_rect[2] - max_rect[0] > 0.2 * p_width
+            and max_rect[3] - max_rect[1] > 0.1 * p_height
+        ):  # 宽度符合
+            # 看是否有文本块落入到这个矩形中
             for text_block in text_blocks:
                 box = text_block['bbox']
                 box_int = (int(box[0]), int(box[1]), int(box[2]), int(box[3]))
                 if _is_in(box_int, max_rect_int):
                     return True
-    
+
     return False
 
 
 def __is_table_overlap_text_block(text_blocks, table_bbox):
-    """
-    检查table_bbox是否覆盖了text_blocks里的文本块
-    TODO
-    """
+    """检查table_bbox是否覆盖了text_blocks里的文本块 TODO."""
     for text_block in text_blocks:
         box = text_block['bbox']
         if _is_in_or_part_overlap(table_bbox, box):
@@ -60,15 +67,12 @@ def __is_table_overlap_text_block(text_blocks, table_bbox):
     return False
 
 
-def pdf_filter(page:fitz.Page, text_blocks, table_bboxes, image_bboxes) -> tuple:
-    """
-    return:(True|False, err_msg)
-        True, 如果pdf符合要求
-        False, 如果pdf不符合要求
-        
-    """
+def pdf_filter(page: fitz.Page, text_blocks, table_bboxes, image_bboxes) -> tuple:
+    """return:(True|False, err_msg) True, 如果pdf符合要求 False, 如果pdf不符合要求."""
     if __is_contain_color_background_rect(page, text_blocks, image_bboxes):
-        return False, {"_need_drop": True, "_drop_reason": DropReason.COLOR_BACKGROUND_TEXT_BOX}
+        return False, {
+            '_need_drop': True,
+            '_drop_reason': DropReason.COLOR_BACKGROUND_TEXT_BOX,
+        }
 
-    
-    return True, None
+    return True, None

+ 20 - 18
magic_pdf/pre_proc/remove_bbox_overlap.py

@@ -1,8 +1,9 @@
-from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in, _is_part_overlap
-from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.config.drop_reason import DropReason
+from magic_pdf.libs.boxbase import _is_in, _is_part_overlap
+
 
 def _remove_overlap_between_bbox(bbox1, bbox2):
-   if _is_part_overlap(bbox1, bbox2):
+    if _is_part_overlap(bbox1, bbox2):
         ix0, iy0, ix1, iy1 = bbox1
         x0, y0, x1, y1 = bbox2
 
@@ -22,10 +23,10 @@ def _remove_overlap_between_bbox(bbox1, bbox2):
             if y1 >= iy1:
                 mid = (y0 + iy1) // 2
                 y0 = max(mid + 0.25, y0)
-                iy1 = min(iy1, mid-0.25)
+                iy1 = min(iy1, mid - 0.25)
             else:
                 mid = (iy0 + y1) // 2
-                y1 = min(y1, mid-0.25)
+                y1 = min(y1, mid - 0.25)
                 iy0 = max(mid + 0.25, iy0)
 
         if ix1 > ix0 and iy1 > iy0 and y1 > y0 and x1 > x0:
@@ -34,8 +35,8 @@ def _remove_overlap_between_bbox(bbox1, bbox2):
             return bbox1, bbox2, None
         else:
             return bbox1, bbox2, DropReason.NEGATIVE_BBOX_AREA
-   else:
-       return bbox1, bbox2, None
+    else:
+        return bbox1, bbox2, None
 
 
 def _remove_overlap_between_bboxes(arr):
@@ -47,7 +48,7 @@ def _remove_overlap_between_bboxes(arr):
         for j in range(N):
             if i == j:
                 continue
-            if _is_in(arr[i]["bbox"], arr[j]["bbox"]):
+            if _is_in(arr[i]['bbox'], arr[j]['bbox']):
                 keeps[i] = False
 
     for idx, v in enumerate(arr):
@@ -56,13 +57,15 @@ def _remove_overlap_between_bboxes(arr):
         for i in range(N):
             if res[i] is None:
                 continue
-        
-            bbox1, bbox2, drop_reason = _remove_overlap_between_bbox(v["bbox"], res[i]["bbox"])
+
+            bbox1, bbox2, drop_reason = _remove_overlap_between_bbox(
+                v['bbox'], res[i]['bbox']
+            )
             if drop_reason is None:
-                v["bbox"] = bbox1
-                res[i]["bbox"] = bbox2
+                v['bbox'] = bbox1
+                res[i]['bbox'] = bbox2
             else:
-                if v["score"] > res[i]["score"]:
+                if v['score'] > res[i]['score']:
                     keeps[i] = False
                     res[i] = None
                 else:
@@ -74,25 +77,24 @@ def _remove_overlap_between_bboxes(arr):
 
 
 def remove_overlap_between_bbox_for_span(spans):
-    arr = [{"bbox": span["bbox"], "score": span.get("score", 0.1)} for span in spans ]
+    arr = [{'bbox': span['bbox'], 'score': span.get('score', 0.1)} for span in spans]
     res, drop_reasons = _remove_overlap_between_bboxes(arr)
     ret = []
     for i in range(len(res)):
         if res[i] is None:
             continue
-        spans[i]["bbox"] = res[i]["bbox"]
+        spans[i]['bbox'] = res[i]['bbox']
         ret.append(spans[i])
     return ret, drop_reasons
 
 
 def remove_overlap_between_bbox_for_block(all_bboxes):
-    arr = [{"bbox": bbox[:4], "score": bbox[-1]} for bbox in all_bboxes ]
+    arr = [{'bbox': bbox[:4], 'score': bbox[-1]} for bbox in all_bboxes]
     res, drop_reasons = _remove_overlap_between_bboxes(arr)
     ret = []
     for i in range(len(res)):
         if res[i] is None:
             continue
-        all_bboxes[i][:4] = res[i]["bbox"]
+        all_bboxes[i][:4] = res[i]['bbox']
         ret.append(all_bboxes[i])
     return ret, drop_reasons
-

+ 36 - 14
magic_pdf/pre_proc/remove_colored_strip_bbox.py

@@ -1,7 +1,8 @@
-from magic_pdf.libs.boxbase import _is_in, _is_in_or_part_overlap, calculate_overlap_area_2_minbox_area_ratio
 from loguru import logger
 
-from magic_pdf.libs.drop_tag import COLOR_BG_HEADER_TXT_BLOCK
+from magic_pdf.config.drop_tag import COLOR_BG_HEADER_TXT_BLOCK
+from magic_pdf.libs.boxbase import (_is_in, _is_in_or_part_overlap,
+                                    calculate_overlap_area_2_minbox_area_ratio)
 
 
 def __area(box):
@@ -9,8 +10,7 @@ def __area(box):
 
 
 def rectangle_position_determination(rect, p_width):
-    """
-    判断矩形是否在页面中轴线附近。
+    """判断矩形是否在页面中轴线附近。
 
     Args:
         rect (list): 矩形坐标,格式为[x1, y1, x2, y2]。
@@ -34,9 +34,10 @@ def rectangle_position_determination(rect, p_width):
         else:
             return False
 
+
 def remove_colored_strip_textblock(remain_text_blocks, page):
-    """
-    根据页面中特定颜色和大小过滤文本块,将符合条件的文本块从remain_text_blocks中移除,并返回移除的文本块列表colored_strip_textblock。
+    """根据页面中特定颜色和大小过滤文本块,将符合条件的文本块从remain_text_blocks中移除,并返回移除的文本块列表colored_str
+    ip_textblock。
 
     Args:
         remain_text_blocks (list): 剩余文本块列表。
@@ -51,22 +52,44 @@ def remove_colored_strip_textblock(remain_text_blocks, page):
         blocks = page.get_cdrawings()
         colored_strip_bg_rect = []
         for block in blocks:
-            is_filled = 'fill' in block and block['fill'] and block['fill'] != (1.0, 1.0, 1.0)  # 过滤掉透明的
+            is_filled = (
+                'fill' in block and block['fill'] and block['fill'] != (1.0, 1.0, 1.0)
+            )  # 过滤掉透明的
             rect = block['rect']
             area_is_large_enough = __area(rect) > 100  # 过滤掉特别小的矩形
-            rectangle_position_determination_result = rectangle_position_determination(rect, p_width)
-            in_upper_half_page = rect[3] < p_height * 0.3  # 找到位于页面上半部分的矩形,下边界小于页面高度的30%
-            aspect_ratio_exceeds_4 = (rect[2] - rect[0]) > (rect[3] - rect[1]) * 4  # 找到长宽比超过4的矩形
+            rectangle_position_determination_result = rectangle_position_determination(
+                rect, p_width
+            )
+            in_upper_half_page = (
+                rect[3] < p_height * 0.3
+            )  # 找到位于页面上半部分的矩形,下边界小于页面高度的30%
+            aspect_ratio_exceeds_4 = (rect[2] - rect[0]) > (
+                rect[3] - rect[1]
+            ) * 4  # 找到长宽比超过4的矩形
 
-            if is_filled and area_is_large_enough and rectangle_position_determination_result and in_upper_half_page and aspect_ratio_exceeds_4:
+            if (
+                is_filled
+                and area_is_large_enough
+                and rectangle_position_determination_result
+                and in_upper_half_page
+                and aspect_ratio_exceeds_4
+            ):
                 colored_strip_bg_rect.append(rect)
 
         if len(colored_strip_bg_rect) > 0:
             for colored_strip_block_bbox in colored_strip_bg_rect:
                 for text_block in remain_text_blocks:
                     text_bbox = text_block['bbox']
-                    if _is_in(text_bbox, colored_strip_block_bbox) or (_is_in_or_part_overlap(text_bbox, colored_strip_block_bbox) and calculate_overlap_area_2_minbox_area_ratio(text_bbox, colored_strip_block_bbox) > 0.6):
-                        logger.info(f'remove_colored_strip_textblock: {text_bbox}, {colored_strip_block_bbox}')
+                    if _is_in(text_bbox, colored_strip_block_bbox) or (
+                        _is_in_or_part_overlap(text_bbox, colored_strip_block_bbox)
+                        and calculate_overlap_area_2_minbox_area_ratio(
+                            text_bbox, colored_strip_block_bbox
+                        )
+                        > 0.6
+                    ):
+                        logger.info(
+                            f'remove_colored_strip_textblock: {text_bbox}, {colored_strip_block_bbox}'
+                        )
                         text_block['tag'] = COLOR_BG_HEADER_TXT_BLOCK
                         colored_strip_textblocks.append(text_block)
 
@@ -76,4 +99,3 @@ def remove_colored_strip_textblock(remain_text_blocks, page):
                             remain_text_blocks.remove(colored_strip_textblock)
 
     return remain_text_blocks, colored_strip_textblocks
-

+ 2 - 5
magic_pdf/pre_proc/remove_footer_header.py

@@ -1,15 +1,12 @@
 import re
 
+from magic_pdf.config.drop_tag import CONTENT_IN_FOOT_OR_HEADER, PAGE_NO
 from magic_pdf.libs.boxbase import _is_in_or_part_overlap
-from magic_pdf.libs.drop_tag import CONTENT_IN_FOOT_OR_HEADER, PAGE_NO
 
 
 def remove_headder_footer_one_page(text_raw_blocks, image_bboxes, table_bboxes, header_bboxs, footer_bboxs,
                                    page_no_bboxs, page_w, page_h):
-    """
-    删除页眉页脚,页码
-    从line级别进行删除,删除之后观察这个text-block是否是空的,如果是空的,则移动到remove_list中
-    """
+    """删除页眉页脚,页码 从line级别进行删除,删除之后观察这个text-block是否是空的,如果是空的,则移动到remove_list中."""
     header = []
     footer = []
     if len(header) == 0:

+ 111 - 63
magic_pdf/pre_proc/remove_rotate_bbox.py

@@ -1,19 +1,21 @@
 import math
+import re
 
+from magic_pdf.config.drop_tag import (EMPTY_SIDE_BLOCK, ROTATE_TEXT,
+                                       VERTICAL_TEXT)
 from magic_pdf.libs.boxbase import is_vbox_on_side
-from magic_pdf.libs.drop_tag import EMPTY_SIDE_BLOCK, ROTATE_TEXT, VERTICAL_TEXT
 
 
 def detect_non_horizontal_texts(result_dict):
-    """
-    This function detects watermarks and vertical margin notes in the document.
+    """This function detects watermarks and vertical margin notes in the
+    document.
 
     Watermarks are identified by finding blocks with the same coordinates and frequently occurring identical texts across multiple pages.
     If these conditions are met, the blocks are highly likely to be watermarks, as opposed to headers or footers, which can change from page to page.
     If the direction of these blocks is not horizontal, they are definitely considered to be watermarks.
 
     Vertical margin notes are identified by finding blocks with the same coordinates and frequently occurring identical texts across multiple pages.
-    If these conditions are met, the blocks are highly likely to be vertical margin notes, which typically appear on the left and right sides of the page.
+    If these conditions are met, the blocks are highly likely to be vertical margin notes, which typically appear on the left and right sides of the page. # noqa: E501
     If the direction of these blocks is vertical, they are definitely considered to be vertical margin notes.
 
 
@@ -32,13 +34,16 @@ def detect_non_horizontal_texts(result_dict):
     potential_margin_notes = {}
 
     for page_id, page_content in result_dict.items():
-        if page_id.startswith("page_"):
+        if page_id.startswith('page_'):
             for block_id, block_data in page_content.items():
-                if block_id.startswith("block_"):
-                    if "dir" in block_data:
-                        coordinates_text = (block_data["bbox"], block_data["text"])  # Tuple of coordinates and text
-
-                        angle = math.atan2(block_data["dir"][1], block_data["dir"][0])
+                if block_id.startswith('block_'):
+                    if 'dir' in block_data:
+                        coordinates_text = (
+                            block_data['bbox'],
+                            block_data['text'],
+                        )  # Tuple of coordinates and text
+
+                        angle = math.atan2(block_data['dir'][1], block_data['dir'][0])
                         angle = abs(math.degrees(angle))
 
                         if angle > 5 and angle < 85:  # Check if direction is watermarks
@@ -49,32 +54,40 @@ def detect_non_horizontal_texts(result_dict):
 
                         if angle > 85 and angle < 105:  # Check if direction is vertical
                             if coordinates_text in potential_margin_notes:
-                                potential_margin_notes[coordinates_text] += 1  # Increment count
+                                potential_margin_notes[coordinates_text] += (
+                                    1  # Increment count
+                                )
                             else:
-                                potential_margin_notes[coordinates_text] = 1  # Initialize count
+                                potential_margin_notes[coordinates_text] = (
+                                    1  # Initialize count
+                                )
 
     # Identify watermarks by finding entries with counts higher than a threshold (e.g., appearing on more than half of the pages)
     watermark_threshold = len(result_dict) // 2
-    watermarks = {k: v for k, v in potential_watermarks.items() if v > watermark_threshold}
+    watermarks = {
+        k: v for k, v in potential_watermarks.items() if v > watermark_threshold
+    }
 
     # Identify margin notes by finding entries with counts higher than a threshold (e.g., appearing on more than half of the pages)
     margin_note_threshold = len(result_dict) // 2
-    margin_notes = {k: v for k, v in potential_margin_notes.items() if v > margin_note_threshold}
+    margin_notes = {
+        k: v for k, v in potential_margin_notes.items() if v > margin_note_threshold
+    }
 
     # Add watermark information to the result dictionary
     for page_id, blocks in result_dict.items():
-        if page_id.startswith("page_"):
+        if page_id.startswith('page_'):
             for block_id, block_data in blocks.items():
-                coordinates_text = (block_data["bbox"], block_data["text"])
+                coordinates_text = (block_data['bbox'], block_data['text'])
                 if coordinates_text in watermarks:
-                    block_data["is_watermark"] = 1
+                    block_data['is_watermark'] = 1
                 else:
-                    block_data["is_watermark"] = 0
+                    block_data['is_watermark'] = 0
 
                 if coordinates_text in margin_notes:
-                    block_data["is_vertical_margin_note"] = 1
+                    block_data['is_vertical_margin_note'] = 1
                 else:
-                    block_data["is_vertical_margin_note"] = 0
+                    block_data['is_vertical_margin_note'] = 0
 
     return result_dict
 
@@ -83,21 +96,21 @@ def detect_non_horizontal_texts(result_dict):
 1. 当一个block里全部文字都不是dir=(1,0),这个block整体去掉
 2. 当一个block里全部文字都是dir=(1,0),但是每行只有一个字,这个block整体去掉。这个block必须出现在页面的四周,否则不去掉
 """
-import re
+
 
 def __is_a_word(sentence):
     # 如果输入是中文并且长度为1,则返回True
     if re.fullmatch(r'[\u4e00-\u9fa5]', sentence):
         return True
     # 判断是否为单个英文单词或字符(包括ASCII标点)
-    elif re.fullmatch(r'[a-zA-Z0-9]+', sentence) and len(sentence) <=2:
+    elif re.fullmatch(r'[a-zA-Z0-9]+', sentence) and len(sentence) <= 2:
         return True
     else:
         return False
 
 
 def __get_text_color(num):
-    """获取字体的颜色RGB值"""
+    """获取字体的颜色RGB值."""
     blue = num & 255
     green = (num >> 8) & 255
     red = (num >> 16) & 255
@@ -105,84 +118,119 @@ def __get_text_color(num):
 
 
 def __is_empty_side_box(text_block):
-    """
-    是否是边缘上的空白没有任何内容的block
-    """
+    """是否是边缘上的空白没有任何内容的block."""
     for line in text_block['lines']:
         for span in line['spans']:
             font_color = span['color']
-            r,g,b = __get_text_color(font_color)
-            if len(span['text'].strip())>0 and (r,g,b)!=(255,255,255):
+            r, g, b = __get_text_color(font_color)
+            if len(span['text'].strip()) > 0 and (r, g, b) != (255, 255, 255):
                 return False
-            
+
     return True
 
 
 def remove_rotate_side_textblock(pymu_text_block, page_width, page_height):
-    """
-    返回删除了垂直,水印,旋转的textblock
-    删除的内容打上tag返回
-    """
+    """返回删除了垂直,水印,旋转的textblock 删除的内容打上tag返回."""
     removed_text_block = []
-    
-    for i, block in enumerate(pymu_text_block): # 格式参考test/assets/papre/pymu_textblocks.json
+
+    for i, block in enumerate(
+        pymu_text_block
+    ):  # 格式参考test/assets/papre/pymu_textblocks.json
         lines = block['lines']
         block_bbox = block['bbox']
-        if not is_vbox_on_side(block_bbox, page_width, page_height, 0.2): # 保证这些box必须在页面的两边
-           continue
-        
-        if all([__is_a_word(line['spans'][0]["text"]) for line in lines if len(line['spans'])>0]) and len(lines)>1 and all([len(line['spans'])==1 for line in lines]):
-            is_box_valign = (len(set([int(line['spans'][0]['bbox'][0] ) for line in lines if len(line['spans'])>0]))==1) and (len([int(line['spans'][0]['bbox'][0] ) for line in lines if len(line['spans'])>0])>1)  # 测试bbox在垂直方向是不是x0都相等,也就是在垂直方向排列.同时必须大于等于2个字
-            
+        if not is_vbox_on_side(
+            block_bbox, page_width, page_height, 0.2
+        ):  # 保证这些box必须在页面的两边
+            continue
+
+        if (
+            all(
+                [
+                    __is_a_word(line['spans'][0]['text'])
+                    for line in lines
+                    if len(line['spans']) > 0
+                ]
+            )
+            and len(lines) > 1
+            and all([len(line['spans']) == 1 for line in lines])
+        ):
+            is_box_valign = (
+                (
+                    len(
+                        set(
+                            [
+                                int(line['spans'][0]['bbox'][0])
+                                for line in lines
+                                if len(line['spans']) > 0
+                            ]
+                        )
+                    )
+                    == 1
+                )
+                and (
+                    len(
+                        [
+                            int(line['spans'][0]['bbox'][0])
+                            for line in lines
+                            if len(line['spans']) > 0
+                        ]
+                    )
+                    > 1
+                )
+            )  # 测试bbox在垂直方向是不是x0都相等,也就是在垂直方向排列.同时必须大于等于2个字
+
             if is_box_valign:
                 block['tag'] = VERTICAL_TEXT
                 removed_text_block.append(block)
                 continue
-        
+
         for line in lines:
-            if line['dir']!=(1,0):
+            if line['dir'] != (1, 0):
                 block['tag'] = ROTATE_TEXT
-                removed_text_block.append(block) # 只要有一个line不是dir=(1,0),就把整个block都删掉
+                removed_text_block.append(
+                    block
+                )  # 只要有一个line不是dir=(1,0),就把整个block都删掉
                 break
-        
+
     for block in removed_text_block:
         pymu_text_block.remove(block)
-    
+
     return pymu_text_block, removed_text_block
 
+
 def get_side_boundry(rotate_bbox, page_width, page_height):
-    """
-    根据rotate_bbox,返回页面的左右正文边界
-    """
+    """根据rotate_bbox,返回页面的左右正文边界."""
     left_x = 0
     right_x = page_width
     for x in rotate_bbox:
         box = x['bbox']
-        if box[2]<page_width/2:
+        if box[2] < page_width / 2:
             left_x = max(left_x, box[2])
         else:
             right_x = min(right_x, box[0])
-            
-    return left_x+1, right_x-1
+
+    return left_x + 1, right_x - 1
 
 
 def remove_side_blank_block(pymu_text_block, page_width, page_height):
-    """
-    删除页面两侧的空白block
-    """
+    """删除页面两侧的空白block."""
     removed_text_block = []
-    
-    for i, block in enumerate(pymu_text_block): # 格式参考test/assets/papre/pymu_textblocks.json
+
+    for i, block in enumerate(
+        pymu_text_block
+    ):  # 格式参考test/assets/papre/pymu_textblocks.json
         block_bbox = block['bbox']
-        if not is_vbox_on_side(block_bbox, page_width, page_height, 0.2): # 保证这些box必须在页面的两边
-           continue
-            
+        if not is_vbox_on_side(
+            block_bbox, page_width, page_height, 0.2
+        ):  # 保证这些box必须在页面的两边
+            continue
+
         if __is_empty_side_box(block):
             block['tag'] = EMPTY_SIDE_BLOCK
             removed_text_block.append(block)
             continue
-        
+
     for block in removed_text_block:
         pymu_text_block.remove(block)
-    
-    return pymu_text_block, removed_text_block
+
+    return pymu_text_block, removed_text_block

+ 10 - 17
magic_pdf/pre_proc/resolve_bbox_conflict.py

@@ -4,8 +4,9 @@
 2. 然后去掉出现在文字blcok上的图片bbox
 """
 
-from magic_pdf.libs.boxbase import _is_in, _is_in_or_part_overlap, _is_left_overlap
-from magic_pdf.libs.drop_tag import ON_IMAGE_TEXT, ON_TABLE_TEXT
+from magic_pdf.config.drop_tag import ON_IMAGE_TEXT, ON_TABLE_TEXT
+from magic_pdf.libs.boxbase import (_is_in, _is_in_or_part_overlap,
+                                    _is_left_overlap)
 
 
 def resolve_bbox_overlap_conflict(images: list, tables: list, interline_equations: list, inline_equations: list,
@@ -26,14 +27,14 @@ def resolve_bbox_overlap_conflict(images: list, tables: list, interline_equation
     # 去掉位于图片上的文字block
     for image_box in images:
         for text_block in text_raw_blocks:
-            text_bbox = text_block["bbox"]
+            text_bbox = text_block['bbox']
             if _is_in(text_bbox, image_box):
                 text_block['tag'] = ON_IMAGE_TEXT
                 text_block_removed.append(text_block)
     # 去掉table上的文字block
     for table_box in tables:
         for text_block in text_raw_blocks:
-            text_bbox = text_block["bbox"]
+            text_bbox = text_block['bbox']
             if _is_in(text_bbox, table_box):
                 text_block['tag'] = ON_TABLE_TEXT
                 text_block_removed.append(text_block)
@@ -77,7 +78,7 @@ def resolve_bbox_overlap_conflict(images: list, tables: list, interline_equation
     # 图片和文字重叠,丢掉图片
     for image_box in images:
         for text_block in text_raw_blocks:
-            text_bbox = text_block["bbox"]
+            text_bbox = text_block['bbox']
             if _is_in_or_part_overlap(image_box, text_bbox):
                 images_backup.append(image_box)
                 break
@@ -122,11 +123,7 @@ def resolve_bbox_overlap_conflict(images: list, tables: list, interline_equation
 
 
 def check_text_block_horizontal_overlap(text_blocks: list, header, footer) -> bool:
-    """
-    检查文本block之间的水平重叠情况,这种情况如果发生,那么这个pdf就不再继续处理了。
-    因为这种情况大概率发生了公式没有被检测出来。
-    
-    """
+    """检查文本block之间的水平重叠情况,这种情况如果发生,那么这个pdf就不再继续处理了。 因为这种情况大概率发生了公式没有被检测出来。"""
     if len(text_blocks) == 0:
         return False
 
@@ -148,7 +145,7 @@ def check_text_block_horizontal_overlap(text_blocks: list, header, footer) -> bo
 
     txt_bboxes = []
     for text_block in text_blocks:
-        bbox = text_block["bbox"]
+        bbox = text_block['bbox']
         if bbox[1] >= clip_y0 and bbox[3] <= clip_y1:
             txt_bboxes.append(bbox)
 
@@ -161,11 +158,7 @@ def check_text_block_horizontal_overlap(text_blocks: list, header, footer) -> bo
 
 
 def check_useful_block_horizontal_overlap(useful_blocks: list) -> bool:
-    """
-    检查文本block之间的水平重叠情况,这种情况如果发生,那么这个pdf就不再继续处理了。
-    因为这种情况大概率发生了公式没有被检测出来。
-
-    """
+    """检查文本block之间的水平重叠情况,这种情况如果发生,那么这个pdf就不再继续处理了。 因为这种情况大概率发生了公式没有被检测出来。"""
     if len(useful_blocks) == 0:
         return False
 
@@ -174,7 +167,7 @@ def check_useful_block_horizontal_overlap(useful_blocks: list) -> bool:
 
     useful_bboxes = []
     for text_block in useful_blocks:
-        bbox = text_block["bbox"]
+        bbox = text_block['bbox']
         if bbox[1] >= page_min_y and bbox[3] <= page_max_y:
             useful_bboxes.append(bbox)
 

+ 15 - 17
magic_pdf/spark/spark_api.py

@@ -1,51 +1,49 @@
 from loguru import logger
 
-from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.config.drop_reason import DropReason
 
 
 def get_data_source(jso: dict):
-    data_source = jso.get("data_source")
+    data_source = jso.get('data_source')
     if data_source is None:
-        data_source = jso.get("file_source")
+        data_source = jso.get('file_source')
     return data_source
 
 
 def get_data_type(jso: dict):
-    data_type = jso.get("data_type")
+    data_type = jso.get('data_type')
     if data_type is None:
-        data_type = jso.get("file_type")
+        data_type = jso.get('file_type')
     return data_type
 
 
 def get_bookid(jso: dict):
-    book_id = jso.get("bookid")
+    book_id = jso.get('bookid')
     if book_id is None:
-        book_id = jso.get("original_file_id")
+        book_id = jso.get('original_file_id')
     return book_id
 
 
 def exception_handler(jso: dict, e):
     logger.exception(e)
-    jso["_need_drop"] = True
-    jso["_drop_reason"] = DropReason.Exception
-    jso["_exception"] = f"ERROR: {e}"
+    jso['_need_drop'] = True
+    jso['_drop_reason'] = DropReason.Exception
+    jso['_exception'] = f'ERROR: {e}'
     return jso
 
 
 def get_bookname(jso: dict):
     data_source = get_data_source(jso)
-    file_id = jso.get("file_id")
-    book_name = f"{data_source}/{file_id}"
+    file_id = jso.get('file_id')
+    book_name = f'{data_source}/{file_id}'
     return book_name
 
 
 def spark_json_extractor(jso: dict) -> dict:
 
-    """
-    从json中提取数据,返回一个dict
-    """
+    """从json中提取数据,返回一个dict."""
 
     return {
-        "_pdf_type": jso["_pdf_type"],
-        "model_list": jso["doc_layout_result"],
+        '_pdf_type': jso['_pdf_type'],
+        'model_list': jso['doc_layout_result'],
     }

+ 1 - 1
magic_pdf/tools/common.py

@@ -7,10 +7,10 @@ import fitz
 from loguru import logger
 
 import magic_pdf.model as model_config
+from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
 from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
                                       draw_model_bbox, draw_span_bbox)
-from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.pipe.UNIPipe import UNIPipe

Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels