浏览代码

feat(增强文本填充逻辑与边界框处理): 更新TextFiller类中的文本填充逻辑,确保在文本为空时返回0分数;新增多个静态方法以处理边界框的面积计算、嵌套框的识别和调试标签生成,提升OCR处理的准确性和可维护性。

zhch158_admin 4 天之前
父节点
当前提交
3099890b65
共有 1 个文件被更改,包括 176 次插入17 次删除
  1. 176 17
      ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

+ 176 - 17
ocr_tools/universal_doc_parser/models/adapters/wired_table/text_filling.py

@@ -192,7 +192,9 @@ class TextFiller:
         if rec_item is None:
             return "", 0.0
         if isinstance(rec_item, tuple) and len(rec_item) >= 2:
-            return str(rec_item[0] or "").strip(), float(rec_item[1] or 0.0)
+            txt = str(rec_item[0] or "").strip()
+            sc = float(rec_item[1] or 0.0)
+            return txt, 0.0 if not txt else sc
         if isinstance(rec_item, list) and len(rec_item) >= 2:
             if isinstance(rec_item[0], (list, tuple, dict)):
                 texts_list: List[str] = []
@@ -210,11 +212,13 @@ class TextFiller:
                         return combined, weighted
                     return combined, sum(scores_list) / len(scores_list)
                 return "", 0.0
-            return str(rec_item[0] or "").strip(), float(rec_item[1] or 0.0)
+            txt = str(rec_item[0] or "").strip()
+            sc = float(rec_item[1] or 0.0)
+            return txt, 0.0 if not txt else sc
         if isinstance(rec_item, dict):
             txt = str(rec_item.get("text") or rec_item.get("label") or "").strip()
             sc = float(rec_item.get("score") or rec_item.get("confidence") or 0.0)
-            return txt, sc
+            return txt, 0.0 if not txt else sc
         return "", 0.0
 
     def _extract_ocr_batch_results(self, rec_res: Any) -> List[Any]:
@@ -366,8 +370,160 @@ class TextFiller:
             return 0.0
         
         return inter_area / ocr_area
-    
-    
+
+    @staticmethod
+    def _bbox_area(bbox: List[float]) -> float:
+        if not bbox or len(bbox) < 4:
+            return 0.0
+        w = bbox[2] - bbox[0]
+        h = bbox[3] - bbox[1]
+        return max(0.0, w * h)
+
+    @staticmethod
+    def _bbox_from_ocr_original_box(box: Dict[str, Any]) -> List[float]:
+        raw = box.get("original_bbox") or box.get("bbox") or []
+        if not raw:
+            return []
+        if len(raw) >= 4 and not isinstance(raw[0], (list, tuple)):
+            return [float(raw[0]), float(raw[1]), float(raw[2]), float(raw[3])]
+        return CoordinateUtils.poly_to_bbox(raw)
+
+    @staticmethod
+    def _is_bbox_mostly_inside(
+        inner: List[float],
+        outer: List[float],
+        *,
+        inside_ratio: float = 0.7,
+    ) -> bool:
+        """inner 面积的大部分落在 outer 内,且 inner 明显小于 outer。"""
+        if not inner or not outer or len(inner) < 4 or len(outer) < 4:
+            return False
+        inner_area = TextFiller._bbox_area(inner)
+        outer_area = TextFiller._bbox_area(outer)
+        if inner_area <= 0 or outer_area <= 0:
+            return False
+        if inner_area >= outer_area * 0.92:
+            return False
+        inter_x1 = max(inner[0], outer[0])
+        inter_y1 = max(inner[1], outer[1])
+        inter_x2 = min(inner[2], outer[2])
+        inter_y2 = min(inner[3], outer[3])
+        if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
+            return False
+        inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
+        return (inter_area / inner_area) >= inside_ratio
+
+    @staticmethod
+    def _ocr_box_debug_tag(box: Dict[str, Any]) -> str:
+        idx = box.get("paddle_bbox_index")
+        idx_part = f"idx={idx}" if idx is not None else "idx=?"
+        bbox = TextFiller._bbox_from_ocr_original_box(box)
+        if bbox and len(bbox) >= 4:
+            bb = ",".join(f"{v:.0f}" for v in bbox[:4])
+            return f"{idx_part} bbox=[{bb}]"
+        return idx_part
+
+    @staticmethod
+    def _resolve_cell_matched_boxes(
+        matched: List[Tuple[str, float, float, float, float, Dict[str, Any]]],
+        *,
+        cell_idx: Optional[int] = None,
+        y_tolerance: int = 5,
+        inside_ratio: float = 0.7,
+    ) -> Tuple[List[Tuple[str, float, float, float, float, Dict[str, Any]]], bool]:
+        """
+        处理同格内嵌套 OCR 框:
+        - 大框有字、小框在内:丢弃小框,保留大框文本;
+        - 大框无字、小框在内:丢弃小框,整格 score 置 0(触发二次 OCR)。
+        """
+        if not matched:
+            return matched, False
+
+        matched.sort(key=lambda x: (round(x[1] / y_tolerance), x[2]))
+
+        entries: List[Dict[str, Any]] = []
+        for text, y1, x1, overlap_ratio, score, original_box in matched:
+            bbox = TextFiller._bbox_from_ocr_original_box(original_box)
+            entries.append(
+                {
+                    "text": text or "",
+                    "y1": y1,
+                    "x1": x1,
+                    "overlap_ratio": overlap_ratio,
+                    "score": score,
+                    "original_box": original_box,
+                    "bbox": bbox,
+                }
+            )
+
+        remove: set = set()
+        force_zero_score = False
+
+        for i, outer_e in enumerate(entries):
+            outer_bbox = outer_e["bbox"]
+            if not outer_bbox:
+                continue
+            outer_text = (outer_e["text"] or "").strip()
+            for j, inner_e in enumerate(entries):
+                if i == j or j in remove:
+                    continue
+                inner_bbox = inner_e["bbox"]
+                if not inner_bbox:
+                    continue
+                if not TextFiller._is_bbox_mostly_inside(
+                    inner_bbox, outer_bbox, inside_ratio=inside_ratio
+                ):
+                    continue
+                inner_text = (inner_e["text"] or "").strip()
+                outer_tag = TextFiller._ocr_box_debug_tag(outer_e["original_box"])
+                inner_tag = TextFiller._ocr_box_debug_tag(inner_e["original_box"])
+                cell_part = f"cell={cell_idx} " if cell_idx is not None else ""
+                if not outer_text:
+                    remove.add(j)
+                    if inner_text:
+                        force_zero_score = True
+                        logger.debug(
+                            f"{cell_part}嵌套 OCR:空大框套小框,丢弃内框并置 score=0 "
+                            f"(outer {outer_tag} text='' | inner {inner_tag} "
+                            f"text={inner_text!r} score={inner_e['score']:.3f})"
+                        )
+                    else:
+                        logger.debug(
+                            f"{cell_part}嵌套 OCR:空大框套空小框,丢弃内框 "
+                            f"(outer {outer_tag} | inner {inner_tag})"
+                        )
+                elif inner_text:
+                    remove.add(j)
+                    logger.debug(
+                        f"{cell_part}嵌套 OCR:有字大框套小框,丢弃内框碎片 "
+                        f"(outer {outer_tag} text={outer_text!r} | inner {inner_tag} "
+                        f"text={inner_text!r} score={inner_e['score']:.3f})"
+                    )
+
+        kept = [e for idx, e in enumerate(entries) if idx not in remove]
+        if remove:
+            removed_texts = [
+                (entries[j]["text"] or "").strip()
+                for j in sorted(remove)
+            ]
+            logger.debug(
+                f"{('cell=' + str(cell_idx) + ' ') if cell_idx is not None else ''}"
+                f"嵌套 OCR 汇总: 移除 {len(remove)} 个小框 {removed_texts!r},"
+                f"保留 {len(kept)} 个框,force_zero_score={force_zero_score}"
+            )
+        resolved = [
+            (
+                e["text"],
+                e["y1"],
+                e["x1"],
+                e["overlap_ratio"],
+                e["score"],
+                e["original_box"],
+            )
+            for e in kept
+        ]
+        return resolved, force_zero_score
+
     def fill_text_by_center_point(
         self,
         bboxes: List[List[float]],
@@ -483,18 +639,21 @@ class TextFiller:
                     ))
             
             if matched:
-                # 直接按 y1 和 x1 排序,确保文本顺序正确
-                # y_tolerance 用于将相近的 y1 归为同一行(容差范围内视为同一行)
-                # 同一行内按 x1 从左到右排序
-                y_tolerance = 5
-                matched.sort(key=lambda x: (round(x[1] / y_tolerance), x[2]))  # 先按 y_group,再按 x1
-                
-                texts[idx] = "".join([t for t, _, _, _, _, _ in matched])
-                # 计算平均置信度
-                avg_score = sum([s for _, _, _, _, s, _ in matched]) / len(matched)
-                scores[idx] = avg_score
-                # 保存匹配到的 OCR boxes
-                matched_boxes_list[idx] = [box for _, _, _, _, _, box in matched]
+                matched, force_zero_score = self._resolve_cell_matched_boxes(
+                    matched, cell_idx=idx
+                )
+                if matched:
+                    texts[idx] = "".join(
+                        [(t or "").strip() for t, _, _, _, _, _ in matched]
+                    )
+                    avg_score = sum(s for _, _, _, _, s, _ in matched) / len(matched)
+                    scores[idx] = 0.0 if force_zero_score else avg_score
+                    matched_boxes_list[idx] = [
+                        box for _, _, _, _, _, box in matched
+                    ]
+                else:
+                    texts[idx] = ""
+                    scores[idx] = 0.0
             else:
                 scores[idx] = 0.0 # 无匹配文本,置信度为0