Răsfoiți Sursa

Merge pull request #1593 from myhloli/dev

perf(magic_pdf): optimize batch ratio calculation for GPU
Xiaomeng Zhao 10 luni în urmă
părinte
comite
636d78a3d7

+ 5 - 3
magic_pdf/model/batch_analyze.py

@@ -19,7 +19,7 @@ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
 # from magic_pdf.operators.models import InferenceResult
 
-YOLO_LAYOUT_BASE_BATCH_SIZE = 4
+YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFR_BASE_BATCH_SIZE = 16
 
@@ -56,7 +56,8 @@ class BatchAnalyze:
                 layout_images.append(pil_img)
 
             images_layout_res += self.model.layout_model.batch_predict(
-                layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
+                # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
+                layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
             )
 
             for image_index, useful_list in modified_images:
@@ -78,7 +79,8 @@ class BatchAnalyze:
             # 公式检测
             mfd_start_time = time.time()
             images_mfd_res = self.model.mfd_model.batch_predict(
-                images, self.batch_ratio * MFD_BASE_BATCH_SIZE
+                # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
+                images, MFD_BASE_BATCH_SIZE
             )
             logger.info(
                 f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'

+ 3 - 7
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -175,13 +175,9 @@ def doc_analyze(
             npu_support = True
 
     if torch.cuda.is_available() and device != 'cpu' or npu_support:
-        gpu_memory = get_vram(device)
-        if gpu_memory is not None and gpu_memory >= 7:
-            # batch_ratio = int((gpu_memory-3) // 1.5)
-            batch_ratio = 2
-            if 8 < gpu_memory:
-                batch_ratio = 4
-
+        gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
+        if gpu_memory is not None and gpu_memory >= 8:
+            batch_ratio = int(gpu_memory-5)
             if batch_ratio >= 1:
                 logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
                 batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)

+ 1 - 1
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -89,7 +89,7 @@ class UnimernetModel(object):
             mf_image_list.append(bbox_img)
 
         dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-        dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
+        dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
         mfr_res = []
         for mf_img in dataloader:
             mf_img = mf_img.to(self.device)