Sfoglia il codice sorgente

refactor: improve batch processing logic and enhance OCR result handling

myhloli 5 mesi fa
parent
commit
0039d11378
1 ha cambiato i file con 34 aggiunte e 23 eliminazioni
  1. 34 23
      mineru/backend/pipeline/batch_analyze.py

+ 34 - 23
mineru/backend/pipeline/batch_analyze.py

@@ -14,7 +14,7 @@ MFR_BASE_BATCH_SIZE = 16
 
 
 class BatchAnalyze:
-    def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = False):
+    def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
         self.batch_ratio = batch_ratio
         self.formula_enable = formula_enable
         self.table_enable = table_enable
@@ -150,17 +150,17 @@ class BatchAnalyze:
 
                 # 对每个分辨率组进行批处理
                 for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
-                    raw_images = [crop_info[0] for crop_info in group_crops]
 
                     # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
-                    max_h = max(img.shape[0] for img in raw_images)
-                    max_w = max(img.shape[1] for img in raw_images)
+                    max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
+                    max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
                     target_h = ((max_h + 32 - 1) // 32) * 32
                     target_w = ((max_w + 32 - 1) // 32) * 32
 
                     # 对所有图像进行padding到统一尺寸
                     batch_images = []
-                    for img in raw_images:
+                    for crop_info in group_crops:
+                        img = crop_info[0]
                         h, w = img.shape[:2]
                         # 创建目标尺寸的白色背景
                         padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
@@ -177,28 +177,38 @@ class BatchAnalyze:
                     for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
                         new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
 
-                        if dt_boxes is not None:
-                            # 构造OCR结果格式 - 每个box应该是4个点的列表
-                            ocr_res = [box.tolist() for box in dt_boxes]
+                        if dt_boxes is not None and len(dt_boxes) > 0:
+                            # 直接应用原始OCR流程中的关键处理步骤
+                            from mineru.utils.ocr_utils import (
+                                merge_det_boxes, update_det_boxes, sorted_boxes
+                            )
+
+                            # 1. 排序检测框
+                            if len(dt_boxes) > 0:
+                                dt_boxes_sorted = sorted_boxes(dt_boxes)
+                            else:
+                                dt_boxes_sorted = []
+
+                            # 2. 合并相邻检测框
+                            if dt_boxes_sorted:
+                                dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
+                            else:
+                                dt_boxes_merged = []
+
+                            # 3. 根据公式位置更新检测框(关键步骤!)
+                            if dt_boxes_merged and adjusted_mfdetrec_res:
+                                dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
+                            else:
+                                dt_boxes_final = dt_boxes_merged
+
+                            # 构造OCR结果格式
+                            ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
 
                             if ocr_res:
                                 ocr_result_list = get_ocr_result_list(
                                     ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
                                 )
 
-                                if res["category_id"] == 3:
-                                    # ocr_result_list中所有bbox的面积之和
-                                    ocr_res_area = sum(
-                                        get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
-                                    # 求ocr_res_area和res的面积的比值
-                                    res_area = get_coords_and_area(res)[4]
-                                    if res_area > 0:
-                                        ratio = ocr_res_area / res_area
-                                        if ratio > 0.25:
-                                            res["category_id"] = 1
-                                        else:
-                                            continue
-
                                 ocr_res_list_dict['layout_res'].extend(ocr_result_list)
         else:
             # 原始单张处理模式
@@ -227,8 +237,9 @@ class BatchAnalyze:
 
                     # Integration results
                     if ocr_res:
-                        ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
-                                                              new_image, _lang)
+                        ocr_result_list = get_ocr_result_list(
+                            ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
+                        )
 
                         ocr_res_list_dict['layout_res'].extend(ocr_result_list)