Parcourir la source

fix: 增强UNet预处理的缩放因子验证,优化预测结果的尺寸一致性检查,记录详细的调试信息以确保坐标转换的准确性

zhch158_admin il y a 1 jour
Parent
commit
ca720abd31
1 fichiers modifiés avec 115 ajouts et 25 suppressions
  1. 115 25
      ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

+ 115 - 25
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -216,6 +216,39 @@ class MinerUWiredTableRecognizer:
                 preprocessed_w = int(w_up_ * scale_factor + 0.5)
                 preprocessed_h = int(h_up_ * scale_factor + 0.5)
                 
+                # 关键:手动调用 resize_img 获取实际的 w_scale 和 h_scale
+                # 因为 keep_ratio=True 时,w_scale 和 h_scale 可能略有不同
+                try:
+                    from mineru.model.table.rec.unet_table.utils import resize_img
+                    img_preprocessed_test, w_scale_actual, h_scale_actual = resize_img(
+                        img_up_, (inp_height, inp_width), keep_ratio=True
+                    )
+                    # 类型检查:确保是numpy数组
+                    if isinstance(img_preprocessed_test, np.ndarray):
+                        preprocessed_h_actual, preprocessed_w_actual = img_preprocessed_test.shape[:2]
+                    else:
+                        # 如果不是numpy数组,使用计算值
+                        preprocessed_h_actual = preprocessed_h
+                        preprocessed_w_actual = preprocessed_w
+                    scale_diff = abs(w_scale_actual - h_scale_actual)
+                    logger.info(
+                        f"🔍 UNet预处理缩放因子验证: "
+                        f"w_scale={w_scale_actual:.6f}, h_scale={h_scale_actual:.6f}, "
+                        f"差异={scale_diff:.6f}, "
+                        f"预处理后实际尺寸=[{preprocessed_h_actual}, {preprocessed_w_actual}]"
+                    )
+                    if scale_diff > 1e-6:
+                        logger.warning(
+                            f"⚠️ w_scale 和 h_scale 不相等!这可能导致坐标偏移。"
+                            f"w_scale={w_scale_actual:.6f}, h_scale={h_scale_actual:.6f}"
+                        )
+                except Exception as e:
+                    logger.warning(f"无法获取实际缩放因子: {e}")
+                    w_scale_actual = scale_factor
+                    h_scale_actual = scale_factor
+                    preprocessed_h_actual = preprocessed_h
+                    preprocessed_w_actual = preprocessed_w
+                
                 img_info = wired_rec.table_structure.preprocess(img_obj)
                 pred_ = wired_rec.table_structure.infer(img_info)
                 
@@ -224,11 +257,15 @@ class MinerUWiredTableRecognizer:
                 
                 # 调试:记录尺寸信息
                 pred_h, pred_w = pred_.shape[:2]
-                logger.debug(
-                    f"UNet 推理: 上采样图像尺寸=[{h_up_}, {w_up_}], "
-                    f"预处理后尺寸=[{preprocessed_h}, {preprocessed_w}], "
-                    f"预测结果尺寸=[{pred_h}, {pred_w}], "
-                    f"缩放因子={scale_factor:.6f}, upscale={upscale:.3f}"
+                logger.info(
+                    f"🔍 UNet 推理详细日志:\n"
+                    f"  - 上采样图像尺寸: [{h_up_}, {w_up_}]\n"
+                    f"  - 计算预处理后尺寸: [{preprocessed_h}, {preprocessed_w}]\n"
+                    f"  - 实际预处理后尺寸: [{preprocessed_h_actual}, {preprocessed_w_actual}]\n"
+                    f"  - 预测结果尺寸: [{pred_h}, {pred_w}]\n"
+                    f"  - 计算缩放因子: {scale_factor:.6f}\n"
+                    f"  - 实际缩放因子: w_scale={w_scale_actual:.6f}, h_scale={h_scale_actual:.6f}\n"
+                    f"  - upscale: {upscale:.3f}"
                 )
                 
                 # 关键修复:正确地将预测结果 resize 回上采样尺寸
@@ -237,28 +274,76 @@ class MinerUWiredTableRecognizer:
                 # 所以我们应该使用 img_up_.shape 来 resize 预测结果
                 # 但是,由于预处理时改变了图像尺寸(保持长宽比),我们需要确保 resize 是正确的
                 
-                # 验证:检查预测结果尺寸是否与预处理后的尺寸一致
-                if pred_h != preprocessed_h or pred_w != preprocessed_w:
+                # 验证:检查预测结果尺寸是否与预处理后的尺寸一致(仅用于警告)
+                if pred_h != preprocessed_h_actual or pred_w != preprocessed_w_actual:
                     logger.warning(
-                        f"⚠️ 预测结果尺寸 [{pred_h}, {pred_w}] 与预处理后尺寸 [{preprocessed_h}, {preprocessed_w}] 不一致!"
-                        f"这可能导致坐标偏移。使用预处理后尺寸进行 resize。"
+                        f"⚠️ 预测结果尺寸 [{pred_h}, {pred_w}] 与预处理后实际尺寸 "
+                        f"[{preprocessed_h_actual}, {preprocessed_w_actual}] 不一致!"
+                        f"这可能导致坐标偏移。"
+                    )
+                
+                # 修复:统一将预测结果resize回上采样尺寸,避免舍入误差
+                # 理论上:target_size = pred_size / unet_scale ≈ upsampled_size
+                # 但为了确保完全一致,直接使用上采样尺寸作为目标,避免任何舍入误差
+                # 这样可以保证:mask坐标系 = 上采样坐标系,坐标转换链路清晰
+                hpred_up_ = cv2.resize(hpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                vpred_up_ = cv2.resize(vpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                
+                # 记录验证信息:检查理论target尺寸与实际上采样尺寸的差异
+                # 这些差异应该非常小(<2像素),如果差异较大说明UNet预处理有问题
+                if abs(w_scale_actual - h_scale_actual) > 1e-6:
+                    target_w_theoretical = int(pred_w / w_scale_actual + 0.5)
+                    target_h_theoretical = int(pred_h / h_scale_actual + 0.5)
+                    diff_w = abs(target_w_theoretical - w_up_)
+                    diff_h = abs(target_h_theoretical - h_up_)
+                    if diff_w > 2 or diff_h > 2:
+                        logger.warning(
+                            f"⚠️ 理论resize尺寸 [{target_h_theoretical}, {target_w_theoretical}] "
+                            f"与上采样尺寸 [{h_up_}, {w_up_}] 差异较大 (diff=[{diff_h}, {diff_w}])!"
+                            f"w_scale={w_scale_actual:.6f}, h_scale={h_scale_actual:.6f}"
+                        )
+                    else:
+                        logger.debug(
+                            f"✓ 理论resize尺寸 [{target_h_theoretical}, {target_w_theoretical}] "
+                            f"与上采样尺寸 [{h_up_}, {w_up_}] 一致 (diff=[{diff_h}, {diff_w}])"
+                        )
+                
+                # 记录resize后的mask尺寸
+                hpred_up_h, hpred_up_w = hpred_up_.shape[:2]
+                vpred_up_h, vpred_up_w = vpred_up_.shape[:2]
+                logger.info(
+                    f"🔍 Resize后mask尺寸: "
+                    f"hpred_up=[{hpred_up_h}, {hpred_up_w}], "
+                    f"vpred_up=[{vpred_up_h}, {vpred_up_w}], "
+                    f"img_up=[{h_up_}, {w_up_}]"
+                )
+                
+                # 详细的坐标转换链路日志
+                logger.info(
+                    f"🔍 UNet推理完成 - 坐标转换链路验证:\n"
+                    f"  [1] 原图尺寸: [{h}, {w}]\n"
+                    f"  [2] 上采样尺寸: [{h_up_}, {w_up_}] (upscale={upscale:.3f})\n"
+                    f"  [3] UNet输入尺寸: [{pred_h}, {pred_w}] (h_scale={h_scale_actual:.6f}, w_scale={w_scale_actual:.6f})\n"
+                    f"  [4] Mask尺寸: [{hpred_up_h}, {hpred_up_w}] (已resize回上采样尺寸)\n"
+                    f"  验证: 理论upscale = {h_up_ / h:.3f} (h), {w_up_ / w:.3f} (w)"
+                )
+                
+                # 验证mask尺寸是否与上采样图像一致
+                if hpred_up_h != h_up_ or hpred_up_w != w_up_:
+                    logger.error(
+                        f"❌ hpred_up 尺寸 [{hpred_up_h}, {hpred_up_w}] 与上采样图像尺寸 "
+                        f"[{h_up_}, {w_up_}] 不一致!"
+                    )
+                if vpred_up_h != h_up_ or vpred_up_w != w_up_:
+                    logger.error(
+                        f"❌ vpred_up 尺寸 [{vpred_up_h}, {vpred_up_w}] 与上采样图像尺寸 "
+                        f"[{h_up_}, {w_up_}] 不一致!"
                     )
-                    # 如果尺寸不一致,先 resize 到预处理后尺寸,再 resize 到上采样尺寸
-                    # 但实际上,预测结果应该就是预处理后的尺寸,所以这个警告不应该出现
-                    hpred_temp = cv2.resize(hpred_, (preprocessed_w, preprocessed_h), interpolation=cv2.INTER_NEAREST)
-                    vpred_temp = cv2.resize(vpred_, (preprocessed_w, preprocessed_h), interpolation=cv2.INTER_NEAREST)
-                    hpred_up_ = cv2.resize(hpred_temp, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
-                    vpred_up_ = cv2.resize(vpred_temp, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
-                else:
-                    # 正常情况:预测结果就是预处理后的尺寸,直接 resize 到上采样尺寸
-                    # 这相当于 UNet postprocess 中的逻辑:cv2.resize(pred, (ori_shape[1], ori_shape[0]))
-                    hpred_up_ = cv2.resize(hpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
-                    vpred_up_ = cv2.resize(vpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
                 
-                return hpred_up_, vpred_up_, img_up_
+                return hpred_up_, vpred_up_, img_up_, w_scale_actual, h_scale_actual
 
             # Step 1: 首次运行 UNet 获取初步 mask
-            hpred_up, vpred_up, img_up = run_unet(table_image)
+            hpred_up, vpred_up, img_up, w_scale_actual, h_scale_actual = run_unet(table_image)
             
             # Step 1.1: 基于 Mask 的高精度倾斜检测与矫正
             if self.skew_detector.enable_deskew:
@@ -275,7 +360,7 @@ class MinerUWiredTableRecognizer:
                     h, w = table_image.shape[:2]
                     
                     # 重新运行 UNet (确保 mask 与矫正后的图完全对齐)
-                    hpred_up, vpred_up, img_up = run_unet(table_image)
+                    hpred_up, vpred_up, img_up, w_scale_actual, h_scale_actual = run_unet(table_image)
                 else:
                     logger.debug(f"表格倾斜 {skew_angle:.3f}° 小于阈值,无需矫正")
                         
@@ -302,7 +387,7 @@ class MinerUWiredTableRecognizer:
                 debug_dir = dbg.output_dir
                 debug_prefix = f"{dbg.prefix}_grid" if dbg.prefix else "grid"
             
-            # 传入原图的实际尺寸,用于计算真实的缩放比例
+            # 传入原图的实际尺寸和UNet预处理时的缩放因子,用于计算真实的缩放比例
             # 这样可以正确处理 UNet 预处理改变图像尺寸的情况
             bboxes = self.grid_recovery.compute_cells_from_lines(
                 hpred_up, 
@@ -310,6 +395,8 @@ class MinerUWiredTableRecognizer:
                 upscale,
                 orig_h=h,
                 orig_w=w,
+                unet_w_scale=w_scale_actual,
+                unet_h_scale=h_scale_actual,
                 debug_dir=debug_dir,
                 debug_prefix=debug_prefix
             )
@@ -350,10 +437,13 @@ class MinerUWiredTableRecognizer:
             # 策略调整:默认对所有单元格进行 Cropped OCR,以解决 Header 误合并和文本分配错误问题。
             # Full-page OCR 结果仅作为 Fallback(在 text_filling.py 中逻辑是: 如果 Cropped OCR 结果为空或低分,才保留原值)
             if hasattr(self, 'ocr_engine') and self.ocr_engine:
+                # 从 debug_options 中获取输出目录
+                output_dir = dbg.output_dir if dbg and dbg.enabled else None
                 texts = self.text_filler.second_pass_ocr_fill(
                     table_image, bboxes_merged, texts, scores, 
                     need_reocr_indices=need_reocr_indices,
-                    force_all=False  # Force Per-Cell OCR
+                    force_all=False,  # Force Per-Cell OCR
+                    output_dir=output_dir
                 )
 
             for i, cell in enumerate(merged_cells):