Browse Source

feat(pdf_parse): implement multi-threaded page processing

- Add ThreadPoolExecutor to process PDF pages in parallel
- Create separate function for page processing to improve readability and maintainability
- Include error handling for individual page processing tasks
- Log total page processing time for performance monitoring
myhloli 8 tháng trước cách đây
mục cha
commit
6ec440d6f1
1 tập tin đã thay đổi với 63 bổ sung8 xóa
  1. 63 8
      magic_pdf/pdf_parse_union_core_v2.py

+ 63 - 8
magic_pdf/pdf_parse_union_core_v2.py

@@ -24,6 +24,8 @@ from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
 
+from concurrent.futures import ThreadPoolExecutor
+
 try:
     import torchtext
 
@@ -937,16 +939,33 @@ def pdf_parse_union(
     """初始化启动时间"""
     start_time = time.time()
 
-    for page_id, page in enumerate(dataset):
-        """debug时输出每页解析的耗时."""
+    # for page_id, page in enumerate(dataset):
+    #     """debug时输出每页解析的耗时."""
+    #     if debug_mode:
+    #         time_now = time.time()
+    #         logger.info(
+    #             f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
+    #         )
+    #         start_time = time_now
+    #
+    #     """解析pdf中的每一页"""
+    #     if start_page_id <= page_id <= end_page_id:
+    #         page_info = parse_page_core(
+    #             page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
+    #         )
+    #     else:
+    #         page_info = page.get_page_info()
+    #         page_w = page_info.w
+    #         page_h = page_info.h
+    #         page_info = ocr_construct_page_component_v2(
+    #             [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
+    #         )
+    #     pdf_info_dict[f'page_{page_id}'] = page_info
+    def process_page(page_id, page, dataset_len, start_page_id, end_page_id, magic_model, pdf_bytes_md5, imageWriter,
+                     parse_mode, lang, debug_mode, start_time):
         if debug_mode:
             time_now = time.time()
-            logger.info(
-                f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
-            )
-            start_time = time_now
 
-        """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
             page_info = parse_page_core(
                 page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
@@ -958,7 +977,43 @@ def pdf_parse_union(
             page_info = ocr_construct_page_component_v2(
                 [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
             )
-        pdf_info_dict[f'page_{page_id}'] = page_info
+        return page_id, page_info
+
+    # Use max_workers based on CPU count but limit to avoid excessive resource usage
+    max_workers = 2
+    pdf_info_dict = {}
+
+    with ThreadPoolExecutor(max_workers=max_workers) as executor:
+        futures = {
+            executor.submit(
+                process_page,
+                page_id,
+                page,
+                len(dataset),
+                start_page_id,
+                end_page_id,
+                magic_model,
+                pdf_bytes_md5,
+                imageWriter,
+                parse_mode,
+                lang,
+                debug_mode,
+                time.time()
+            ): page_id
+            for page_id, page in enumerate(dataset)
+        }
+
+        for page_id in range(len(dataset)):
+            future = [f for f in futures if futures[f] == page_id][0]
+            try:
+                page_id, page_info = future.result()
+                pdf_info_dict[f'page_{page_id}'] = page_info
+            except Exception as e:
+                logger.exception(f"Error processing page {page_id}: {e}")
+
+    logger.info(
+        f'page_process_time: {round(time.time() - start_time, 2)}'
+    )
 
     """分段"""
     para_split(pdf_info_dict)