Просмотр исходного кода

feat(增强布局路由器和文档管道): 在SmartLayoutRouter中添加布局调试上下文传播功能,优化模型检测流程;在EnhancedDocPipeline中改进页面预处理,注入水印调试上下文,增强OCR调试选项,提升处理灵活性和准确性。

zhch158_admin 5 дней назад
Родитель
Сommit
92b9d902ee

+ 18 - 16
ocr_tools/universal_doc_parser/core/layout_model_router.py

@@ -97,6 +97,16 @@ class SmartLayoutRouter(BaseLayoutDetector):
     def set_scene_name(self, scene_name: Optional[str]):
         """设置场景名称(用于scene策略)"""
         self.scene_name = scene_name
+
+    def _propagate_layout_debug_context(self, model: BaseLayoutDetector) -> None:
+        """将路由器上的 debug 上下文传给子 layout 模型(scene/auto 策略需要)。"""
+        if not self._is_layout_debug_enabled():
+            return
+        model.debug_mode = True  # type: ignore[attr-defined]
+        if self.output_dir:
+            model.output_dir = self.output_dir  # type: ignore[attr-defined]
+        if self.page_name:
+            model.page_name = self.page_name  # type: ignore[attr-defined]
     
     def _detect_raw(
         self, 
@@ -177,7 +187,9 @@ class SmartLayoutRouter(BaseLayoutDetector):
             selected_model = next(iter(self.models.keys()))
 
         logger.info(f"🎯 Scene strategy selected model: {selected_model} (scene: {self.scene_name})")
-        return self.models[selected_model].detect(image)
+        model = self.models[selected_model]
+        self._propagate_layout_debug_context(model)
+        return model.detect(image)
     
     def _ocr_eval_detect(
         self, 
@@ -201,14 +213,8 @@ class SmartLayoutRouter(BaseLayoutDetector):
             if model_name == 'fallback':
                 continue  # 跳过回退模型(除非所有模型都失败)
             try:
-                # 传递 debug 模式配置给子模型(如果启用)
-                if self.debug_mode:
-                    model.debug_mode = self.debug_mode  # type: ignore
-                    if self.output_dir:
-                        model.output_dir = self.output_dir  # type: ignore
-                    if self.page_name:
-                        model.page_name = self.page_name  # type: ignore
-                
+                self._propagate_layout_debug_context(model)
+
                 # 调用 detect() 方法,基类会自动执行后处理
                 results = model.detect(image)
                 all_postprocessed_results[model_name] = results
@@ -221,13 +227,7 @@ class SmartLayoutRouter(BaseLayoutDetector):
             # 如果所有模型都失败,尝试回退模型
             if 'fallback' in self.models:
                 logger.info("🔄 All models failed, using fallback model")
-                # 传递 debug 模式配置给回退模型(如果启用)
-                if self.debug_mode:
-                    self.models['fallback'].debug_mode = self.debug_mode  # type: ignore
-                    if self.output_dir:
-                        self.models['fallback'].output_dir = self.output_dir  # type: ignore
-                    if self.page_name:
-                        self.models['fallback'].page_name = self.page_name  # type: ignore
+                self._propagate_layout_debug_context(self.models['fallback'])
                 # 回退模型使用 detect() 方法(会自动执行后处理)
                 fallback_result = self.models['fallback'].detect(image)
                 return fallback_result
@@ -337,10 +337,12 @@ class SmartLayoutRouter(BaseLayoutDetector):
         # 使用选中的模型进行检测(使用 detect() 方法,会自动执行后处理)
         if selected_model in self.models:
             model = self.models[selected_model]
+            self._propagate_layout_debug_context(model)
             results = model.detect(image)
         else:
             # 回退到第一个可用模型
             first_model = next(iter(self.models.values()))
+            self._propagate_layout_debug_context(first_model)
             results = first_model.detect(image)
         
         return results

+ 91 - 39
ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

@@ -392,50 +392,42 @@ class EnhancedDocPipeline:
             'pdf_type': pdf_type
         }
         
-        # 用于检测的图片(可能被旋转)
-        detection_image = original_image.copy()
         rotate_angle = 0
+        pdf_rotate_angle: Optional[int] = None
+        use_orientation_classifier = pdf_type == 'ocr'
 
-        # 0. 页级水印去除(全页一次;表格裁剪等下游仅做方向校正,避免重复去水印)
-        detection_image = self.preprocessor.remove_watermark(detection_image)
-        
-        # 1. 页面方向识别
-        # rotate_angle统一定义:图像需要逆时针旋转的角度(0/90/180/270)来变为正视
-        if pdf_type == 'ocr':
-            # 扫描件:使用OCR方向识别
-            try:
-                detection_image, rotate_angle = self.preprocessor.process(
-                    detection_image, skip_watermark=True
-                )
-                page_result['angle'] = rotate_angle
-                
-                if rotate_angle != 0:
-                    logger.info(f"📐 Page {page_idx}: rotated {rotate_angle}° for detection")
-            except Exception as e:
-                logger.warning(f"⚠️ Orientation detection failed: {e}")
-        elif pdf_type == 'txt' and pdf_doc is not None:
-            # 文字PDF:获取PDF页面rotation并转换为统一的rotate_angle定义
+        if pdf_type == 'txt' and pdf_doc is not None:
             try:
                 pdf_rotation_angle = PDFUtils.get_page_rotation(pdf_doc, page_idx)
                 if pdf_rotation_angle != 0:
-                    # 转换为OCR定义:图像需要逆时针旋转的角度
-                    # PDF rotation 270° 表示内容逆时针270° = 顺时针90°
-                    # 要恢复正视,需要逆时针90° (即360-270=90)
-                    rotate_angle = (360 - pdf_rotation_angle) % 360
-                    if rotate_angle == 360:
-                        rotate_angle = 0
-                    
-                    # 将图片旋转为正视(使用rotate_angle,逆时针旋转)
-                    from PIL import Image
-                    pil_rotated = Image.fromarray(detection_image).rotate(rotate_angle, expand=True)
-                    detection_image = np.array(pil_rotated)
-                    page_result['angle'] = rotate_angle
-                    logger.info(f"📐 Page {page_idx}: PDF rotation {pdf_rotation_angle}°, rotated image {rotate_angle}° to upright")
+                    pdf_rotate_angle = (360 - pdf_rotation_angle) % 360
+                    if pdf_rotate_angle == 360:
+                        pdf_rotate_angle = 0
+                    if pdf_rotate_angle:
+                        logger.info(
+                            f"📐 Page {page_idx}: PDF rotation {pdf_rotation_angle}°, "
+                            f"will rotate image {pdf_rotate_angle}° to upright"
+                        )
             except Exception as e:
                 logger.warning(f"⚠️ Failed to get PDF rotation: {e}")
 
-        
-        # 2. Layout检测
+        # 0. 页级预处理(方向校正 → 去水印,见 preprocessor.order)
+        self._inject_watermark_debug_context(output_dir, page_name)
+        try:
+            detection_image, rotate_angle = self.preprocessor.prepare_detection_image(
+                original_image.copy(),
+                pdf_rotate_angle=pdf_rotate_angle,
+                use_orientation_classifier=use_orientation_classifier,
+            )
+            page_result['angle'] = rotate_angle
+            page_result['inference_image'] = detection_image
+            if rotate_angle != 0:
+                logger.info(f"📐 Page {page_idx}: detection image upright (rotate {rotate_angle}°)")
+        except Exception as e:
+            logger.warning(f"⚠️ Page preprocessing failed, using original copy: {e}")
+            detection_image = original_image.copy()
+
+        # 1. Layout检测
         try:
             # 如果使用智能路由器且策略是ocr_eval,需要先获取OCR spans(只检测文本框,不识别文字)
             ocr_spans_for_layout = None
@@ -456,12 +448,18 @@ class EnhancedDocPipeline:
                     except Exception as e:
                         logger.warning(f"⚠️ Pre-OCR text box detection for layout evaluation failed: {e}")
             
-            # 注入每页运行时信息(output_dir/page_name 仅在 layout detector 自身 debug 开启时才有意义)
-            if hasattr(self.layout_detector, 'debug_mode') and self.layout_detector.debug_mode:  # type: ignore
-                if output_dir and hasattr(self.layout_detector, 'output_dir'):
+            # 注入每页运行时信息(SmartLayoutRouter scene 策略需传到子模型)
+            layout_dbg = (
+                getattr(self.layout_detector, '_is_layout_debug_enabled', None)
+                and self.layout_detector._is_layout_debug_enabled()  # type: ignore
+            )
+            if layout_dbg and hasattr(self.layout_detector, 'output_dir'):
+                if output_dir:
                     self.layout_detector.output_dir = output_dir  # type: ignore
                 if page_name and hasattr(self.layout_detector, 'page_name'):
                     self.layout_detector.page_name = page_name  # type: ignore
+                if hasattr(self.layout_detector, 'debug_mode'):
+                    self.layout_detector.debug_mode = True  # type: ignore
             
             # 调用layout检测(传递OCR spans如果可用)
             if ocr_spans_for_layout is not None and hasattr(self.layout_detector, 'detect'):
@@ -543,6 +541,9 @@ class EnhancedDocPipeline:
                 all_ocr_spans = SpanMatcher.remove_duplicate_spans(all_ocr_spans)
                 all_ocr_spans = self._sort_spans_by_position(all_ocr_spans)
                 logger.info(f"📝 Page {page_idx}: OCR detected {len(all_ocr_spans)} text spans")
+                self._save_page_ocr_debug_if_enabled(
+                    detection_image, all_ocr_spans, output_dir, page_name
+                )
             except Exception as e:
                 logger.warning(f"⚠️ Full-page OCR failed: {e}")                
             # 3.1 调试模式:对比 OCR 和 PDF 提取结果
@@ -608,6 +609,57 @@ class EnhancedDocPipeline:
         page_result['discarded_blocks'] = sorted_discarded
         return page_result
 
+    def _is_page_ocr_debug_enabled(self) -> bool:
+        opts = self.config.get('ocr_recognition', {}).get('debug_options', {})
+        return isinstance(opts, dict) and bool(opts.get('enabled', False))
+
+    def _save_page_ocr_debug_if_enabled(
+        self,
+        image: np.ndarray,
+        spans: List[Dict[str, Any]],
+        output_dir: Optional[str],
+        page_name: Optional[str],
+    ) -> None:
+        """整页 OCR 完成后保存 module debug(底图=inference_image,与 layout 一致)。"""
+        if not self._is_page_ocr_debug_enabled() or not output_dir or not page_name:
+            return
+        from ocr_utils.module_debug_viz import save_ocr_debug
+
+        opts = self.config.get('ocr_recognition', {}).get('debug_options', {})
+        if not isinstance(opts, dict):
+            opts = {}
+        save_ocr_debug(
+            image,
+            spans,
+            output_dir,
+            page_name,
+            subdir=opts.get('subdir', 'ocr_recognition'),
+            image_format=opts.get('image_format', 'png'),
+            save_json=bool(opts.get('save_json', True)),
+        )
+
+    def _inject_watermark_debug_context(
+        self,
+        output_dir: Optional[str],
+        page_name: Optional[str],
+    ) -> None:
+        """按页注入水印 debug 输出路径(与 layout_detection 一致)。"""
+        pre = self.preprocessor
+        if pre is None or not hasattr(pre, '_is_watermark_debug_enabled'):
+            return
+        wm_opts = (
+            self.config.get('preprocessor', {})
+            .get('watermark_removal', {})
+            .get('debug_options', {})
+        )
+        if not isinstance(wm_opts, dict) or not wm_opts.get('enabled', False):
+            return
+        if output_dir:
+            pre.output_dir = output_dir  # type: ignore[attr-defined]
+        if page_name:
+            pre.page_name = page_name  # type: ignore[attr-defined]
+        pre.debug_mode = True  # type: ignore[attr-defined]
+
     @staticmethod
     def _convert_pdf_blocks_to_spans(
         pdf_text_blocks: List[Dict[str, Any]],