Bladeren bron

pipeline重构

赵小蒙 1 jaar geleden
bovenliggende
commit
7f0c734ff6
3 gewijzigde bestanden met toevoegingen van 21 en 19 verwijderingen
  1. 4 2
      demo/pdf2md.py
  2. 2 3
      magic_pdf/pdf_parse_by_model.py
  3. 15 14
      magic_pdf/pipeline.py

+ 4 - 2
demo/pdf2md.py

@@ -5,7 +5,7 @@ from pathlib import Path
 import click
 from loguru import logger
 
-from magic_pdf.libs.commons import join_path
+from magic_pdf.libs.commons import join_path, read_file
 from magic_pdf.dict2md.mkcontent import mk_mm_markdown
 from magic_pdf.pipeline import parse_pdf_by_model
 
@@ -21,9 +21,11 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
     text_content_save_path = f"{save_path}/{book_name}/book.md"
     # metadata_save_path = f"{save_path}/{book_name}/metadata.json"
 
+    pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
+
     try:
         paras_dict = parse_pdf_by_model(
-            s3_pdf_path, s3_pdf_profile, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode
+            pdf_bytes, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode
         )
         parent_dir = os.path.dirname(text_content_save_path)
         if not os.path.exists(parent_dir):

+ 2 - 3
magic_pdf/pdf_parse_by_model.py

@@ -71,8 +71,7 @@ paraMergeException_msg = ParaMergeException().message
 
 
 def parse_pdf_by_model(
-    s3_pdf_path,
-    s3_pdf_profile,
+    pdf_bytes,
     pdf_model_output,
     save_path,
     book_name,
@@ -83,7 +82,7 @@ def parse_pdf_by_model(
     junk_img_bojids=[],
     debug_mode=False,
 ):
-    pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
+
     save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest")
     md_bookname_save_path = ""
     book_name = sanitize_filename(book_name)

+ 15 - 14
magic_pdf/pipeline.py

@@ -304,6 +304,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
     # 开始正式逻辑
     s3_pdf_path = jso.get("file_location")
     s3_config = get_s3_config(s3_pdf_path)
+    pdf_bytes = read_file(s3_pdf_path, s3_config)
     model_output_json_list = jso.get("doc_layout_result")
     data_source = get_data_source(jso)
     file_id = jso.get("file_id")
@@ -341,8 +342,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
                 file=sys.stderr,
             )
             pdf_info_dict = parse_pdf_by_model(
-                s3_pdf_path,
-                s3_config,
+                pdf_bytes,
                 model_output_json_list,
                 save_path,
                 book_name,
@@ -373,18 +373,6 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
     return jso
 
 
-"""
-统一处理逻辑
-1.先调用parse_pdf对文本类pdf进行处理
-2.再调用ocr_dropped_parse_pdf,对之前drop的pdf进行处理
-"""
-
-
-def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
-    jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
-    jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
-    return jso
-
 def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> dict:
     # 检测debug开关
     if debug_mode:
@@ -465,6 +453,19 @@ def parse_pdf_for_model_train(jso: dict, start_page_id=0, debug_mode=False) -> d
     return jso
 
 
+"""
+统一处理逻辑
+1.先调用parse_pdf对文本类pdf进行处理
+2.再调用ocr_dropped_parse_pdf,对之前drop的pdf进行处理
+"""
+
+
+def uni_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
+    jso = parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
+    jso = ocr_dropped_parse_pdf(jso, start_page_id=start_page_id, debug_mode=debug_mode)
+    return jso
+
+
 # 专门用来跑被drop的pdf,跑完之后需要把need_drop字段置为false
 def ocr_dropped_parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
     if not jso.get("need_drop", False):