浏览代码

perf(model): adjust batch size for layout and formula detection

- Reduce YOLO_LAYOUT_BASE_BATCH_SIZE from 4 to 1
- Simplify batch ratio calculation for formula detection
- Remove unused conditional logic in batch ratio determination
myhloli 10 月之前
父节点
当前提交
49d140c55b
共有 2 个文件被更改,包括 7 次插入9 次删除
  1. 5 3
      magic_pdf/model/batch_analyze.py
  2. 2 6
      magic_pdf/model/doc_analyze_by_custom_model.py

+ 5 - 3
magic_pdf/model/batch_analyze.py

@@ -19,7 +19,7 @@ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
 # from magic_pdf.operators.models import InferenceResult
 
-YOLO_LAYOUT_BASE_BATCH_SIZE = 4
+YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFR_BASE_BATCH_SIZE = 16
 
@@ -56,7 +56,8 @@ class BatchAnalyze:
                 layout_images.append(pil_img)
 
             images_layout_res += self.model.layout_model.batch_predict(
-                layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
+                # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
+                layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
             )
 
             for image_index, useful_list in modified_images:
@@ -78,7 +79,8 @@ class BatchAnalyze:
             # 公式检测
             mfd_start_time = time.time()
             images_mfd_res = self.model.mfd_model.batch_predict(
-                images, self.batch_ratio * MFD_BASE_BATCH_SIZE
+                # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
+                images, MFD_BASE_BATCH_SIZE
             )
             logger.info(
                 f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'

+ 2 - 6
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -176,12 +176,8 @@ def doc_analyze(
 
     if torch.cuda.is_available() and device != 'cpu' or npu_support:
         gpu_memory = get_vram(device)
-        if gpu_memory is not None and gpu_memory >= 7:
-            # batch_ratio = int((gpu_memory-3) // 1.5)
-            batch_ratio = 2
-            if 8 < gpu_memory:
-                batch_ratio = 4
-
+        if gpu_memory is not None and gpu_memory >= 7.5:
+            batch_ratio = int((gpu_memory-5) // 1)
             if batch_ratio >= 1:
                 logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
                 batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)