Browse Source

fix: adjust resolution grouping stride for improved image normalization

myhloli 3 months ago
parent
commit
7676543ff8

+ 7 - 4
mineru/backend/pipeline/batch_analyze.py

@@ -145,14 +145,17 @@ class BatchAnalyze:
                 )
 
                 # 按分辨率分组并同时完成padding
+                # RESOLUTION_GROUP_STRIDE = 32
+                RESOLUTION_GROUP_STRIDE = 64  # 定义分辨率分组的步进值
+
                 resolution_groups = defaultdict(list)
                 for crop_info in lang_crop_list:
                     cropped_img = crop_info[0]
                     h, w = cropped_img.shape[:2]
                     # 使用更大的分组容差,减少分组数量
                     # 将尺寸标准化到32的倍数
-                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
-                    normalized_w = ((w + 32) // 32) * 32
+                    normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE  # 向上取整到32的倍数
+                    normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
                     group_key = (normalized_h, normalized_w)
                     resolution_groups[group_key].append(crop_info)
 
@@ -162,8 +165,8 @@ class BatchAnalyze:
                     # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
                     max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
                     max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
-                    target_h = ((max_h + 32 - 1) // 32) * 32
-                    target_w = ((max_w + 32 - 1) // 32) * 32
+                    target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+                    target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
 
                     # 对所有图像进行padding到统一尺寸
                     batch_images = []

+ 5 - 1
mineru/backend/pipeline/pipeline_analyze.py

@@ -190,10 +190,14 @@ def batch_image_analyze(
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
-    if str(device).startswith('mps'):
+    # 检测torch的版本号
+    import torch
+    from packaging import version
+    if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
         enable_ocr_det_batch = False
     else:
         enable_ocr_det_batch = True
+
     batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
     results = batch_model(images_with_extra_info)