Kaynağa Gözat

feat(新增印章补充检测器): 在SmartLayoutRouter类中添加seal补充检测功能,初始化PP-DocLayoutV3模型以提升印章区域的识别能力,并实现结果合并与调试信息保存,优化印章检测流程。

zhch158_admin 1 ay önce
ebeveyn
işleme
6e1b1bead4

+ 218 - 3
ocr_tools/universal_doc_parser/core/layout_model_router.py

@@ -45,6 +45,9 @@ class SmartLayoutRouter(BaseLayoutDetector):
         self.page_name = None  # 将在 detect 方法中设置
         # 分数差距阈值:当模型间分数差距小于此值时,优先选择 docling
         self.score_diff_threshold = config.get('score_diff_threshold', 0.05)
+        # seal 补充检测配置
+        self.seal_supplement_config = config.get('seal_supplement', {})
+        self.seal_detector = None  # PP-DocLayoutV3 用于 seal 补充检测
         
     def initialize(self):
         """初始化所有模型"""
@@ -87,6 +90,28 @@ class SmartLayoutRouter(BaseLayoutDetector):
         
         if not self.models:
             raise RuntimeError("No layout models available")
+        
+        # 初始化 seal 补充检测器(PP-DocLayoutV3)
+        if self.seal_supplement_config.get('enabled', False):
+            try:
+                seal_model_config = self.seal_supplement_config.get('model_config', {})
+                if not seal_model_config:
+                    # 尝试从 model_configs 中查找 PP-DocLayoutV3
+                    for model_name, model_config in self.model_configs.items():
+                        if model_config.get('model_name') == 'PP-DocLayoutV3':
+                            seal_model_config = model_config
+                            break
+                if seal_model_config:
+                    logger.info(f"🔧 Initializing seal supplement detector: PP-DocLayoutV3")
+                    self.seal_detector = ModelFactory.create_layout_detector(
+                        _merge_child_model_config(seal_model_config)
+                    )
+                    logger.info(f"✅ Seal supplement detector initialized")
+                else:
+                    logger.warning(f"⚠️ Seal supplement enabled but no PP-DocLayoutV3 model config found")
+            except Exception as e:
+                logger.warning(f"⚠️ Failed to initialize seal supplement detector: {e}")
+                self.seal_detector = None
     
     def cleanup(self):
         """清理所有模型资源"""
@@ -96,6 +121,12 @@ class SmartLayoutRouter(BaseLayoutDetector):
             except Exception as e:
                 logger.warning(f"⚠️ Failed to cleanup {model_name}: {e}")
         self.models.clear()
+        if self.seal_detector is not None:
+            try:
+                self.seal_detector.cleanup()
+            except Exception as e:
+                logger.warning(f"⚠️ Failed to cleanup seal detector: {e}")
+            self.seal_detector = None
     
     def set_ocr_recognizer(self, ocr_recognizer):
         """设置OCR识别器(用于ocr_eval策略)"""
@@ -160,14 +191,31 @@ class SmartLayoutRouter(BaseLayoutDetector):
         if page_name is not None:
             self.page_name = page_name
         
+        results = []
         if self.strategy == 'ocr_eval':
-            return self._ocr_eval_detect(image, ocr_spans)
+            results = self._ocr_eval_detect(image, ocr_spans)
         elif self.strategy == 'auto':
-            return self._auto_select_detect(image)
+            results = self._auto_select_detect(image)
         elif self.strategy == 'scene':
-            return self._scene_select_detect(image)
+            results = self._scene_select_detect(image)
         else:
             raise ValueError(f"Unknown strategy: {self.strategy}")
+        
+        # 补充 seal 检测结果(如果启用)
+        seal_supplement_applied = (
+            self.seal_supplement_config.get('enabled', False)
+            and self.seal_detector is not None
+        )
+        if seal_supplement_applied:
+            primary_results = results
+            results = self._supplement_seal_detections(image, results)
+            # 子模型 detect() 已在 supplement 前写出 layout_post(仅主模型、无 seal)
+            if self._is_layout_debug_enabled():
+                self._save_router_layout_debug(image, primary_results, suffix='post_primary')
+            # 覆盖 layout_post,与 pipeline 实际使用的 layout(含 seal 补充)一致
+            self._save_router_layout_debug(image, results, suffix='post')
+
+        return results
 
     def _scene_select_detect(
         self,
@@ -356,6 +404,173 @@ class SmartLayoutRouter(BaseLayoutDetector):
             results = first_model.detect(image)
         
         return results
+
+    def _save_router_layout_debug(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        layout_results: List[Dict[str, Any]],
+        suffix: str,
+    ) -> None:
+        """在 SmartLayoutRouter 层写出 layout debug(含 seal 补充后的最终结果)。"""
+        if not self._is_layout_debug_enabled() or not layout_results:
+            return
+        output_dir, page_name = self._resolve_layout_debug_paths()
+        dbg_opts = self._layout_debug_options()
+        if output_dir and dbg_opts.get('save_post_processed', True):
+            self._visualize_layout_results(
+                image, layout_results, output_dir, page_name, suffix=suffix
+            )
+
+    def _run_seal_detector(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """运行 seal 补充检测器,不写出子模型 layout debug。"""
+        seal_det = self.seal_detector
+        if seal_det is None:
+            return []
+
+        prev_debug_mode = getattr(seal_det, 'debug_mode', None)
+        prev_debug_opts: Optional[Dict[str, Any]] = None
+        if hasattr(seal_det, 'config') and isinstance(seal_det.config, dict):
+            opts = seal_det.config.get('debug_options')
+            if isinstance(opts, dict):
+                prev_debug_opts = opts.copy()
+                opts['enabled'] = False
+        seal_det.debug_mode = False  # type: ignore[attr-defined]
+
+        try:
+            if hasattr(seal_det, '_detect_raw'):
+                raw = seal_det._detect_raw(image)
+                pp_config = (
+                    seal_det.config.get('post_process', {})
+                    if hasattr(seal_det, 'config')
+                    else {}
+                )
+                return seal_det.post_process(raw, image, pp_config)
+            return seal_det.detect(image)
+        finally:
+            seal_det.debug_mode = prev_debug_mode  # type: ignore[attr-defined]
+            if prev_debug_opts is not None and hasattr(seal_det, 'config'):
+                seal_det.config['debug_options'] = prev_debug_opts
+    
+    # 主模型常把印章误标为 image;补充 seal 与高 IoU 重叠时应替换而非丢弃
+    _SEAL_REPLACEABLE_CATEGORIES = frozenset({
+        'image_body', 'image', 'figure', 'abandon', 'discarded',
+    })
+
+    @staticmethod
+    def _bbox_iou(box_a: List[float], box_b: List[float]) -> float:
+        xa = max(box_a[0], box_b[0])
+        ya = max(box_a[1], box_b[1])
+        xb = min(box_a[2], box_b[2])
+        yb = min(box_a[3], box_b[3])
+        inter = max(0, xb - xa) * max(0, yb - ya)
+        area_a = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
+        area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
+        union = area_a + area_b - inter
+        return inter / union if union > 0 else 0.0
+
+    def _supplement_seal_detections(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        existing_results: List[Dict[str, Any]]
+    ) -> List[Dict[str, Any]]:
+        """
+        使用 PP-DocLayoutV3 补充检测印章区域,将 seal 结果合并到主模型输出
+        
+        策略:
+        1. 运行 PP-DocLayoutV3,仅保留 category == 'seal'
+        2. replace_existing=true:丢弃主结果中已有 seal,全部采用补充模型 seal
+        3. 默认:若 seal 与主结果中 image_body/image 等 IoU >= replace_iou_threshold,
+           将该框**替换**为 seal(解决主模型把章标成 image 导致补充 seal 被去重丢弃)
+        4. 否则 IoU > duplicate_iou_threshold 视为重复跳过;否则追加新 seal
+        
+        Args:
+            image: 输入图像
+            existing_results: 主模型的 layout 检测结果
+            
+        Returns:
+            合并 seal 检测后的结果列表
+        """
+        try:
+            seal_results = self._run_seal_detector(image)
+            seal_only_items = [item for item in seal_results if item.get('category') == 'seal']
+            
+            if not seal_only_items:
+                logger.info("🔖 Seal supplement: no seal detected by PP-DocLayoutV3")
+                return existing_results
+
+            if self.seal_supplement_config.get('replace_existing', False):
+                logger.info("🔖 Seal supplement: replacing existing seal detections with PP-DocLayoutV3 results")
+                result = [item for item in existing_results if item.get('category') != 'seal']
+                result.extend(seal_only_items)
+                return result
+
+            replace_image = self.seal_supplement_config.get('replace_overlapping_image', True)
+            replace_iou_threshold = float(
+                self.seal_supplement_config.get('replace_iou_threshold', 0.7)
+            )
+            duplicate_iou_threshold = float(
+                self.seal_supplement_config.get('duplicate_iou_threshold', 0.3)
+            )
+
+            merged = list(existing_results)
+            replaced_count = 0
+            added_count = 0
+            skipped_duplicate = 0
+
+            for seal_item in seal_only_items:
+                seal_bbox = seal_item.get('bbox', [])
+                if not seal_bbox or len(seal_bbox) < 4:
+                    continue
+                seal_bbox = seal_bbox[:4]
+
+                if replace_image:
+                    best_idx = -1
+                    best_iou = 0.0
+                    for idx, existing in enumerate(merged):
+                        if existing.get('category') not in self._SEAL_REPLACEABLE_CATEGORIES:
+                            continue
+                        existing_bbox = existing.get('bbox', [])
+                        if not existing_bbox or len(existing_bbox) < 4:
+                            continue
+                        overlap = self._bbox_iou(seal_bbox, existing_bbox[:4])
+                        if overlap >= replace_iou_threshold and overlap > best_iou:
+                            best_iou = overlap
+                            best_idx = idx
+                    if best_idx >= 0:
+                        old_cat = merged[best_idx].get('category', '')
+                        new_item = dict(seal_item)
+                        new_item['category'] = 'seal'
+                        merged[best_idx] = new_item
+                        replaced_count += 1
+                        logger.debug(
+                            f"🔖 Seal supplement: replaced {old_cat} with seal "
+                            f"(IoU={best_iou:.3f}, bbox={seal_bbox})"
+                        )
+                        continue
+
+                is_duplicate = False
+                for existing in merged:
+                    existing_bbox = existing.get('bbox', [])
+                    if existing_bbox and len(existing_bbox) >= 4:
+                        if self._bbox_iou(seal_bbox, existing_bbox[:4]) > duplicate_iou_threshold:
+                            is_duplicate = True
+                            break
+                if is_duplicate:
+                    skipped_duplicate += 1
+                else:
+                    merged.append(dict(seal_item))
+                    added_count += 1
+
+            logger.info(
+                f"🔖 Seal supplement: PP-DocLayoutV3 seal={len(seal_only_items)}, "
+                f"replaced={replaced_count}, added={added_count}, "
+                f"skipped_duplicate={skipped_duplicate}"
+            )
+            return merged
+            
+        except Exception as e:
+            logger.warning(f"⚠️ Seal supplement failed: {e}")
+            return existing_results
     
     def _get_ocr_spans(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
         """