Pārlūkot izejas kodu

fix: Enhance UNet preprocessing and resizing logic in MinerUWiredTableRecognizer to ensure consistent prediction dimensions and improve debugging information

zhch158_admin 2 dienas atpakaļ
vecāks
revīzija
6b063ced58

+ 52 - 5
ocr_tools/universal_doc_parser/models/adapters/mineru_wired_table.py

@@ -203,15 +203,58 @@ class MinerUWiredTableRecognizer:
                 
                 wired_rec = self.table_model.wired_table_model
                 img_obj = wired_rec.load_img(img_up_)
+                
+                # 手动计算 UNet 预处理后的图像尺寸(模拟 preprocess 中的 resize_img 逻辑)
+                # UNet 使用 scale = (inp_height, inp_width) = (1024, 1024),keep_ratio=True
+                h_up_, w_up_ = img_up_.shape[:2]
+                inp_height = 1024
+                inp_width = 1024
+                max_long_edge = max(inp_height, inp_width)  # 1024
+                max_short_edge = min(inp_height, inp_width)  # 1024
+                # 计算缩放因子(保持长宽比)
+                scale_factor = min(max_long_edge / max(h_up_, w_up_), max_short_edge / min(h_up_, w_up_))
+                preprocessed_w = int(w_up_ * scale_factor + 0.5)
+                preprocessed_h = int(h_up_ * scale_factor + 0.5)
+                
                 img_info = wired_rec.table_structure.preprocess(img_obj)
                 pred_ = wired_rec.table_structure.infer(img_info)
                 
                 hpred_ = np.where(pred_ == 1, 255, 0).astype(np.uint8)
                 vpred_ = np.where(pred_ == 2, 255, 0).astype(np.uint8)
                 
-                h_up_, w_up_ = img_up_.shape[:2]
-                hpred_up_ = cv2.resize(hpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
-                vpred_up_ = cv2.resize(vpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                # 调试:记录尺寸信息
+                pred_h, pred_w = pred_.shape[:2]
+                logger.debug(
+                    f"UNet 推理: 上采样图像尺寸=[{h_up_}, {w_up_}], "
+                    f"预处理后尺寸=[{preprocessed_h}, {preprocessed_w}], "
+                    f"预测结果尺寸=[{pred_h}, {pred_w}], "
+                    f"缩放因子={scale_factor:.6f}, upscale={upscale:.3f}"
+                )
+                
+                # 关键修复:正确地将预测结果 resize 回上采样尺寸
+                # UNet 的 postprocess 使用 ori_shape = img.shape 来 resize 预测结果
+                # 在我们的情况下,img 是 img_up_(上采样后的图像)
+                # 所以我们应该使用 img_up_.shape 来 resize 预测结果
+                # 但是,由于预处理时改变了图像尺寸(保持长宽比),我们需要确保 resize 是正确的
+                
+                # 验证:检查预测结果尺寸是否与预处理后的尺寸一致
+                if pred_h != preprocessed_h or pred_w != preprocessed_w:
+                    logger.warning(
+                        f"⚠️ 预测结果尺寸 [{pred_h}, {pred_w}] 与预处理后尺寸 [{preprocessed_h}, {preprocessed_w}] 不一致!"
+                        f"这可能导致坐标偏移。使用预处理后尺寸进行 resize。"
+                    )
+                    # 如果尺寸不一致,先 resize 到预处理后尺寸,再 resize 到上采样尺寸
+                    # 但实际上,预测结果应该就是预处理后的尺寸,所以这个警告不应该出现
+                    hpred_temp = cv2.resize(hpred_, (preprocessed_w, preprocessed_h), interpolation=cv2.INTER_NEAREST)
+                    vpred_temp = cv2.resize(vpred_, (preprocessed_w, preprocessed_h), interpolation=cv2.INTER_NEAREST)
+                    hpred_up_ = cv2.resize(hpred_temp, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                    vpred_up_ = cv2.resize(vpred_temp, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                else:
+                    # 正常情况:预测结果就是预处理后的尺寸,直接 resize 到上采样尺寸
+                    # 这相当于 UNet postprocess 中的逻辑:cv2.resize(pred, (ori_shape[1], ori_shape[0]))
+                    hpred_up_ = cv2.resize(hpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                    vpred_up_ = cv2.resize(vpred_, (w_up_, h_up_), interpolation=cv2.INTER_NEAREST)
+                
                 return hpred_up_, vpred_up_, img_up_
 
             # Step 1: 首次运行 UNet 获取初步 mask
@@ -258,11 +301,15 @@ class MinerUWiredTableRecognizer:
             if dbg and dbg.enabled and dbg.output_dir:
                 debug_dir = dbg.output_dir
                 debug_prefix = f"{dbg.prefix}_grid" if dbg.prefix else "grid"
-                
+            
+            # 传入原图的实际尺寸,用于计算真实的缩放比例
+            # 这样可以正确处理 UNet 预处理改变图像尺寸的情况
             bboxes = self.grid_recovery.compute_cells_from_lines(
                 hpred_up, 
                 vpred_up, 
                 upscale,
+                orig_h=h,
+                orig_w=w,
                 debug_dir=debug_dir,
                 debug_prefix=debug_prefix
             )
@@ -306,7 +353,7 @@ class MinerUWiredTableRecognizer:
                 texts = self.text_filler.second_pass_ocr_fill(
                     table_image, bboxes_merged, texts, scores, 
                     need_reocr_indices=need_reocr_indices,
-                    force_all=True  # Force Per-Cell OCR
+                    force_all=False  # Force Per-Cell OCR
                 )
 
             for i, cell in enumerate(merged_cells):