浏览代码

Merge pull request #672 from myhloli/add-layoutreader

feat:add layoutreader to sort blocks
Xiaomeng Zhao 1 年之前
父节点
当前提交
bcbee130f6

+ 10 - 0
magic_pdf/libs/clean_memory.py

@@ -0,0 +1,10 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import torch
+import gc
+
+
+def clean_memory():
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+    gc.collect()

+ 88 - 36
magic_pdf/libs/draw_bbox.py

@@ -33,7 +33,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
             )  # Draw the rectangle
 
 
-def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
+def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True):
     new_rgb = []
     for item in rgb_config:
         item = float(item) / 255
@@ -42,31 +42,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
     for j, bbox in enumerate(page_data):
         x0, y0, x1, y1 = bbox
         rect_coords = fitz.Rect(x0, y0, x1, y1)  # Define the rectangle
-        if fill_config:
-            page.draw_rect(
-                rect_coords,
-                color=None,
-                fill=new_rgb,
-                fill_opacity=0.3,
-                width=0.5,
-                overlay=True,
-            )  # Draw the rectangle
-        else:
-            page.draw_rect(
-                rect_coords,
-                color=new_rgb,
-                fill=None,
-                fill_opacity=1,
-                width=0.5,
-                overlay=True,
-            )  # Draw the rectangle
+        if draw_bbox:
+            if fill_config:
+                page.draw_rect(
+                    rect_coords,
+                    color=None,
+                    fill=new_rgb,
+                    fill_opacity=0.3,
+                    width=0.5,
+                    overlay=True,
+                )  # Draw the rectangle
+            else:
+                page.draw_rect(
+                    rect_coords,
+                    color=new_rgb,
+                    fill=None,
+                    fill_opacity=1,
+                    width=0.5,
+                    overlay=True,
+                )  # Draw the rectangle
         page.insert_text(
-            (x0, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
+            (x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
         )  # Insert the index in the top left corner of the rectangle
 
 
 def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
-    layout_bbox_list = []
     dropped_bbox_list = []
     tables_list, tables_body_list = [], []
     tables_caption_list, tables_footnote_list = [], []
@@ -76,16 +76,14 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
     texts_list = []
     interequations_list = []
     for page in pdf_info:
-        page_layout_list = []
+
         page_dropped_list = []
         tables, tables_body, tables_caption, tables_footnote = [], [], [], []
         imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
         titles = []
         texts = []
         interequations = []
-        for layout in page['layout_bboxes']:
-            page_layout_list.append(layout['layout_bbox'])
-        layout_bbox_list.append(page_layout_list)
+
         for dropped_bbox in page['discarded_blocks']:
             page_dropped_list.append(dropped_bbox['bbox'])
         dropped_bbox_list.append(page_dropped_list)
@@ -129,9 +127,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         texts_list.append(texts)
         interequations_list.append(interequations)
 
+    layout_bbox_list = []
+
+    for page in pdf_info:
+        page_block_list = []
+        for block in page['para_blocks']:
+            bbox = block['bbox']
+            page_block_list.append(bbox)
+        layout_bbox_list.append(page_block_list)
+
     pdf_docs = fitz.open('pdf', pdf_bytes)
+
     for i, page in enumerate(pdf_docs):
-        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
+
         draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
                                  True)
         draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
@@ -146,13 +154,15 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
         draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
                                  True)
-        draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
+        draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102],
                               True),
         draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
         draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
         draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
                                  True)
 
+        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False)
+
     # Save the PDF
     pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
 
@@ -211,9 +221,9 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
         # 构造其余useful_list
         for block in page['para_blocks']:
             if block['type'] in [
-                    BlockType.Text,
-                    BlockType.Title,
-                    BlockType.InterlineEquation,
+                BlockType.Text,
+                BlockType.Title,
+                BlockType.InterlineEquation,
             ]:
                 for line in block['lines']:
                     for span in line['spans']:
@@ -232,10 +242,8 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
     for i, page in enumerate(pdf_docs):
         # 获取当前页面的数据
         draw_bbox_without_number(i, text_list, page, [255, 0, 0], False)
-        draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0],
-                                 False)
-        draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255],
-                                 False)
+        draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False)
+        draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False)
         draw_bbox_without_number(i, image_list, page, [255, 204, 0], False)
         draw_bbox_without_number(i, table_list, page, [204, 0, 255], False)
         draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
@@ -244,7 +252,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
     pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
 
 
-def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
+def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
     dropped_bbox_list = []
     tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
     imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
@@ -279,7 +287,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
             elif layout_det['category_id'] == CategoryId.ImageCaption:
                 imgs_caption.append(bbox)
             elif layout_det[
-                    'category_id'] == CategoryId.InterlineEquation_YOLO:
+                'category_id'] == CategoryId.InterlineEquation_YOLO:
                 interequations.append(bbox)
             elif layout_det['category_id'] == CategoryId.Abandon:
                 page_dropped_list.append(bbox)
@@ -316,3 +324,47 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
 
     # Save the PDF
     pdf_docs.save(f'{out_path}/{filename}_model.pdf')
+
+
+def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
+    layout_bbox_list = []
+
+    for page in pdf_info:
+        page_line_list = []
+        for block in page['preproc_blocks']:
+            if block['type'] in ['text', 'title', 'interline_equation']:
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    index = line['index']
+                    page_line_list.append({'index': index, 'bbox': bbox})
+            if block['type'] in ['table', 'image']:
+                bbox = block['bbox']
+                index = block['index']
+                page_line_list.append({'index': index, 'bbox': bbox})
+            # for line in block['lines']:
+            #     bbox = line['bbox']
+            #     index = line['index']
+            #     page_line_list.append({'index': index, 'bbox': bbox})
+        sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
+        layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
+    pdf_docs = fitz.open('pdf', pdf_bytes)
+    for i, page in enumerate(pdf_docs):
+        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
+
+    pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')
+
+
+def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
+    layout_bbox_list = []
+
+    for page in pdf_info:
+        page_block_list = []
+        for block in page['para_blocks']:
+            bbox = block['bbox']
+            page_block_list.append(bbox)
+        layout_bbox_list.append(page_block_list)
+    pdf_docs = fitz.open('pdf', pdf_bytes)
+    for i, page in enumerate(pdf_docs):
+        draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
+
+    pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')

+ 3 - 0
magic_pdf/model/pdf_extract_kit.py

@@ -3,6 +3,7 @@ import os
 import time
 
 from magic_pdf.libs.Constants import *
+from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.model_list import AtomicModel
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
@@ -330,6 +331,8 @@ class CustomPEKModel:
             elif int(res['category_id']) in [5]:
                 table_res_list.append(res)
 
+        clean_memory()
+
         # ocr识别
         if self.apply_ocr:
             ocr_start = time.time()

+ 0 - 0
magic_pdf/model/v3/__init__.py


+ 125 - 0
magic_pdf/model/v3/helpers.py

@@ -0,0 +1,125 @@
+from collections import defaultdict
+from typing import List, Dict
+
+import torch
+from transformers import LayoutLMv3ForTokenClassification
+
+MAX_LEN = 510
+CLS_TOKEN_ID = 0
+UNK_TOKEN_ID = 3
+EOS_TOKEN_ID = 2
+
+
+class DataCollator:
+    def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
+        bbox = []
+        labels = []
+        input_ids = []
+        attention_mask = []
+
+        # clip bbox and labels to max length, build input_ids and attention_mask
+        for feature in features:
+            _bbox = feature["source_boxes"]
+            if len(_bbox) > MAX_LEN:
+                _bbox = _bbox[:MAX_LEN]
+            _labels = feature["target_index"]
+            if len(_labels) > MAX_LEN:
+                _labels = _labels[:MAX_LEN]
+            _input_ids = [UNK_TOKEN_ID] * len(_bbox)
+            _attention_mask = [1] * len(_bbox)
+            assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
+            bbox.append(_bbox)
+            labels.append(_labels)
+            input_ids.append(_input_ids)
+            attention_mask.append(_attention_mask)
+
+        # add CLS and EOS tokens
+        for i in range(len(bbox)):
+            bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
+            labels[i] = [-100] + labels[i] + [-100]
+            input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
+            attention_mask[i] = [1] + attention_mask[i] + [1]
+
+        # padding to max length
+        max_len = max(len(x) for x in bbox)
+        for i in range(len(bbox)):
+            bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
+            labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
+            input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
+            attention_mask[i] = attention_mask[i] + [0] * (
+                max_len - len(attention_mask[i])
+            )
+
+        ret = {
+            "bbox": torch.tensor(bbox),
+            "attention_mask": torch.tensor(attention_mask),
+            "labels": torch.tensor(labels),
+            "input_ids": torch.tensor(input_ids),
+        }
+        # set label > MAX_LEN to -100, because original labels may be > MAX_LEN
+        ret["labels"][ret["labels"] > MAX_LEN] = -100
+        # set label > 0 to label-1, because original labels are 1-indexed
+        ret["labels"][ret["labels"] > 0] -= 1
+        return ret
+
+
+def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
+    bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
+    input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
+    attention_mask = [1] + [1] * len(boxes) + [1]
+    return {
+        "bbox": torch.tensor([bbox]),
+        "attention_mask": torch.tensor([attention_mask]),
+        "input_ids": torch.tensor([input_ids]),
+    }
+
+
+def prepare_inputs(
+    inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
+) -> Dict[str, torch.Tensor]:
+    ret = {}
+    for k, v in inputs.items():
+        v = v.to(model.device)
+        if torch.is_floating_point(v):
+            v = v.to(model.dtype)
+        ret[k] = v
+    return ret
+
+
+def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
+    """
+    parse logits to orders
+
+    :param logits: logits from model
+    :param length: input length
+    :return: orders
+    """
+    logits = logits[1 : length + 1, :length]
+    orders = logits.argsort(descending=False).tolist()
+    ret = [o.pop() for o in orders]
+    while True:
+        order_to_idxes = defaultdict(list)
+        for idx, order in enumerate(ret):
+            order_to_idxes[order].append(idx)
+        # filter idxes len > 1
+        order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
+        if not order_to_idxes:
+            break
+        # filter
+        for order, idxes in order_to_idxes.items():
+            # find original logits of idxes
+            idxes_to_logit = {}
+            for idx in idxes:
+                idxes_to_logit[idx] = logits[idx, order]
+            idxes_to_logit = sorted(
+                idxes_to_logit.items(), key=lambda x: x[1], reverse=True
+            )
+            # keep the highest logit as order, set others to next candidate
+            for idx, _ in idxes_to_logit[1:]:
+                ret[idx] = orders[idx].pop()
+
+    return ret
+
+
+def check_duplicate(a: List[int]) -> bool:
+    return len(a) != len(set(a))

+ 1 - 1
magic_pdf/pdf_parse_by_ocr.py

@@ -1,4 +1,4 @@
-from magic_pdf.pdf_parse_union_core import pdf_parse_union
+from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
 def parse_pdf_by_ocr(pdf_bytes,

+ 1 - 1
magic_pdf/pdf_parse_by_txt.py

@@ -1,4 +1,4 @@
-from magic_pdf.pdf_parse_union_core import pdf_parse_union
+from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
 def parse_pdf_by_txt(

+ 451 - 0
magic_pdf/pdf_parse_union_core_v2.py

@@ -0,0 +1,451 @@
+import statistics
+import time
+
+from loguru import logger
+
+from typing import List
+
+import torch
+
+from magic_pdf.libs.clean_memory import clean_memory
+from magic_pdf.libs.commons import fitz, get_delta_time
+from magic_pdf.libs.convert_utils import dict_to_list
+from magic_pdf.libs.drop_reason import DropReason
+from magic_pdf.libs.hash_utils import compute_md5
+from magic_pdf.libs.local_math import float_equal
+from magic_pdf.libs.ocr_content_type import ContentType
+from magic_pdf.model.magic_model import MagicModel
+from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
+from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
+from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
+from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \
+    combine_chars_to_pymudict
+from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
+from magic_pdf.pre_proc.ocr_dict_merge import  fill_spans_in_blocks, fix_block_spans, fix_discarded_block
+from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
+    remove_overlaps_low_confidence_spans
+from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
+
+
+def remove_horizontal_overlap_block_which_smaller(all_bboxes):
+    useful_blocks = []
+    for bbox in all_bboxes:
+        useful_blocks.append({
+            "bbox": bbox[:4]
+        })
+    is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks)
+    if is_useful_block_horz_overlap:
+        logger.warning(
+            f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}")
+        for bbox in all_bboxes.copy():
+            if smaller_bbox == bbox[:4]:
+                all_bboxes.remove(bbox)
+
+    return is_useful_block_horz_overlap, all_bboxes
+
+
+def __replace_STX_ETX(text_str:str):
+    """ Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
+Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
+
+    Args:
+        text_str (str): raw text
+
+    Returns:
+        _type_: replaced text
+    """
+    if text_str:
+        s = text_str.replace('\u0002', "'")
+        s = s.replace("\u0003", "'")
+        return s
+    return text_str
+
+
+def txt_spans_extract(pdf_page, inline_equations, interline_equations):
+    text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
+    char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
+        "blocks"
+    ]
+    text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
+    text_blocks = replace_equations_in_textblock(
+        text_blocks, inline_equations, interline_equations
+    )
+    text_blocks = remove_citation_marker(text_blocks)
+    text_blocks = remove_chars_in_text_blocks(text_blocks)
+    spans = []
+    for v in text_blocks:
+        for line in v["lines"]:
+            for span in line["spans"]:
+                bbox = span["bbox"]
+                if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
+                    continue
+                if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation):
+                    spans.append(
+                        {
+                            "bbox": list(span["bbox"]),
+                            "content": __replace_STX_ETX(span["text"]),
+                            "type": ContentType.Text,
+                            "score": 1.0,
+                        }
+                    )
+    return spans
+
+
+def replace_text_span(pymu_spans, ocr_spans):
+    return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
+
+
+def model_init(model_name: str, local_path=None):
+    from transformers import LayoutLMv3ForTokenClassification
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+        if torch.cuda.is_bf16_supported():
+            supports_bfloat16 = True
+        else:
+            supports_bfloat16 = False
+    else:
+        device = torch.device("cpu")
+        supports_bfloat16 = False
+
+    if model_name == "layoutreader":
+        if local_path:
+            model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
+        else:
+            model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
+        # 检查设备是否支持 bfloat16
+        if supports_bfloat16:
+            model.bfloat16()
+        model.to(device).eval()
+    else:
+        logger.error("model name not allow")
+        exit(1)
+    return model
+
+
+class ModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_model(self, model_name: str, local_path=None):
+        if model_name not in self._models:
+            if local_path:
+                self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
+            else:
+                self._models[model_name] = model_init(model_name=model_name)
+        return self._models[model_name]
+
+
+def do_predict(boxes: List[List[int]], model) -> List[int]:
+    from magic_pdf.model.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
+    inputs = boxes2inputs(boxes)
+    inputs = prepare_inputs(inputs, model)
+    logits = model(**inputs).logits.cpu().squeeze(0)
+    return parse_logits(logits, len(boxes))
+
+
+def cal_block_index(fix_blocks, sorted_bboxes):
+    for block in fix_blocks:
+        # if block['type'] in ['text', 'title', 'interline_equation']:
+        #     line_index_list = []
+        #     if len(block['lines']) == 0:
+        #         block['index'] = sorted_bboxes.index(block['bbox'])
+        #     else:
+        #         for line in block['lines']:
+        #             line['index'] = sorted_bboxes.index(line['bbox'])
+        #             line_index_list.append(line['index'])
+        #         median_value = statistics.median(line_index_list)
+        #         block['index'] = median_value
+        #
+        # elif block['type'] in ['table', 'image']:
+        #     block['index'] = sorted_bboxes.index(block['bbox'])
+
+        line_index_list = []
+        if len(block['lines']) == 0:
+            block['index'] = sorted_bboxes.index(block['bbox'])
+        else:
+            for line in block['lines']:
+                line['index'] = sorted_bboxes.index(line['bbox'])
+                line_index_list.append(line['index'])
+            median_value = statistics.median(line_index_list)
+            block['index'] = median_value
+
+        # 删除图表block中的虚拟line信息
+        if block['type'] in ['table', 'image']:
+            del block['lines']
+
+    return fix_blocks
+
+
+def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
+    # block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
+    x0, y0, x1, y1 = block_bbox
+
+    block_height = y1 - y0
+    block_weight = x1 - x0
+
+    # 如果block高度小于n行正文,则直接返回block的bbox
+    if line_height*3 < block_height:
+        if block_height > page_h*0.25 and page_w*0.5 > block_weight > page_w*0.25:  # 可能是双列结构,可以切细点
+            lines = int(block_height/line_height)
+        else:
+            # 如果block的宽度超过0.4页面宽度,则将block分成3行
+            if block_weight > page_w*0.4:
+                line_height = (y1 - y0) / 3
+                lines = 3
+            elif block_weight > page_w*0.25: # 否则将block分成两行
+                line_height = (y1 - y0) / 2
+                lines = 2
+            else: # 判断长宽比
+                if block_height/block_weight > 1.2:  # 细长的不分
+                    return [[x0, y0, x1, y1]]
+                else: # 不细长的还是分成两行
+                    line_height = (y1 - y0) / 2
+                    lines = 2
+
+        # 确定从哪个y位置开始绘制线条
+        current_y = y0
+
+        # 用于存储线条的位置信息[(x0, y), ...]
+        lines_positions = []
+
+        for i in range(lines):
+            lines_positions.append([x0, current_y, x1, current_y + line_height])
+            current_y += line_height
+        return lines_positions
+
+    else:
+        return [[x0, y0, x1, y1]]
+
+
+def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
+    page_line_list = []
+    for block in fix_blocks:
+        if block['type'] in ['text', 'title', 'interline_equation']:
+            if len(block['lines']) == 0:
+                bbox = block['bbox']
+                lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
+                for line in lines:
+                    block['lines'].append({'bbox': line, 'spans': []})
+                page_line_list.extend(lines)
+            else:
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    page_line_list.append(bbox)
+        elif block['type'] in ['table', 'image']:
+            bbox = block['bbox']
+            lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
+            block['lines'] = []
+            for line in lines:
+                block['lines'].append({'bbox': line, 'spans': []})
+            page_line_list.extend(lines)
+
+    # 使用layoutreader排序
+    x_scale = 1000.0 / page_w
+    y_scale = 1000.0 / page_h
+    boxes = []
+    # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
+    for left, top, right, bottom in page_line_list:
+        if left < 0:
+            logger.warning(
+                f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            left = 0
+        if right > page_w:
+            logger.warning(
+                f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            right = page_w
+        if top < 0:
+            logger.warning(
+                f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            top = 0
+        if bottom > page_h:
+            logger.warning(
+                f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            bottom = page_h
+
+        left = round(left * x_scale)
+        top = round(top * y_scale)
+        right = round(right * x_scale)
+        bottom = round(bottom * y_scale)
+        assert (
+                1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
+        ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
+        boxes.append([left, top, right, bottom])
+    model_manager = ModelSingleton()
+    model = model_manager.get_model("layoutreader")
+    with torch.no_grad():
+        orders = do_predict(boxes, model)
+    sorted_bboxes = [page_line_list[i] for i in orders]
+
+    return sorted_bboxes
+
+
+def get_line_height(blocks):
+    page_line_height_list = []
+    for block in blocks:
+        if block['type'] in ['text', 'title', 'interline_equation']:
+            for line in block['lines']:
+                bbox = line['bbox']
+                page_line_height_list.append(int(bbox[3]-bbox[1]))
+    if len(page_line_height_list) > 0:
+        return statistics.median(page_line_height_list)
+    else:
+        return 10
+
+
+def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode):
+    need_drop = False
+    drop_reason = []
+
+    '''从magic_model对象中获取后面会用到的区块信息'''
+    img_blocks = magic_model.get_imgs(page_id)
+    table_blocks = magic_model.get_tables(page_id)
+    discarded_blocks = magic_model.get_discarded(page_id)
+    text_blocks = magic_model.get_text_blocks(page_id)
+    title_blocks = magic_model.get_title_blocks(page_id)
+    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
+
+    page_w, page_h = magic_model.get_page_size(page_id)
+
+    spans = magic_model.get_all_spans(page_id)
+
+    '''根据parse_mode,构造spans'''
+    if parse_mode == "txt":
+        """ocr 中文本类的 span 用 pymu spans 替换!"""
+        pymu_spans = txt_spans_extract(
+            pdf_docs[page_id], inline_equations, interline_equations
+        )
+        spans = replace_text_span(pymu_spans, spans)
+    elif parse_mode == "ocr":
+        pass
+    else:
+        raise Exception("parse_mode must be txt or ocr")
+
+    '''删除重叠spans中置信度较低的那些'''
+    spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
+    '''删除重叠spans中较小的那些'''
+    spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
+    '''对image和table截图'''
+    spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
+
+    '''将所有区块的bbox整理到一起'''
+    # interline_equation_blocks参数不够准,后面切换到interline_equations上
+    interline_equation_blocks = []
+    if len(interline_equation_blocks) > 0:
+        all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
+            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
+            interline_equation_blocks, page_w, page_h)
+    else:
+        all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
+            img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
+            interline_equations, page_w, page_h)
+
+    '''先处理不需要排版的discarded_blocks'''
+    discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
+    fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
+
+    '''如果当前页面没有bbox则跳过'''
+    if len(all_bboxes) == 0:
+        logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
+        return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
+                                               [], [], interline_equations, fix_discarded_blocks,
+                                               need_drop, drop_reason)
+
+    '''将span填入blocks中'''
+    block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.3)
+
+    '''对block进行fix操作'''
+    fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
+
+    '''获取所有line并计算正文line的高度'''
+    line_height = get_line_height(fix_blocks)
+
+    '''获取所有line并对line排序'''
+    sorted_bboxes = sort_lines_by_model(fix_blocks, page_w, page_h, line_height)
+
+    '''根据line的中位数算block的序列关系'''
+    fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
+
+    '''重排block'''
+    sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
+
+    '''获取QA需要外置的list'''
+    images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
+
+    '''构造pdf_info_dict'''
+    page_info = ocr_construct_page_component_v2(sorted_blocks, [], page_id, page_w, page_h, [],
+                                                images, tables, interline_equations, fix_discarded_blocks,
+                                                need_drop, drop_reason)
+    return page_info
+
+
+def pdf_parse_union(pdf_bytes,
+                    model_list,
+                    imageWriter,
+                    parse_mode,
+                    start_page_id=0,
+                    end_page_id=None,
+                    debug_mode=False,
+                    ):
+    pdf_bytes_md5 = compute_md5(pdf_bytes)
+    pdf_docs = fitz.open("pdf", pdf_bytes)
+
+    '''初始化空的pdf_info_dict'''
+    pdf_info_dict = {}
+
+    '''用model_list和docs对象初始化magic_model'''
+    magic_model = MagicModel(model_list, pdf_docs)
+
+    '''根据输入的起始范围解析pdf'''
+    # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
+    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
+
+    if end_page_id > len(pdf_docs) - 1:
+        logger.warning("end_page_id is out of range, use pdf_docs length")
+        end_page_id = len(pdf_docs) - 1
+
+    '''初始化启动时间'''
+    start_time = time.time()
+
+    for page_id, page in enumerate(pdf_docs):
+        '''debug时输出每页解析的耗时'''
+        if debug_mode:
+            time_now = time.time()
+            logger.info(
+                f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
+            )
+            start_time = time_now
+
+        '''解析pdf中的每一页'''
+        if start_page_id <= page_id <= end_page_id:
+            page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+        else:
+            page_w = page.rect.width
+            page_h = page.rect.height
+            page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
+                                                [], [], [], [],
+                                                True, "skip page")
+        pdf_info_dict[f"page_{page_id}"] = page_info
+
+    """分段"""
+    # para_split(pdf_info_dict, debug_mode=debug_mode)
+    for page_num, page in pdf_info_dict.items():
+        page['para_blocks'] = page['preproc_blocks']
+
+    """dict转list"""
+    pdf_info_list = dict_to_list(pdf_info_dict)
+    new_pdf_info_dict = {
+        "pdf_info": pdf_info_list,
+    }
+
+    clean_memory()
+
+    return new_pdf_info_dict
+
+
+if __name__ == '__main__':
+    pass

+ 53 - 0
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -60,6 +60,59 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
     return all_bboxes, all_discarded_blocks, drop_reasons
 
 
+def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks,
+                                        title_blocks, interline_equation_blocks, page_w, page_h):
+    all_bboxes = []
+    all_discarded_blocks = []
+    for image in img_blocks:
+        x0, y0, x1, y1 = image['bbox']
+        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]])
+
+    for table in table_blocks:
+        x0, y0, x1, y1 = table['bbox']
+        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
+
+    for text in text_blocks:
+        x0, y0, x1, y1 = text['bbox']
+        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
+
+    for title in title_blocks:
+        x0, y0, x1, y1 = title['bbox']
+        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]])
+
+    for interline_equation in interline_equation_blocks:
+        x0, y0, x1, y1 = interline_equation['bbox']
+        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]])
+
+    '''block嵌套问题解决'''
+    '''文本框与标题框重叠,优先信任文本框'''
+    all_bboxes = fix_text_overlap_title_blocks(all_bboxes)
+    '''任何框体与舍弃框重叠,优先信任舍弃框'''
+    all_bboxes = remove_need_drop_blocks(all_bboxes, discarded_blocks)
+
+    # interline_equation 与title或text框冲突的情况,分两种情况处理
+    '''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
+    all_bboxes = fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes)
+    '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
+    # 通过后续大框套小框逻辑删除
+
+    '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
+    for discarded in discarded_blocks:
+        x0, y0, x1, y1 = discarded['bbox']
+        all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
+        # 将footnote加入到all_bboxes中,用来计算layout
+        # if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
+        #     all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
+
+    '''经过以上处理后,还存在大框套小框的情况,则删除小框'''
+    all_bboxes = remove_overlaps_min_blocks(all_bboxes)
+    all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
+    '''将剩余的bbox做分离处理,防止后面分layout时出错'''
+    # all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
+
+    return all_bboxes, all_discarded_blocks
+
+
 def fix_interline_equation_overlap_text_blocks_with_hi_iou(all_bboxes):
     # 先提取所有text和interline block
     text_blocks = []

+ 8 - 4
magic_pdf/tools/common.py

@@ -7,7 +7,7 @@ from loguru import logger
 
 import magic_pdf.model as model_config
 from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox,
-                                      drow_model_bbox)
+                                      draw_model_bbox, draw_line_sort_bbox)
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
@@ -39,17 +39,19 @@ def do_parse(
     f_dump_middle_json=True,
     f_dump_model_json=True,
     f_dump_orig_pdf=True,
-    f_dump_content_list=False,
+    f_dump_content_list=True,
     f_make_md_mode=MakeMode.MM_MD,
     f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
     start_page_id=0,
     end_page_id=None,
     lang=None,
 ):
     if debug_able:
         logger.warning('debug mode is on')
-        f_dump_content_list = True
+        # f_dump_content_list = True
         f_draw_model_bbox = True
+        f_draw_line_sort_bbox = True
 
     orig_model_list = copy.deepcopy(model_list)
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
@@ -90,7 +92,9 @@ def do_parse(
     if f_draw_span_bbox:
         draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
     if f_draw_model_bbox:
-        drow_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+        draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+    if f_draw_line_sort_bbox:
+        draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
 
     md_content = pipe.pipe_mk_markdown(image_dir,
                                        drop_mode=DropMode.NONE,

+ 2 - 0
requirements.txt

@@ -9,4 +9,6 @@ pydantic>=2.7.2,<2.8.0
 PyMuPDF>=1.24.9
 scikit-learn>=1.0.2
 wordninja>=2.0.0
+torch>=2.2.2,<=2.3.1
+transformers
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.