|
|
@@ -281,28 +281,20 @@ class BatchAnalyze:
|
|
|
|
|
|
# 按分辨率分组并同时完成padding
|
|
|
# RESOLUTION_GROUP_STRIDE = 32
|
|
|
- RESOLUTION_GROUP_STRIDE = 64 # 定义分辨率分组的步进值
|
|
|
+ RESOLUTION_GROUP_STRIDE = 64
|
|
|
|
|
|
resolution_groups = defaultdict(list)
|
|
|
for crop_info in lang_crop_list:
|
|
|
cropped_img = crop_info[0]
|
|
|
h, w = cropped_img.shape[:2]
|
|
|
- # 使用更大的分组容差,减少分组数量
|
|
|
- # 将尺寸标准化到32的倍数
|
|
|
- normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE # 向上取整到32的倍数
|
|
|
- normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
- group_key = (normalized_h, normalized_w)
|
|
|
+ # 直接计算目标尺寸并用作分组键
|
|
|
+ target_h = ((h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
+ target_w = ((w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
+ group_key = (target_h, target_w)
|
|
|
resolution_groups[group_key].append(crop_info)
|
|
|
|
|
|
# 对每个分辨率组进行批处理
|
|
|
- for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
|
|
|
-
|
|
|
- # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
|
|
|
- max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
|
|
|
- max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
|
|
|
- target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
- target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
-
|
|
|
+ for (target_h, target_w), group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
|
|
|
# 对所有图像进行padding到统一尺寸
|
|
|
batch_images = []
|
|
|
for crop_info in group_crops:
|
|
|
@@ -310,49 +302,34 @@ class BatchAnalyze:
|
|
|
h, w = img.shape[:2]
|
|
|
# 创建目标尺寸的白色背景
|
|
|
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
|
|
|
- # 将原图像粘贴到左上角
|
|
|
padded_img[:h, :w] = img
|
|
|
batch_images.append(padded_img)
|
|
|
|
|
|
# 批处理检测
|
|
|
- det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) # 增加批处理大小
|
|
|
- # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}")
|
|
|
+ det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
|
|
|
batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
|
|
|
|
|
|
# 处理批处理结果
|
|
|
- for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
|
|
|
+ for crop_info, (dt_boxes, _) in zip(group_crops, batch_results):
|
|
|
bgr_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
|
|
|
|
|
|
if dt_boxes is not None and len(dt_boxes) > 0:
|
|
|
- # 直接应用原始OCR流程中的关键处理步骤
|
|
|
-
|
|
|
- # 1. 排序检测框
|
|
|
- if len(dt_boxes) > 0:
|
|
|
- dt_boxes_sorted = sorted_boxes(dt_boxes)
|
|
|
- else:
|
|
|
- dt_boxes_sorted = []
|
|
|
-
|
|
|
- # 2. 合并相邻检测框
|
|
|
- if dt_boxes_sorted:
|
|
|
- dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
|
|
|
- else:
|
|
|
- dt_boxes_merged = []
|
|
|
-
|
|
|
- # 3. 根据公式位置更新检测框(关键步骤!)
|
|
|
- if dt_boxes_merged and adjusted_mfdetrec_res:
|
|
|
- dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
|
|
|
- else:
|
|
|
- dt_boxes_final = dt_boxes_merged
|
|
|
-
|
|
|
- # 构造OCR结果格式
|
|
|
- ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
|
|
|
-
|
|
|
- if ocr_res:
|
|
|
+ # 处理检测框
|
|
|
+ dt_boxes_sorted = sorted_boxes(dt_boxes)
|
|
|
+ dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) if dt_boxes_sorted else []
|
|
|
+
|
|
|
+ # 根据公式位置更新检测框
|
|
|
+ dt_boxes_final = (update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
|
|
|
+ if dt_boxes_merged and adjusted_mfdetrec_res
|
|
|
+ else dt_boxes_merged)
|
|
|
+
|
|
|
+ if dt_boxes_final:
|
|
|
+ ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
|
|
|
ocr_result_list = get_ocr_result_list(
|
|
|
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], bgr_image, _lang
|
|
|
)
|
|
|
-
|
|
|
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
|
|
+
|
|
|
else:
|
|
|
# 原始单张处理模式
|
|
|
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
|