Selaa lähdekoodia

refactor: streamline resolution grouping and padding logic in batch_analyze.py

myhloli 2 viikkoa sitten
vanhempi
commit
d975836b25
1 muutettua tiedostoa jossa 20 lisäystä ja 43 poistoa
  1. 20 43
      mineru/backend/pipeline/batch_analyze.py

+ 20 - 43
mineru/backend/pipeline/batch_analyze.py

@@ -281,28 +281,20 @@ class BatchAnalyze:
 
                 # 按分辨率分组并同时完成padding
                 # RESOLUTION_GROUP_STRIDE = 32
-                RESOLUTION_GROUP_STRIDE = 64  # 定义分辨率分组的步进值
+                RESOLUTION_GROUP_STRIDE = 64
 
                 resolution_groups = defaultdict(list)
                 for crop_info in lang_crop_list:
                     cropped_img = crop_info[0]
                     h, w = cropped_img.shape[:2]
-                    # 使用更大的分组容差,减少分组数量
-                    # 将尺寸标准化到32的倍数
-                    normalized_h = ((h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE  # 向上取整到32的倍数
-                    normalized_w = ((w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
-                    group_key = (normalized_h, normalized_w)
+                    # 直接计算目标尺寸并用作分组键
+                    target_h = ((h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+                    target_w = ((w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+                    group_key = (target_h, target_w)
                     resolution_groups[group_key].append(crop_info)
 
                 # 对每个分辨率组进行批处理
-                for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
-
-                    # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
-                    max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
-                    max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
-                    target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
-                    target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
-
+                for (target_h, target_w), group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
                     # 对所有图像进行padding到统一尺寸
                     batch_images = []
                     for crop_info in group_crops:
@@ -310,49 +302,34 @@ class BatchAnalyze:
                         h, w = img.shape[:2]
                         # 创建目标尺寸的白色背景
                         padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
-                        # 将原图像粘贴到左上角
                         padded_img[:h, :w] = img
                         batch_images.append(padded_img)
 
                     # 批处理检测
-                    det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)  # 增加批处理大小
-                    # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}")
+                    det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
                     batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
 
                     # 处理批处理结果
-                    for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
+                    for crop_info, (dt_boxes, _) in zip(group_crops, batch_results):
                         bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
 
                         if dt_boxes is not None and len(dt_boxes) > 0:
-                            # 直接应用原始OCR流程中的关键处理步骤
-
-                            # 1. 排序检测框
-                            if len(dt_boxes) > 0:
-                                dt_boxes_sorted = sorted_boxes(dt_boxes)
-                            else:
-                                dt_boxes_sorted = []
-
-                            # 2. 合并相邻检测框
-                            if dt_boxes_sorted:
-                                dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
-                            else:
-                                dt_boxes_merged = []
-
-                            # 3. 根据公式位置更新检测框(关键步骤!)
-                            if dt_boxes_merged and adjusted_mfdetrec_res:
-                                dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
-                            else:
-                                dt_boxes_final = dt_boxes_merged
-
-                            # 构造OCR结果格式
-                            ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
-
-                            if ocr_res:
+                            # 处理检测框
+                            dt_boxes_sorted = sorted_boxes(dt_boxes)
+                            dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) if dt_boxes_sorted else []
+
+                            # 根据公式位置更新检测框
+                            dt_boxes_final = (update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
+                                              if dt_boxes_merged and adjusted_mfdetrec_res
+                                              else dt_boxes_merged)
+
+                            if dt_boxes_final:
+                                ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
                                 ocr_result_list = get_ocr_result_list(
                                     ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], bgr_image, _lang
                                 )
-
                                 ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+
         else:
             # 原始单张处理模式
             for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):