Эх сурвалжийг харах

ocr模式下删除header/page number/footnote/footer

赵小蒙 1 жил өмнө
parent
commit
388223f2e0

+ 5 - 4
demo/draw_bbox.py

@@ -1,20 +1,21 @@
 from magic_pdf.libs.commons import fitz  # PyMuPDF
 
 # PDF文件路径
-pdf_path = "D:\\project\\20231108code-clean\\code-clean\\tmp\\unittest\\download-pdfs\\scihub\\scihub_53700000\\libgen.scimag53724000-53724999.zip_10.1097\\00129191-200509000-00018.pdf"
+pdf_path = r"D:\project\20231108code-clean\magic_pdf\tmp\unittest\download-pdfs\ocr_1.json.pdf"
 
 doc = fitz.open(pdf_path)  # Open the PDF
 # 你的数据
-data = [[[-2, 0, 603, 80, 24]], [[-3, 0, 602, 80, 24]]]
+data = [[(294.7569528415961, 776.8430953398889, 300.8827085852479, 786.922616502779), (460.1523579201934, 776.8430953398889, 509.51874244256345, 787.2825994014537)], [(294.03627569528413, 779.7229585292861, 301.24304715840384, 788.3625480974777), (85.76058041112454, 781.882855921334, 156.74727932285367, 789.8024796921762)], [(293.6759371221282, 779.7229585292861, 301.60338573155985, 788.7225309961523), (459.43168077388145, 779.7229585292861, 508.7980652962515, 789.8024796921762)], [(295.8379685610641, 780.0829414279607, 301.24304715840384, 788.0025651988029), (85.76058041112454, 781.5228730226593, 156.74727932285367, 790.1624625908509)], [(294.03627569528413, 779.7229585292861, 301.60338573155985, 789.0825138948269), (459.79201934703747, 779.7229585292861, 508.4377267230955, 789.4424967935015)], [(86.4812575574365, 781.882855921334, 156.0266021765417, 789.8024796921762)], [(294.39661426844015, 779.7229585292861, 301.24304715840384, 788.3625480974777), (459.43168077388145, 779.7229585292861, 508.7980652962515, 789.4424967935015)], [(294.03627569528413, 779.7229585292861, 301.24304715840384, 788.3625480974777), (85.76058041112454, 781.5228730226593, 156.74727932285367, 789.8024796921762)], [(294.39661426844015, 779.7229585292861, 300.8827085852479, 788.3625480974777)]]
 
 # 对每个页面进行处理
 for i, page in enumerate(doc):
     # 获取当前页面的数据
     page_data = data[i]
     for img in page_data:
-        x0, y0, x1, y1, _ = img
+        # x0, y0, x1, y1, _ = img
+        x0, y0, x1, y1 = img
         rect_coords = fitz.Rect(x0, y0, x1, y1)  # Define the rectangle
         page.draw_rect(rect_coords, color=(1, 0, 0), fill=None, width=1.5, overlay=True)  # Draw the rectangle
 
 # Save the PDF
-doc.save("D:\\project\\20231108code-clean\\code-clean\\tmp\\unittest\\download-pdfs\\scihub\\scihub_53700000\\libgen.scimag53724000-53724999.zip_10.1097\\00129191-200509000-00018_new.pdf")
+doc.save(r"D:\project\20231108code-clean\magic_pdf\tmp\unittest\download-pdfs\ocr_1.json_new.pdf")

+ 27 - 6
demo/ocr_demo.py

@@ -2,8 +2,10 @@ import json
 import os
 
 from loguru import logger
+from pathlib import Path
 
 from magic_pdf.dict2md.ocr_mkcontent import mk_nlp_markdown
+from magic_pdf.libs.commons import join_path
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 
 
@@ -28,12 +30,31 @@ def read_json_file(file_path):
 
 
 if __name__ == '__main__':
-    ocr_json_file_path = r"D:\project\20231108code-clean\ocr\new\demo_4\ocr_1(3).json"
+    ocr_pdf_path = r"D:\project\20231108code-clean\ocr\new\demo_4\ocr_demo\ocr_1_org.pdf"
+    ocr_json_file_path = r"D:\project\20231108code-clean\ocr\new\demo_4\ocr_demo\ocr_1.json"
     try:
-        ocr_pdf_info = read_json_file(ocr_json_file_path)
-        pdf_info_dict = parse_pdf_by_ocr(ocr_pdf_info)
-        markdown_text = mk_nlp_markdown(pdf_info_dict)
-        logger.info(markdown_text)
-        save_markdown(markdown_text, ocr_json_file_path)
+        ocr_pdf_model_info = read_json_file(ocr_json_file_path)
+        pth = Path(ocr_json_file_path)
+        book_name = pth.name
+        save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
+        save_path = join_path(save_tmp_path, "md")
+        text_content_save_path = f"{save_path}/{book_name}/book.md"
+        pdf_info_dict = parse_pdf_by_ocr(
+            ocr_pdf_path,
+            None,
+            ocr_pdf_model_info,
+            book_name,
+            debug_mode=True)
+        parent_dir = os.path.dirname(text_content_save_path)
+        if not os.path.exists(parent_dir):
+            os.makedirs(parent_dir)
+
+        markdown_content = mk_nlp_markdown(pdf_info_dict)
+
+        with open(text_content_save_path, "w", encoding="utf-8") as f:
+            f.write(markdown_content)
+
+        # logger.info(markdown_content)
+        # save_markdown(markdown_text, ocr_json_file_path)
     except Exception as e:
         logger.error(e)

+ 29 - 0
magic_pdf/libs/commons.py

@@ -1,4 +1,5 @@
 import datetime
+import json
 import os, re, configparser
 import time
 
@@ -115,6 +116,34 @@ def read_file(pdf_path: str, s3_profile):
         with open(pdf_path, "rb") as f:
             return f.read()
 
+
+def get_docx_model_output(pdf_model_output, pdf_model_s3_profile, page_id):
+    if isinstance(pdf_model_output, str):
+        model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")  # 模型输出的页面编号从1开始的
+        if os.path.exists(model_output_json_path):
+            json_from_docx = read_file(model_output_json_path, pdf_model_s3_profile)
+            model_output_json = json.loads(json_from_docx)
+        else:
+            try:
+                model_output_json_path = join_path(pdf_model_output, "model.json")
+                with open(model_output_json_path, "r", encoding="utf-8") as f:
+                    model_output_json = json.load(f)
+                    model_output_json = model_output_json["doc_layout_result"][page_id]
+            except:
+                s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")
+                s3_model_output_json_path = join_path(pdf_model_output, f"{page_id}.json")
+                #s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id }.json")
+                # logger.warning(f"model_output_json_path: {model_output_json_path} not found. try to load from s3: {s3_model_output_json_path}")
+
+                s = read_file(s3_model_output_json_path, pdf_model_s3_profile)
+                return json.loads(s)
+
+    elif isinstance(pdf_model_output, list):
+        model_output_json = pdf_model_output[page_id]
+
+    return model_output_json
+
+
 def list_dir(dir_path:str, s3_profile:str):
     """
     列出dir_path下的所有文件

+ 1 - 26
magic_pdf/pdf_parse_by_model.py

@@ -2,7 +2,7 @@ import time
 
 # from anyio import Path
 
-from magic_pdf.libs.commons import fitz, get_delta_time, get_img_s3_client
+from magic_pdf.libs.commons import fitz, get_delta_time, get_img_s3_client, get_docx_model_output
 import json
 import os
 import math
@@ -68,31 +68,6 @@ paraSplitException_msg = ParaSplitException().message
 paraMergeException_msg = ParaMergeException().message
 
 
-def get_docx_model_output(pdf_model_output, pdf_model_s3_profile, page_id):
-    if isinstance(pdf_model_output, str):
-        model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")  # 模型输出的页面编号从1开始的
-        if os.path.exists(model_output_json_path):
-            json_from_docx = read_file(model_output_json_path, pdf_model_s3_profile)
-            model_output_json = json.loads(json_from_docx)
-        else:
-            try:
-                model_output_json_path = join_path(pdf_model_output, "model.json")
-                with open(model_output_json_path, "r", encoding="utf-8") as f:
-                    model_output_json = json.load(f)
-                    model_output_json = model_output_json["doc_layout_result"][page_id]
-            except:
-                s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id + 1}.json")
-                s3_model_output_json_path = join_path(pdf_model_output, f"{page_id}.json")
-                #s3_model_output_json_path = join_path(pdf_model_output, f"page_{page_id }.json")
-                # logger.warning(f"model_output_json_path: {model_output_json_path} not found. try to load from s3: {s3_model_output_json_path}")
-
-                s = read_file(s3_model_output_json_path, pdf_model_s3_profile)
-                return json.loads(s)
-
-    elif isinstance(pdf_model_output, list):
-        model_output_json = pdf_model_output[page_id]
-
-    return model_output_json
 
 
 def parse_pdf_by_model(

+ 99 - 3
magic_pdf/pdf_parse_by_ocr.py

@@ -1,5 +1,17 @@
+import os
+import time
+
+from loguru import logger
+
+from magic_pdf.libs.commons import read_file, join_path, fitz, get_img_s3_client, get_delta_time, get_docx_model_output
+from magic_pdf.libs.safe_filename import sanitize_filename
+from magic_pdf.pre_proc.detect_footer_by_model import parse_footers
+from magic_pdf.pre_proc.detect_footnote import parse_footnotes_by_model
+from magic_pdf.pre_proc.detect_header import parse_headers
+from magic_pdf.pre_proc.detect_page_number import parse_pageNos
 from magic_pdf.pre_proc.ocr_detect_layout import layout_detect
 from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line, remove_overlaps_min_spans
+from magic_pdf.pre_proc.ocr_remove_spans import remove_spans_by_bboxes
 
 
 def construct_page_component(page_id, blocks, layout_bboxes):
@@ -12,22 +24,100 @@ def construct_page_component(page_id, blocks, layout_bboxes):
 
 
 def parse_pdf_by_ocr(
-    ocr_pdf_info,
+    pdf_path,
+    s3_pdf_profile,
+    pdf_model_output,
+    book_name,
+    pdf_model_profile=None,
+    image_s3_config=None,
     start_page_id=0,
     end_page_id=None,
+    debug_mode=False,
 ):
+    pdf_bytes = read_file(pdf_path, s3_pdf_profile)
+    save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
+    book_name = sanitize_filename(book_name)
+    md_bookname_save_path = ""
+    if debug_mode:
+        save_path = join_path(save_tmp_path, "md")
+        pdf_local_path = join_path(save_tmp_path, "download-pdfs", book_name)
+
+        if not os.path.exists(os.path.dirname(pdf_local_path)):
+            # 如果目录不存在,创建它
+            os.makedirs(os.path.dirname(pdf_local_path))
+
+        md_bookname_save_path = join_path(save_tmp_path, "md", book_name)
+        if not os.path.exists(md_bookname_save_path):
+            # 如果目录不存在,创建它
+            os.makedirs(md_bookname_save_path)
+
+        with open(pdf_local_path + ".pdf", "wb") as pdf_file:
+            pdf_file.write(pdf_bytes)
+
 
+    pdf_docs = fitz.open("pdf", pdf_bytes)
+    # 初始化空的pdf_info_dict
     pdf_info_dict = {}
-    end_page_id = end_page_id if end_page_id else len(ocr_pdf_info) - 1
+    img_s3_client = get_img_s3_client(save_path, image_s3_config)
+
+    start_time = time.time()
+
+    remove_bboxes = []
+
+    end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
     for page_id in range(start_page_id, end_page_id + 1):
-        ocr_page_info = ocr_pdf_info[page_id]
+
+        # 获取当前页的page对象
+        page = pdf_docs[page_id]
+
+        if debug_mode:
+            time_now = time.time()
+            logger.info(f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}")
+            start_time = time_now
+
+        # 获取当前页的模型数据
+        ocr_page_info = get_docx_model_output(pdf_model_output, pdf_model_profile, page_id)
+
+        """从json中获取每页的页码、页眉、页脚的bbox"""
+        page_no_bboxes = parse_pageNos(page_id, page, ocr_page_info)
+        header_bboxes = parse_headers(page_id, page, ocr_page_info)
+        footer_bboxes = parse_footers(page_id, page, ocr_page_info)
+        footnote_bboxes =  parse_footnotes_by_model(page_id, page, ocr_page_info, md_bookname_save_path, debug_mode=debug_mode)
+
+        # 构建需要remove的bbox列表
+        need_remove_spans_bboxes = []
+        need_remove_spans_bboxes.extend(page_no_bboxes)
+        need_remove_spans_bboxes.extend(header_bboxes)
+        need_remove_spans_bboxes.extend(footer_bboxes)
+        need_remove_spans_bboxes.extend(footnote_bboxes)
+        remove_bboxes.append(need_remove_spans_bboxes)
+
+
+
         layout_dets = ocr_page_info['layout_dets']
         spans = []
+
+        # 将模型坐标转换成pymu格式下的未缩放坐标
+        DPI = 72  # use this resolution
+        pix = page.get_pixmap(dpi=DPI)
+        pageL = 0
+        pageR = int(pix.w)
+        pageU = 0
+        pageD = int(pix.h)
+        width_from_json = ocr_page_info['page_info']['width']
+        height_from_json = ocr_page_info['page_info']['height']
+        LR_scaleRatio = width_from_json / (pageR - pageL)
+        UD_scaleRatio = height_from_json / (pageD - pageU)
+
         for layout_det in layout_dets:
             category_id = layout_det['category_id']
             allow_category_id_list = [1, 7, 13, 14, 15]
             if category_id in allow_category_id_list:
                 x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
+                x0 = x0 / LR_scaleRatio
+                y0 = y0 / UD_scaleRatio
+                x1 = x1 / LR_scaleRatio
+                y1 = y1 / UD_scaleRatio
                 bbox = [int(x0), int(y0), int(x1), int(y1)]
                 '''要删除的'''
                 #  3: 'header',      # 页眉
@@ -48,8 +138,10 @@ def parse_pdf_by_ocr(
                 }
                 if category_id == 1:
                     span['type'] = 'image'
+
                 elif category_id == 7:
                     span['type'] = 'table'
+
                 elif category_id == 13:
                     span['content'] = layout_det['latex']
                     span['type'] = 'inline_equation'
@@ -67,6 +159,9 @@ def parse_pdf_by_ocr(
         # 删除重叠spans中较小的那些
         spans = remove_overlaps_min_spans(spans)
 
+        # 删除remove_span_block_bboxes中的bbox
+        spans = remove_spans_by_bboxes(spans, need_remove_spans_bboxes)
+
         # 对tpye=["displayed_equation", "image", "table"]进行额外处理,如果左边有字的话,将该span的bbox中y0调整低于文字的y0
 
 
@@ -89,5 +184,6 @@ def parse_pdf_by_ocr(
         page_info = construct_page_component(page_id, blocks, layout_bboxes)
         pdf_info_dict[f"page_{page_id}"] = page_info
 
+    # logger.info(remove_bboxes)
     return pdf_info_dict
 

+ 18 - 0
magic_pdf/pre_proc/ocr_remove_spans.py

@@ -0,0 +1,18 @@
+from magic_pdf.libs.boxbase import _is_in_or_part_overlap
+
+
+def remove_spans_by_bboxes(spans, need_remove_spans_bboxes):
+    # 遍历spans, 判断是否在removed_span_block_bboxes中
+    # 如果是, 则删除该span
+    # 否则, 保留该span
+    need_remove_spans = []
+    for span in spans:
+        for bbox in need_remove_spans_bboxes:
+            if _is_in_or_part_overlap(span['bbox'], bbox):
+                need_remove_spans.append(span)
+                break
+
+    for span in need_remove_spans:
+        spans.remove(span)
+
+    return spans