Selaa lähdekoodia

refactor(magic_pdf): optimize code and improve logging

- Remove unused imports and comments
- Increase MIN_BATCH_INFERENCE_SIZE from 100 to 200
- Comment out VRAM cleaning and logging in batch_analyze.py
- Simplify code in doc_analyze_by_custom_model.py- Add tqdm progress bar in pdf_parse_union_core_v2.py
- Enable tqdm in OCR processing
myhloli 7 kuukautta sitten
vanhempi
commit
553f250fc7

+ 1 - 5
magic_pdf/model/batch_analyze.py

@@ -1,18 +1,14 @@
 import time
-
 import cv2
-import torch
 from loguru import logger
 from tqdm import tqdm
 
 from magic_pdf.config.constants import MODEL_NAME
-from magic_pdf.libs.config_reader import get_table_recog_config
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
-from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -86,7 +82,7 @@ class BatchAnalyze:
             # )
 
         # 清理显存
-        clean_vram(self.model.device, vram_threshold=8)
+        # clean_vram(self.model.device, vram_threshold=8)
 
         ocr_res_list_all_page = []
         table_res_list_all_page = []

+ 15 - 19
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -188,7 +188,7 @@ def batch_doc_analyze(
     formula_enable=None,
     table_enable=None,
 ):
-    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
     batch_size = MIN_BATCH_INFERENCE_SIZE
     images = []
     page_wh_list = []
@@ -245,8 +245,7 @@ def may_batch_image_analyze(
 
     model_manager = ModelSingleton()
 
-    images = [image for image, _, _ in images_with_extra_info]
-    batch_analyze = False
+    # images = [image for image, _, _ in images_with_extra_info]
     batch_ratio = 1
     device = get_device()
 
@@ -269,25 +268,22 @@ def may_batch_image_analyze(
             else:
                 batch_ratio = 1
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
-            # batch_analyze = True
-    elif str(device).startswith('mps'):
-        # batch_analyze = True
-        pass
 
-    doc_analyze_start = time.time()
+
+    # doc_analyze_start = time.time()
 
     batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
     results = batch_model(images_with_extra_info)
 
-    gc_start = time.time()
+    # gc_start = time.time()
     clean_memory(get_device())
-    gc_time = round(time.time() - gc_start, 2)
-    logger.info(f'gc time: {gc_time}')
-
-    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
-    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
-    logger.info(
-        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
-        f' speed: {doc_analyze_speed} pages/second'
-    )
-    return (idx, results)
+    # gc_time = round(time.time() - gc_start, 2)
+    # logger.debug(f'gc time: {gc_time}')
+
+    # doc_analyze_time = round(time.time() - doc_analyze_start, 2)
+    # doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
+    # logger.debug(
+    #     f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
+    #     f' speed: {doc_analyze_speed} pages/second'
+    # )
+    return idx, results

+ 13 - 11
magic_pdf/pdf_parse_union_core_v2.py

@@ -12,6 +12,7 @@ import fitz
 import torch
 import numpy as np
 from loguru import logger
+from tqdm import tqdm
 
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.ocr_content_type import BlockType, ContentType
@@ -932,17 +933,18 @@ def pdf_parse_union(
         logger.warning('end_page_id is out of range, use pdf_docs length')
         end_page_id = len(dataset) - 1
 
-    """初始化启动时间"""
-    start_time = time.time()
+    # """初始化启动时间"""
+    # start_time = time.time()
 
-    for page_id, page in enumerate(dataset):
-        """debug时输出每页解析的耗时."""
-        if debug_mode:
-            time_now = time.time()
-            logger.info(
-                f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
-            )
-            start_time = time_now
+    # for page_id, page in enumerate(dataset):
+    for page_id, page in tqdm(enumerate(dataset), total=len(dataset), desc="Processing pages"):
+        # """debug时输出每页解析的耗时."""
+        # if debug_mode:
+            # time_now = time.time()
+            # logger.info(
+            #     f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
+            # )
+            # start_time = time_now
 
         """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
@@ -988,7 +990,7 @@ def pdf_parse_union(
             lang=lang
         )
         rec_start = time.time()
-        ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
+        ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
         # Verify we have matching counts
         assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
         # Process OCR results for this language