Procházet zdrojové kódy

feat(增强图像预处理): 在BasePreprocessor类中添加水印调试选项和图像处理顺序配置,优化方向校正和水印去除流程,提升OCR处理的灵活性和准确性。

zhch158_admin před 5 dny
rodič
revize
1c67a0d785
1 změnil soubory, kde provedl 197 přidání a 158 odebrání
  1. 197 158
      ocr_tools/universal_doc_parser/models/adapters/base.py

+ 197 - 158
ocr_tools/universal_doc_parser/models/adapters/base.py

@@ -25,6 +25,72 @@ class BaseAdapter(ABC):
 
 class BasePreprocessor(BaseAdapter):
     """预处理器基类"""
+
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        # 运行时由 pipeline 按页注入(与 layout_detection 一致)
+        self.debug_mode: Optional[bool] = None
+        self.output_dir: Optional[str] = None
+        self.page_name: Optional[str] = None
+
+    def _watermark_debug_options(self) -> Dict[str, Any]:
+        wm_cfg = self.config.get('watermark_removal', {})
+        opts = wm_cfg.get('debug_options', {})
+        return opts if isinstance(opts, dict) else {}
+
+    def _is_watermark_debug_enabled(self) -> bool:
+        debug_mode = getattr(self, 'debug_mode', None)
+        if debug_mode is not None:
+            return bool(debug_mode)
+        return bool(self._watermark_debug_options().get('enabled', False))
+
+    def _resolve_watermark_debug_paths(self) -> Tuple[Optional[str], str]:
+        output_dir = getattr(self, 'output_dir', None)
+        if output_dir is None:
+            output_dir = self._watermark_debug_options().get('output_dir')
+        page_name = getattr(self, 'page_name', None)
+        if not page_name:
+            page_name = self._watermark_debug_options().get('prefix') or 'watermark'
+        prefix = self._watermark_debug_options().get('prefix', '')
+        if prefix and page_name and not str(page_name).startswith(str(prefix)):
+            page_name = f"{prefix}_{page_name}"
+        return output_dir, str(page_name)
+
+    def _save_watermark_debug_images(
+        self,
+        before: np.ndarray,
+        after: np.ndarray,
+        threshold: int,
+        morph_close_kernel: int,
+        contrast_cfg: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """保存水印调试图(委托 ocr_utils.watermark_utils)。"""
+        from ocr_utils.watermark_utils import save_watermark_removal_debug
+
+        output_dir, page_name = self._resolve_watermark_debug_paths()
+        if not output_dir:
+            return
+
+        opts = self._watermark_debug_options()
+        params: Dict[str, Any] = {
+            "threshold": threshold,
+            "morph_close_kernel": morph_close_kernel,
+        }
+        if contrast_cfg:
+            params["contrast_enhancement"] = contrast_cfg
+
+        try:
+            save_watermark_removal_debug(
+                before,
+                after,
+                output_dir,
+                page_name,
+                processing_params=params,
+                image_format=opts.get("image_format") or "png",
+                save_compare=opts.get("save_compare", True),
+            )
+        except Exception as e:
+            logger.warning(f"Watermark debug save failed: {e}")
     
     def remove_watermark(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
         """页级水印去除(默认无操作,子类可覆盖)。"""
@@ -32,17 +98,82 @@ class BasePreprocessor(BaseAdapter):
             return np.array(image)
         return image
 
-    @abstractmethod
+    def _preprocess_order(self) -> str:
+        """预处理步骤顺序:orient_first(默认)| watermark_first。"""
+        order = str(self.config.get('order', 'orient_first')).strip().lower()
+        if order not in ('orient_first', 'watermark_first'):
+            logger.warning(
+                f"Unknown preprocessor.order={order!r}, fallback to orient_first"
+            )
+            return 'orient_first'
+        return order
+
+    def correct_orientation(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        *,
+        pdf_rotate_angle: Optional[int] = None,
+        use_orientation_classifier: bool = True,
+    ) -> tuple[np.ndarray, int]:
+        """
+        仅方向校正,不去水印。用于表格裁剪等页级已预处理场景。
+
+        Args:
+            pdf_rotate_angle: 文字 PDF 页级旋转(逆时针角度,与 pipeline 一致)
+            use_orientation_classifier: 是否使用方向分类器(扫描件为 True)
+        """
+        if isinstance(image, Image.Image):
+            image = np.array(image)
+
+        if pdf_rotate_angle:
+            pil_rotated = Image.fromarray(image).rotate(pdf_rotate_angle, expand=True)
+            return np.array(pil_rotated), int(pdf_rotate_angle)
+        return image, 0
+
+    def prepare_detection_image(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        *,
+        pdf_rotate_angle: Optional[int] = None,
+        use_orientation_classifier: bool = True,
+    ) -> tuple[np.ndarray, int]:
+        """
+        页级完整预处理:按 preprocessor.order 执行方向校正与水印去除。
+
+        Returns:
+            (detection_image, rotate_angle)
+        """
+        if isinstance(image, Image.Image):
+            image = np.array(image)
+
+        order = self._preprocess_order()
+
+        def _orient(img: np.ndarray) -> tuple[np.ndarray, int]:
+            return self.correct_orientation(
+                img,
+                pdf_rotate_angle=pdf_rotate_angle,
+                use_orientation_classifier=use_orientation_classifier,
+            )
+
+        if order == 'watermark_first':
+            cleaned = self.remove_watermark(image)
+            return _orient(cleaned)
+
+        oriented, rotate_angle = _orient(image)
+        return self.remove_watermark(oriented), rotate_angle
+
     def process(
         self,
         image: Union[np.ndarray, Image.Image],
         skip_watermark: bool = False,
     ) -> tuple[np.ndarray, int]:
         """
-        处理图像
-        返回处理后的图像和旋转角度
+        裁剪块:仅方向校正(skip_watermark=True)。
+        页级请使用 prepare_detection_image()。
         """
-        pass
+        if skip_watermark:
+            return self.correct_orientation(image, use_orientation_classifier=True)
+        return self.prepare_detection_image(image, use_orientation_classifier=True)
     
     def _apply_rotation(self, image: np.ndarray, rotation_angle: int) -> np.ndarray:
         """应用旋转"""
@@ -92,63 +223,30 @@ class BaseLayoutDetector(BaseAdapter):
         # 调用子类实现的原始检测方法
         layout_results = self._detect_raw(image, ocr_spans)
         
-        # Debug 模式:打印和可视化后处理前的检测结果
-        # 优先从实例属性读取(如果存在),否则从配置读取
-        # 支持两种配置方式:debug_mode 或 debug_options.enabled
-        debug_mode = getattr(self, 'debug_mode', None)
-        if debug_mode is None:
-            if hasattr(self, 'config'):
-                # 优先从 debug_mode 读取
-                debug_mode = self.config.get('debug_mode', False)
-                # 如果没有 debug_mode,尝试从 debug_options.enabled 读取
-                if not debug_mode:
-                    debug_options = self.config.get('debug_options', {})
-                    if isinstance(debug_options, dict):
-                        debug_mode = debug_options.get('enabled', False)
-            else:
-                debug_mode = False
-        
+        debug_mode = self._is_layout_debug_enabled()
+        output_dir, page_name = self._resolve_layout_debug_paths()
+        dbg_opts = self._layout_debug_options()
+
         if debug_mode:
-            logger.debug(f"🔍 Layout detection raw results (before post-processing): {len(layout_results)} elements")
-            # logger.debug(f"Raw layout_results: {layout_results}")
-            # 可视化 layout 结果
-            output_dir = getattr(self, 'output_dir', None)
-            if output_dir is None:
-                if hasattr(self, 'config'):
-                    # 优先从 output_dir 读取
-                    output_dir = self.config.get('output_dir', None)
-                    # 如果没有 output_dir,尝试从 debug_options.output_dir 读取
-                    if output_dir is None:
-                        debug_options = self.config.get('debug_options', {})
-                        if isinstance(debug_options, dict):
-                            output_dir = debug_options.get('output_dir', None)
-                else:
-                    output_dir = None
-            
-            page_name = getattr(self, 'page_name', None)
-            if page_name is None:
-                if hasattr(self, 'config'):
-                    # 优先从 page_name 读取
-                    page_name = self.config.get('page_name', None)
-                    # 如果没有 page_name,尝试从 debug_options.prefix 读取
-                    if page_name is None:
-                        debug_options = self.config.get('debug_options', {})
-                        if isinstance(debug_options, dict):
-                            prefix = debug_options.get('prefix', '')
-                            page_name = prefix if prefix else 'layout_detection'
-                    if page_name is None:
-                        page_name = 'layout_detection'
-                else:
-                    page_name = 'layout_detection'
-            
-            if output_dir:
-                self._visualize_layout_results(image, layout_results, output_dir, page_name, suffix='raw')
-        
+            logger.debug(
+                f"Layout detection raw results (before post-processing): "
+                f"{len(layout_results)} elements"
+            )
+            if output_dir and dbg_opts.get('save_raw', True):
+                self._visualize_layout_results(
+                    image, layout_results, output_dir, page_name, suffix='raw'
+                )
+
         # 自动执行后处理
         if layout_results:
             layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
             layout_results = self.post_process(layout_results, image, layout_config)
-        
+
+        if debug_mode and output_dir and dbg_opts.get('save_post_processed', True):
+            self._visualize_layout_results(
+                image, layout_results, output_dir, page_name, suffix='post'
+            )
+
         return layout_results
     
     @abstractmethod
@@ -325,116 +423,57 @@ class BaseLayoutDetector(BaseAdapter):
         }
         return category_map.get(category_id, f'unknown_{category_id}')
     
+    def _layout_debug_options(self) -> Dict[str, Any]:
+        opts = self.config.get('debug_options', {})
+        return opts if isinstance(opts, dict) else {}
+
+    def _is_layout_debug_enabled(self) -> bool:
+        debug_mode = getattr(self, 'debug_mode', None)
+        if debug_mode is not None:
+            return bool(debug_mode)
+        if self.config.get('debug_mode', False):
+            return True
+        return bool(self._layout_debug_options().get('enabled', False))
+
+    def _resolve_layout_debug_paths(self) -> Tuple[Optional[str], str]:
+        output_dir = getattr(self, 'output_dir', None)
+        if output_dir is None:
+            output_dir = self.config.get('output_dir')
+        if output_dir is None:
+            output_dir = self._layout_debug_options().get('output_dir')
+
+        page_name = getattr(self, 'page_name', None)
+        if page_name is None:
+            page_name = self.config.get('page_name')
+        if not page_name:
+            prefix = self._layout_debug_options().get('prefix', '')
+            page_name = prefix if prefix else 'layout_detection'
+        return output_dir, str(page_name)
+
     def _visualize_layout_results(
         self,
         image: Union[np.ndarray, Image.Image],
         layout_results: List[Dict[str, Any]],
         output_dir: str,
         page_name: str,
-        suffix: str = 'raw'
+        suffix: str = 'raw',
     ) -> None:
-        """
-        可视化 layout 检测结果
-        
-        Args:
-            image: 输入图像
-            layout_results: 布局检测结果
-            output_dir: 输出目录
-            page_name: 页面名称
-            suffix: 文件名后缀(如 'raw', 'postprocessed')
-        """
+        """保存 layout 模块 debug(底图为 inference / detection 输入)。"""
         if not layout_results:
             return
-        
-        try:
-            # 转换为 numpy 数组
-            if isinstance(image, Image.Image):
-                vis_image = np.array(image)
-                if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
-                    # PIL RGB -> OpenCV BGR
-                    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
-            else:
-                vis_image = image.copy()
-                if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
-                    # 如果是 RGB,转换为 BGR
-                    vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
-            
-            # 定义类别颜色映射 (BGR格式)
-            category_colors = {
-                'table_body': (0, 0, 255),      # 红色
-                'table_caption': (0, 0, 200),   # 暗红色
-                'table_footnote': (0, 0, 150),  # 更暗的红色
-                'text': (255, 0, 0),            # 蓝色
-                'title': (0, 255, 255),         # 黄色
-                'header': (255, 0, 255),        # 紫色
-                'footer': (0, 165, 255),        # 橙色
-                'image_body': (0, 255, 0),      # 绿色
-                'image_caption': (0, 200, 0),   # 暗绿色
-                'image_footnote': (0, 150, 0),  # 更暗的绿色
-                'abandon': (128, 128, 128),     # 灰色
-            }
-            
-            # 绘制检测框
-            for result in layout_results:
-                bbox = result.get('bbox', [])
-                if not bbox or len(bbox) < 4:
-                    continue
-                
-                category = result.get('category', 'unknown')
-                color = category_colors.get(category, (128, 128, 128))  # 默认灰色
-                thickness = 2
-                
-                x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
-                cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, thickness)
-                
-                # 添加类别标签
-                label = f"{category}"
-                confidence = result.get('confidence', result.get('score', 0))
-                if confidence:
-                    label += f":{confidence:.2f}"
-                
-                # 计算文本大小
-                font = cv2.FONT_HERSHEY_SIMPLEX
-                font_scale = 0.4
-                text_thickness = 1
-                (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, text_thickness)
-                
-                # 在框的上方绘制文本背景
-                text_y = max(y1 - baseline - 1, text_height + baseline)
-                cv2.rectangle(vis_image, (x1, text_y - text_height - baseline - 2), 
-                            (x1 + text_width, text_y), color, -1)
-                # 绘制文本
-                cv2.putText(vis_image, label, (x1, text_y - baseline - 1), 
-                          font, font_scale, (255, 255, 255), text_thickness)
-            
-            # 保存图像
-            debug_dir = Path(output_dir) / "debug_comparison" / "layout_detection"
-            debug_dir.mkdir(parents=True, exist_ok=True)
-            output_path = debug_dir / f"{page_name}_layout_{suffix}.jpg"
-            cv2.imwrite(str(output_path), vis_image)
-            logger.info(f"📊 Saved layout detection image ({suffix}): {output_path}")
-            
-            # 保存 JSON 数据
-            json_data = {
-                'page_name': page_name,
-                'suffix': suffix,
-                'count': len(layout_results),
-                'results': [
-                    {
-                        'category': r.get('category'),
-                        'bbox': r.get('bbox'),
-                        'confidence': r.get('confidence', r.get('score', 0.0))
-                    }
-                    for r in layout_results
-                ]
-            }
-            json_path = debug_dir / f"{page_name}_layout_{suffix}.json"
-            with open(json_path, 'w', encoding='utf-8') as f:
-                json.dump(json_data, f, ensure_ascii=False, indent=2)
-            logger.info(f"📊 Saved layout detection JSON ({suffix}): {json_path}")
-            
-        except Exception as e:
-            logger.warning(f"⚠️ Failed to visualize layout results: {e}")
+        from ocr_utils.module_debug_viz import save_layout_debug
+
+        opts = self._layout_debug_options()
+        save_layout_debug(
+            image,
+            layout_results,
+            output_dir,
+            page_name,
+            suffix=suffix,
+            subdir=opts.get('subdir', 'layout_detection'),
+            image_format=opts.get('image_format', 'jpg'),
+            save_json=bool(opts.get('save_json', True)),
+        )
     
     def _remove_overlapping_boxes(
         self,
@@ -615,7 +654,7 @@ class BaseVLRecognizer(BaseAdapter):
 
 class BaseOCRRecognizer(BaseAdapter):
     """OCR识别器基类"""
-    
+
     @abstractmethod
     def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
         """识别文本"""