Przeglądaj źródła

fix(parse_pipeline): Resolve post-processing exceptions caused by partial PDFs due to file corruption or non-standard format by forcing a re-print.

myhloli 1 rok temu
rodzic
commit
918ed65bd5
1 zmienionych plików z 47 dodań i 3 usunięć
  1. 47 3
      magic_pdf/tools/common.py

+ 47 - 3
magic_pdf/tools/common.py

@@ -14,6 +14,9 @@ from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
+import fitz
+# from io import BytesIO
+# from pypdf import PdfReader, PdfWriter
 
 
 def prepare_env(output_dir, pdf_file_name, method):
@@ -26,6 +29,42 @@ def prepare_env(output_dir, pdf_file_name, method):
     return local_image_dir, local_md_dir
 
 
+# def convert_pdf_bytes_to_bytes_by_pypdf(pdf_bytes, start_page_id=0, end_page_id=None):
+#     # 将字节数据包装在 BytesIO 对象中
+#     pdf_file = BytesIO(pdf_bytes)
+#     # 读取 PDF 的字节数据
+#     reader = PdfReader(pdf_file)
+#     # 创建一个新的 PDF 写入器
+#     writer = PdfWriter()
+#     # 将所有页面添加到新的 PDF 写入器中
+#     end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(reader.pages) - 1
+#     if end_page_id > len(reader.pages) - 1:
+#         logger.warning("end_page_id is out of range, use pdf_docs length")
+#         end_page_id = len(reader.pages) - 1
+#     for i, page in enumerate(reader.pages):
+#         if start_page_id <= i <= end_page_id:
+#             writer.add_page(page)
+#     # 创建一个字节缓冲区来存储输出的 PDF 数据
+#     output_buffer = BytesIO()
+#     # 将 PDF 写入字节缓冲区
+#     writer.write(output_buffer)
+#     # 获取字节缓冲区的内容
+#     converted_pdf_bytes = output_buffer.getvalue()
+#     return converted_pdf_bytes
+
+
+def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
+    document = fitz.open("pdf", pdf_bytes)
+    output_document = fitz.open()
+    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1
+    if end_page_id > len(document) - 1:
+        logger.warning("end_page_id is out of range, use pdf_docs length")
+        end_page_id = len(document) - 1
+    output_document.insert_pdf(document, from_page=start_page_id, to_page=end_page_id)
+    output_bytes = output_document.tobytes()
+    return output_bytes
+
+
 def do_parse(
     output_dir,
     pdf_file_name,
@@ -55,6 +94,8 @@ def do_parse(
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
 
+    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id, end_page_id)
+
     orig_model_list = copy.deepcopy(model_list)
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
                                                 parse_method)
@@ -66,15 +107,18 @@ def do_parse(
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       # start_page_id=start_page_id, end_page_id=end_page_id,
+                       lang=lang,
                        layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'txt':
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       # start_page_id=start_page_id, end_page_id=end_page_id,
+                       lang=lang,
                        layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     elif parse_method == 'ocr':
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
+                       # start_page_id=start_page_id, end_page_id=end_page_id,
+                       lang=lang,
                        layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     else:
         logger.error('unknown parse method')