Переглянути джерело

feat(优化文本填充与OCR识别逻辑): 更新TextFiller类,新增短文本最小字符配置,重构识别逻辑以支持更灵活的文本解析和分数归一化,优化单元格对比度调整与增强功能,提升OCR处理的准确性与灵活性。

zhch158_admin 3 днів тому
батько
коміт
df98998bd5

+ 107 - 23
ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

@@ -63,6 +63,7 @@ class TextFiller:
         self.second_pass_row_peer_min_nonempty: int = int(
             sp_cfg.get("row_peer_min_nonempty", 5)
         )
+        _short_min = sp_cfg.get("suspicious_short_min_chars")
         cpp = sp_cfg.get("cell_preprocess") or {}
         if not isinstance(cpp, dict):
             cpp = {}
@@ -70,16 +71,18 @@ class TextFiller:
         if not isinstance(light, dict):
             light = {}
         self.second_pass_light_upscale_min: int = int(
-            light.get("upscale_min_side", 64)
+            light.get("upscale_min_side", 192)
         )
-        er = cpp.get("enhance_retry") or {}
+        er = sp_cfg.get("enhance_retry") or cpp.get("enhance_retry") or {}
         if not isinstance(er, dict):
             er = {}
+        if _short_min is None:
+            _short_min = er.get("min_chars", 4)
+        self.second_pass_suspicious_short_min_chars: int = int(_short_min)
         self.second_pass_enhance_retry_enabled: bool = bool(er.get("enabled", True))
         self.second_pass_enhance_score_below: float = float(
             er.get("score_below", 0.90)
         )
-        self.second_pass_enhance_min_chars: int = int(er.get("min_chars", 4))
         self.second_pass_enhance_short_tall: bool = bool(
             er.get("short_text_in_tall_cell", True)
         )
@@ -101,7 +104,7 @@ class TextFiller:
         denoise = cpp.get("denoise") or {}
         if not isinstance(denoise, dict):
             denoise = {}
-        self._cell_denoise_enabled: bool = bool(denoise.get("enabled", True))
+        self._cell_denoise_enabled: bool = bool(denoise.get("enabled", False))
         self._cell_denoise_method: str = str(denoise.get("method", "median"))
         cell_contrast = cpp.get("contrast") or {}
         if not isinstance(cell_contrast, dict):
@@ -245,12 +248,40 @@ class TextFiller:
         return x1, y1, x2, y2
 
     @staticmethod
+    def _normalize_rec_score(score: float) -> float:
+        """识别分归一化到 [0,1];部分引擎返回 0~100。"""
+        try:
+            sc = float(score)
+        except (TypeError, ValueError):
+            return 0.0
+        if sc != sc:  # NaN
+            return 0.0
+        if sc > 1.0:
+            if sc <= 100.0:
+                return sc / 100.0
+            return 0.0
+        if sc < 0.0:
+            return 0.0
+        return sc
+
+    @staticmethod
+    def _parse_det_rec_item(item: Any) -> Tuple[str, float]:
+        """解析 det+rec 一体结果的一项:[[box], (text, score)]。"""
+        if item is None:
+            return "", 0.0
+        if isinstance(item, (list, tuple)) and len(item) >= 2:
+            head = item[0]
+            if isinstance(head, (list, tuple)) and len(head) >= 4:
+                return TextFiller._parse_single_rec_item(item[1])
+        return TextFiller._parse_single_rec_item(item)
+
+    @staticmethod
     def _parse_single_rec_item(rec_item: Any) -> Tuple[str, float]:
         if rec_item is None:
             return "", 0.0
         if isinstance(rec_item, tuple) and len(rec_item) >= 2:
             txt = str(rec_item[0] or "").strip()
-            sc = float(rec_item[1] or 0.0)
+            sc = TextFiller._normalize_rec_score(float(rec_item[1] or 0.0))
             return txt, 0.0 if not txt else sc
         if isinstance(rec_item, list) and len(rec_item) >= 2:
             if isinstance(rec_item[0], (list, tuple, dict)):
@@ -266,15 +297,19 @@ class TextFiller:
                     total_len = sum(len(t) for t in texts_list)
                     if total_len > 0:
                         weighted = sum(len(t) * s for t, s in zip(texts_list, scores_list)) / total_len
-                        return combined, weighted
-                    return combined, sum(scores_list) / len(scores_list)
+                        return combined, TextFiller._normalize_rec_score(weighted)
+                    return combined, TextFiller._normalize_rec_score(
+                        sum(scores_list) / len(scores_list)
+                    )
                 return "", 0.0
             txt = str(rec_item[0] or "").strip()
-            sc = float(rec_item[1] or 0.0)
+            sc = TextFiller._normalize_rec_score(float(rec_item[1] or 0.0))
             return txt, 0.0 if not txt else sc
         if isinstance(rec_item, dict):
             txt = str(rec_item.get("text") or rec_item.get("label") or "").strip()
-            sc = float(rec_item.get("score") or rec_item.get("confidence") or 0.0)
+            sc = TextFiller._normalize_rec_score(
+                float(rec_item.get("score") or rec_item.get("confidence") or 0.0)
+            )
             return txt, 0.0 if not txt else sc
         return "", 0.0
 
@@ -293,7 +328,18 @@ class TextFiller:
             items = self._extract_ocr_batch_results(rec_res)
             if not items:
                 return "", 0.0
-            return self._parse_single_rec_item(items[0] if len(items) == 1 else items)
+            blocks: List[Tuple[str, float]] = []
+            for item in items:
+                text, score = self._parse_det_rec_item(item)
+                if text:
+                    blocks.append((text, score))
+            if not blocks:
+                return "", 0.0
+            return self.aggregate_line_ocr(
+                blocks,
+                line_min_score=0.0,
+                drop_low_score_blocks=False,
+            )
         except Exception as e:
             logger.warning(f"整格 OCR 失败: {e}")
             return "", 0.0
@@ -418,7 +464,11 @@ class TextFiller:
         return cell_img
 
     def _apply_cell_contrast(
-        self, cell_img: np.ndarray, contrast_cfg: Dict[str, Any]
+        self,
+        cell_img: np.ndarray,
+        contrast_cfg: Dict[str, Any],
+        *,
+        sharpen_cfg: Optional[Dict[str, Any]] = None,
     ) -> np.ndarray:
         from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
 
@@ -429,8 +479,9 @@ class TextFiller:
         else:
             gray = cell_img
         gray = apply_contrast_enhancement_config(gray, contrast_cfg)
-        if self.second_pass_enhance_sharpen.get("enabled", False):
-            amount = float(self.second_pass_enhance_sharpen.get("amount", 0.3))
+        sharpen = sharpen_cfg or {}
+        if sharpen.get("enabled", False):
+            amount = float(sharpen.get("amount", 0.3))
             blurred = cv2.GaussianBlur(gray, (0, 0), 1.0)
             gray = cv2.addWeighted(gray, 1.0 + amount, blurred, -amount, 0)
         if cell_img.ndim == 3:
@@ -451,12 +502,18 @@ class TextFiller:
             img = self._denoise_cell(img)
             stages.append("denoise")
 
-        if mode == "enhance":
+        if mode == "light":
+            if self._cell_contrast_cfg.get("enabled", False) and "wm" in stages:
+                img = self._apply_cell_contrast(img, self._cell_contrast_cfg)
+                stages.append("contrast")
+        elif mode == "enhance":
             contrast_cfg = self.second_pass_enhance_contrast
             if self._cell_contrast_cfg.get("enabled", False):
                 contrast_cfg = self._cell_contrast_cfg
             if contrast_cfg.get("enabled", False) and "wm" in stages:
-                img = self._apply_cell_contrast(img, contrast_cfg)
+                img = self._apply_cell_contrast(
+                    img, contrast_cfg, sharpen_cfg=self.second_pass_enhance_sharpen
+                )
                 stages.append("contrast")
 
         img = self._upscale_cell_if_small(img)
@@ -473,10 +530,18 @@ class TextFiller:
         strip_score: float = 0.0,
     ) -> Tuple[str, float, str]:
         """返回 (text, score, strategy)。"""
+        line_score = self._normalize_rec_score(line_score)
+        whole_score = self._normalize_rec_score(whole_score)
+        strip_score = self._normalize_rec_score(strip_score)
+
         candidates: List[Tuple[str, float, str]] = []
         if line_text:
             candidates.append((line_text, line_score, "lines"))
-        if whole_text and self.second_pass_whole_fallback:
+        if (
+            whole_text
+            and self.second_pass_whole_fallback
+            and 0.0 < whole_score <= 1.0
+        ):
             candidates.append((whole_text, whole_score, "whole"))
         if strip_text:
             candidates.append((strip_text, strip_score, "strip"))
@@ -487,6 +552,7 @@ class TextFiller:
         if (
             whole_text
             and line_text
+            and 0.0 < whole_score <= 1.0
             and line_score > whole_score
             and len(whole_text) >= len(line_text) + self.second_pass_whole_longer_extra
             and len(whole_text) > len(line_text)
@@ -567,7 +633,7 @@ class TextFiller:
         if (
             line_text
             and line_score >= base_conf_th
-            and len(line_text) < self.second_pass_enhance_min_chars
+            and len(line_text) < self.second_pass_suspicious_short_min_chars
         ):
             return True
         return False
@@ -587,7 +653,7 @@ class TextFiller:
             reasons.append("not_accepted")
         if score < self.second_pass_enhance_score_below:
             reasons.append("score_below_threshold")
-        if text and len(text) < self.second_pass_enhance_min_chars:
+        if text and len(text) < self.second_pass_suspicious_short_min_chars:
             reasons.append("suspicious_short_text")
         h, w = cell_img.shape[:2]
         if (
@@ -595,7 +661,7 @@ class TextFiller:
             and w > 0
             and h / w >= self.second_pass_strip_aspect
             and len(result.get("lines") or []) <= 1
-            and len(text) < self.second_pass_enhance_min_chars + 2
+            and len(text) < self.second_pass_suspicious_short_min_chars + 2
         ):
             reasons.append("tall_cell_single_line")
         return bool(reasons), reasons
@@ -620,7 +686,7 @@ class TextFiller:
             whole_text, whole_score = self._recognize_whole_cell(cell_img)
             whole_skipped = None
         elif line_text and line_score >= base_conf_th:
-            if len(line_text) < self.second_pass_enhance_min_chars:
+            if len(line_text) < self.second_pass_suspicious_short_min_chars:
                 whole_skipped = "short_text_high_score"
             else:
                 whole_skipped = "line_score>=%.2f" % base_conf_th
@@ -757,6 +823,7 @@ class TextFiller:
         debug_img: np.ndarray,
         result: Dict[str, Any],
         *,
+        raw_img: Optional[np.ndarray] = None,
         first_pass_text: str = "",
         first_pass_score: float = 0.0,
         trigger_reasons: Optional[List[str]] = None,
@@ -769,15 +836,31 @@ class TextFiller:
         if pass_label:
             stem += f"_{pass_label}"
         stem += f"_{strategy}_{tag}"
-        png_path = os.path.join(cell_ocr_dir, f"{stem}.png")
+        preprocessed_name = f"{stem}.png"
+        preprocessed_path = os.path.join(cell_ocr_dir, preprocessed_name)
         try:
-            cv2.imwrite(png_path, debug_img)
+            cv2.imwrite(preprocessed_path, debug_img)
         except Exception as e:
             logger.warning(f"保存单元格OCR图片失败 (cell {cell_idx}): {e}")
             return
+
+        raw_name: Optional[str] = None
+        if raw_img is not None and raw_img.size > 0:
+            raw_name = f"{stem}_raw.png"
+            raw_path = os.path.join(cell_ocr_dir, raw_name)
+            try:
+                cv2.imwrite(raw_path, raw_img)
+            except Exception as e:
+                logger.warning(f"保存单元格原图失败 (cell {cell_idx}): {e}")
+                raw_name = None
+
         payload = {
             "cell_idx": cell_idx,
             "bbox": bbox,
+            "debug_images": {
+                "raw": raw_name,
+                "preprocessed": preprocessed_name,
+            },
             "first_pass": {"text": first_pass_text, "score": first_pass_score},
             "trigger_reason": trigger_reasons or [],
             "lines": result.get("lines") or [],
@@ -828,7 +911,7 @@ class TextFiller:
         
         if text_len == 1:
             # 单字符:提高阈值 +0.05
-            return min(0.95, base_threshold + 0.1)
+            return min(0.92, base_threshold + 0.1)
         elif text_len <= 3:
             # 2-3字符:轻微提高阈值 +0.02
             return min(0.92, base_threshold + 0.02)
@@ -1456,6 +1539,7 @@ class TextFiller:
                         cell_idx,
                         debug_img,
                         result,
+                        raw_img=raw_crop,
                         first_pass_text=fp_text,
                         first_pass_score=fp_score,
                         trigger_reasons=trigger_reasons,