Selaa lähdekoodia

feat(TextFiller): 添加动态置信度阈值计算方法,优化OCR结果处理

zhch158_admin 2 viikkoa sitten
vanhempi
commit
9f84fef765

+ 52 - 6
ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

@@ -28,7 +28,7 @@ class TextFiller:
         """
         self.ocr_engine = ocr_engine
         self.cell_crop_margin: int = config.get("cell_crop_margin", 2)
-        self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.9)  # 单元格 OCR 置信度阈值
+        self.ocr_conf_threshold: float = config.get("ocr_conf_threshold", 0.9)  # 单元格 OCR 置信度阈值(基准值)
         
         # 跨单元格检测配置参数
         self.overlap_threshold_horizontal: float = config.get("overlap_threshold_horizontal", 0.2)
@@ -38,6 +38,45 @@ class TextFiller:
         self.other_cell_max_ratio: float = config.get("other_cell_max_ratio", 0.3)
     
     @staticmethod
+    def calculate_dynamic_confidence_threshold(text: str, base_threshold: float = 0.9) -> float:
+        """
+        根据文本长度动态计算置信度阈值
+        
+        策略:
+        - 单字符:使用较高阈值(避免误识别,如"1"误识别为"l")
+        - 短文本(2-3字符):使用中等阈值
+        - 中等长度(4-10字符):使用基准阈值
+        - 长文本(10+字符):使用较低阈值(长文本整体可靠性更高)
+        
+        Args:
+            text: 识别的文本
+            base_threshold: 基准置信度阈值(默认0.9)
+            
+        Returns:
+            动态调整后的置信度阈值
+        """
+        if not text:
+            return base_threshold
+        
+        text_len = len(text.strip())
+        
+        if text_len == 1:
+            # 单字符:提高阈值 +0.05
+            return min(0.95, base_threshold + 0.1)
+        elif text_len <= 3:
+            # 2-3字符:轻微提高阈值 +0.02
+            return min(0.92, base_threshold + 0.02)
+        elif text_len <= 10:
+            # 4-10字符:使用基准阈值
+            return max(0.85, base_threshold - 0.05)
+        elif text_len <= 20:
+            # 11-20字符:降低阈值 -0.03
+            return max(0.80, base_threshold - 0.1)
+        else:
+            # 20+字符:显著降低阈值 -0.05
+            return max(0.75, base_threshold - 0.15)
+    
+    @staticmethod
     def calculate_overlap_ratio(ocr_bbox: List[float], cell_bbox: List[float]) -> float:
         """
         计算 OCR box 与单元格的重叠比例(重叠面积 / OCR box 面积)
@@ -608,7 +647,7 @@ class TextFiller:
 
             # 对齐长度,避免越界
             n = min(len(results) if isinstance(results, list) else 0, len(crop_list), len(crop_indices))
-            conf_th = self.ocr_conf_threshold
+            base_conf_th = self.ocr_conf_threshold
 
             # 辅助函数:清理文件名中的非法字符
             def sanitize_filename(text: str, max_length: int = 50) -> str:
@@ -642,10 +681,17 @@ class TextFiller:
                     except Exception as e:
                         logger.warning(f"保存单元格OCR图片失败 (cell {cell_idx}): {e}")
                 
-                if text_k and score_k >= conf_th:
-                    texts[cell_idx] = text_k
-                elif text_k:
-                    logger.debug(f"单元格 {cell_idx} 二次OCR结果置信度({score_k:.2f})低于阈值({conf_th}): (文本: '{text_k[:30]}...')")
+                if text_k:
+                    # 根据文本长度动态调整置信度阈值
+                    dynamic_conf_th = self.calculate_dynamic_confidence_threshold(text_k, base_conf_th)
+                    
+                    if score_k >= dynamic_conf_th:
+                        texts[cell_idx] = text_k
+                    else:
+                        logger.debug(
+                            f"单元格 {cell_idx} 二次OCR结果置信度({score_k:.2f})低于动态阈值({dynamic_conf_th:.2f}) "
+                            f"[文本长度={len(text_k)}, 基准阈值={base_conf_th:.2f}]: '{text_k[:30]}...'"
+                        )
 
         except Exception as e:
             logger.warning(f"二次OCR失败: {e}")