Bläddra i källkod

Merge branch 'master' of github.com:myhloli/Magic-PDF

liukaiwen 1 år sedan
förälder
incheckning
5b86ca82fe

+ 0 - 0
demo/draw_bbox.py → magic_pdf/libs/draw_bbox.py


+ 77 - 47
magic_pdf/pdf_parse_by_ocr.py

@@ -4,8 +4,15 @@ import time
 
 from loguru import logger
 
-from demo.draw_bbox import draw_layout_bbox, draw_text_bbox
-from magic_pdf.libs.commons import read_file, join_path, fitz, get_img_s3_client, get_delta_time, get_docx_model_output
+from magic_pdf.libs.draw_bbox import draw_layout_bbox, draw_text_bbox
+from magic_pdf.libs.commons import (
+    read_file,
+    join_path,
+    fitz,
+    get_img_s3_client,
+    get_delta_time,
+    get_docx_model_output,
+)
 from magic_pdf.libs.coordinate_transform import get_scale_ratio
 from magic_pdf.libs.safe_filename import sanitize_filename
 from magic_pdf.pre_proc.detect_footer_by_model import parse_footers
@@ -14,30 +21,36 @@ from magic_pdf.pre_proc.detect_header import parse_headers
 from magic_pdf.pre_proc.detect_page_number import parse_pageNos
 from magic_pdf.pre_proc.ocr_cut_image import cut_image_and_table
 from magic_pdf.pre_proc.ocr_detect_layout import layout_detect
-from magic_pdf.pre_proc.ocr_dict_merge import remove_overlaps_min_spans, merge_spans_to_line_by_layout
+from magic_pdf.pre_proc.ocr_dict_merge import (
+    remove_overlaps_min_spans,
+    merge_spans_to_line_by_layout,
+)
 from magic_pdf.pre_proc.ocr_remove_spans import remove_spans_by_bboxes
+from magic_pdf.pre_proc.remove_bbox_overlap import remove_overlap_between_bbox
 
 
-def construct_page_component(page_id, blocks, layout_bboxes):
+def construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree):
     return_dict = {
         'preproc_blocks': blocks,
-        'page_idx': page_id,
         'layout_bboxes': layout_bboxes,
+        'page_idx': page_id,
+        'page_size': [page_w, page_h],
+        '_layout_tree': layout_tree,
     }
     return return_dict
 
 
 def parse_pdf_by_ocr(
-        pdf_path,
-        s3_pdf_profile,
-        pdf_model_output,
-        save_path,
-        book_name,
-        pdf_model_profile=None,
-        image_s3_config=None,
-        start_page_id=0,
-        end_page_id=None,
-        debug_mode=False,
+    pdf_path,
+    s3_pdf_profile,
+    pdf_model_output,
+    save_path,
+    book_name,
+    pdf_model_profile=None,
+    image_s3_config=None,
+    start_page_id=0,
+    end_page_id=None,
+    debug_mode=False,
 ):
     pdf_bytes = read_file(pdf_path, s3_pdf_profile)
     save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
@@ -73,21 +86,29 @@ def parse_pdf_by_ocr(
 
         # 获取当前页的page对象
         page = pdf_docs[page_id]
+        # 获取当前页的宽高
+        page_w = page.rect.width
+        page_h = page.rect.height
 
         if debug_mode:
             time_now = time.time()
-            logger.info(f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}")
+            logger.info(
+                f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
+            )
             start_time = time_now
 
         # 获取当前页的模型数据
-        ocr_page_info = get_docx_model_output(pdf_model_output, pdf_model_profile, page_id)
+        ocr_page_info = get_docx_model_output(
+            pdf_model_output, pdf_model_profile, page_id
+        )
 
         """从json中获取每页的页码、页眉、页脚的bbox"""
         page_no_bboxes = parse_pageNos(page_id, page, ocr_page_info)
         header_bboxes = parse_headers(page_id, page, ocr_page_info)
         footer_bboxes = parse_footers(page_id, page, ocr_page_info)
-        footnote_bboxes = parse_footnotes_by_model(page_id, page, ocr_page_info, md_bookname_save_path,
-                                                   debug_mode=debug_mode)
+        footnote_bboxes = parse_footnotes_by_model(
+            page_id, page, ocr_page_info, md_bookname_save_path, debug_mode=debug_mode
+        )
 
         # 构建需要remove的bbox列表
         need_remove_spans_bboxes = []
@@ -96,51 +117,57 @@ def parse_pdf_by_ocr(
         need_remove_spans_bboxes.extend(footer_bboxes)
         need_remove_spans_bboxes.extend(footnote_bboxes)
 
-        layout_dets = ocr_page_info['layout_dets']
+        layout_dets = ocr_page_info["layout_dets"]
         spans = []
 
         # 计算模型坐标和pymu坐标的缩放比例
-        horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(ocr_page_info, page)
+        horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
+            ocr_page_info, page
+        )
 
         for layout_det in layout_dets:
-            category_id = layout_det['category_id']
+            category_id = layout_det["category_id"]
             allow_category_id_list = [1, 7, 13, 14, 15]
             if category_id in allow_category_id_list:
-                x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
-                bbox = [int(x0 / horizontal_scale_ratio), int(y0 / vertical_scale_ratio),
-                        int(x1 / horizontal_scale_ratio), int(y1 / vertical_scale_ratio)]
-                '''要删除的'''
+                x0, y0, _, _, x1, y1, _, _ = layout_det["poly"]
+                bbox = [
+                    int(x0 / horizontal_scale_ratio),
+                    int(y0 / vertical_scale_ratio),
+                    int(x1 / horizontal_scale_ratio),
+                    int(y1 / vertical_scale_ratio),
+                ]
+                """要删除的"""
                 #  3: 'header',      # 页眉
                 #  4: 'page number', # 页码
                 #  5: 'footnote',    # 脚注
                 #  6: 'footer',      # 页脚
-                '''当成span拼接的'''
+                """当成span拼接的"""
                 #  1: 'image', # 图片
                 #  7: 'table',       # 表格
                 #  13: 'inline_equation',     # 行内公式
                 #  14: 'displayed_equation',      # 行间公式
                 #  15: 'text',      # ocr识别文本
-                '''layout信息'''
+                """layout信息"""
                 #  11: 'full column',   # 单栏
                 #  12: 'sub column',    # 多栏
                 span = {
-                    'bbox': bbox,
+                    "bbox": bbox,
                 }
                 if category_id == 1:
-                    span['type'] = 'image'
+                    span["type"] = "image"
 
                 elif category_id == 7:
-                    span['type'] = 'table'
+                    span["type"] = "table"
 
                 elif category_id == 13:
-                    span['content'] = layout_det['latex']
-                    span['type'] = 'inline_equation'
+                    span["content"] = layout_det["latex"]
+                    span["type"] = "inline_equation"
                 elif category_id == 14:
-                    span['content'] = layout_det['latex']
-                    span['type'] = 'displayed_equation'
+                    span["content"] = layout_det["latex"]
+                    span["type"] = "displayed_equation"
                 elif category_id == 15:
-                    span['content'] = layout_det['text']
-                    span['type'] = 'text'
+                    span["content"] = layout_det["text"]
+                    span["type"] = "text"
                 # print(span)
                 spans.append(span)
             else:
@@ -155,42 +182,45 @@ def parse_pdf_by_ocr(
         # 对image和table截图
         spans = cut_image_and_table(spans, page, page_id, book_name, save_path)
 
-
         # 行内公式调整, 高度调整至与同行文字高度一致(优先左侧, 其次右侧)
 
         # 模型识别错误的行间公式, type类型转换成行内公式
 
         # bbox去除粘连
+        spans = remove_overlap_between_bbox(spans)
 
         # 对tpye=["displayed_equation", "image", "table"]进行额外处理,如果左边有字的话,将该span的bbox中y0调整至不高于文字的y0
 
         # 从ocr_page_info中解析layout信息(按自然阅读方向排序,并修复重叠和交错的bad case)
-        layout_bboxes = layout_detect(ocr_page_info['subfield_dets'], page, ocr_page_info)
+        layout_bboxes, layout_tree = layout_detect(ocr_page_info['subfield_dets'], page, ocr_page_info)
 
         # 将spans合并成line(在layout内,从上到下,从左到右)
         lines = merge_spans_to_line_by_layout(spans, layout_bboxes)
 
-
         # 目前不做block拼接,先做个结构,每个block中只有一个line,block的bbox就是line的bbox
         blocks = []
         for line in lines:
-            blocks.append({
-                "bbox": line['bbox'],
-                "lines": [line],
-            })
+            blocks.append(
+                {
+                    "bbox": line["bbox"],
+                    "lines": [line],
+                }
+            )
 
         # 构造pdf_info_dict
-        page_info = construct_page_component(page_id, blocks, layout_bboxes)
+        page_info = construct_page_component(blocks, layout_bboxes, page_id, page_w, page_h, layout_tree)
         pdf_info_dict[f"page_{page_id}"] = page_info
 
     # 在测试时,保存调试信息
     if debug_mode:
-        params_file_save_path = join_path(save_tmp_path, "md", book_name, "preproc_out.json")
+        params_file_save_path = join_path(
+            save_tmp_path, "md", book_name, "preproc_out.json"
+        )
         with open(params_file_save_path, "w", encoding="utf-8") as f:
             json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
+
         # drow_bbox
         draw_layout_bbox(pdf_info_dict, pdf_path, md_bookname_save_path)
         draw_text_bbox(pdf_info_dict, pdf_path, md_bookname_save_path)
 
-
     return pdf_info_dict

+ 3 - 3
magic_pdf/pre_proc/ocr_detect_layout.py

@@ -69,7 +69,7 @@ def adjust_layouts(layout_bboxes, page_boundry, page_id):
     layout_bboxes, layout_tree = get_bboxes_layout(new_bboxes, page_boundry, page_id)
 
     # 返回排序调整后的布局边界框列表
-    return layout_bboxes
+    return layout_bboxes, layout_tree
 
 
 def layout_detect(layout_info, page: fitz.Page, ocr_page_info):
@@ -127,7 +127,7 @@ def layout_detect(layout_info, page: fitz.Page, ocr_page_info):
     page_width = page.rect.width
     page_height = page.rect.height
     page_boundry = [0, 0, page_width, page_height]
-    layout_bboxes = adjust_layouts(new_layout_bboxes, page_boundry, page_id)
+    layout_bboxes, layout_tree = adjust_layouts(new_layout_bboxes, page_boundry, page_id)
 
     # 返回排序调整后的布局边界框列表
-    return layout_bboxes
+    return layout_bboxes, layout_tree

+ 1 - 1
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -9,7 +9,7 @@ def remove_overlaps_min_spans(spans):
     for span1 in spans.copy():
         for span2 in spans.copy():
             if span1 != span2:
-                overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.8)
+                overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.5)
                 if overlap_box is not None:
                     bbox_to_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
                     if bbox_to_remove is not None:

+ 43 - 0
magic_pdf/pre_proc/remove_bbox_overlap.py

@@ -0,0 +1,43 @@
+from magic_pdf.libs.boxbase import _is_in_or_part_overlap, _is_in
+
+
+def _remove_overlap_between_bbox(spans):
+    res = []
+    for v in spans:
+        for i in range(len(res)):
+            if _is_in(res[i]["bbox"], v["bbox"]):
+                continue
+            if _is_in_or_part_overlap(res[i]["bbox"], v["bbox"]):
+                ix0, iy0, ix1, iy1 = res[i]["bbox"]
+                x0, y0, x1, y1 = v["bbox"]
+
+                diff_x = min(x1, ix1) - max(x0, ix0)
+                diff_y = min(y1, iy1) - max(y0, iy0)
+
+                if diff_y > diff_x:
+                    if x1 >= ix1:
+                        mid = (x0 + ix1) // 2
+                        ix1 = min(mid, ix1)
+                        x0 = max(mid + 1, x0)
+                    else:
+                        mid = (ix0 + x1) // 2
+                        ix0 = max(mid + 1, ix0)
+                        x1 = min(mid, x1)
+                else:
+                    if y1 >= iy1:
+                        mid = (y0 + iy1) // 2
+                        y0 = max(mid + 1, y0)
+                        iy1 = min(iy1, mid)
+                    else:
+                        mid = (iy0 + y1) // 2
+                        y1 = min(y1, mid)
+                        iy0 = max(mid + 1, iy0)
+                res[i]["bbox"] = [ix0, iy0, ix1, iy1]
+                v["bbox"] = [x0, y0, x1, y1]
+
+        res.append(v)
+    return res
+
+
+def remove_overlap_between_bbox(spans):
+    return _remove_overlap_between_bbox(spans)