from abc import ABC, abstractmethod from typing import Dict, Any, List, Union, Optional, Tuple import numpy as np from PIL import Image from loguru import logger from pathlib import Path import cv2 import json class BaseAdapter(ABC): """基础适配器接口""" def __init__(self, config: Dict[str, Any]): self.config = config @abstractmethod def initialize(self): """初始化模型""" pass @abstractmethod def cleanup(self): """清理资源""" pass class BasePreprocessor(BaseAdapter): """预处理器基类""" def __init__(self, config: Dict[str, Any]): super().__init__(config) # 运行时由 pipeline 按页注入(与 layout_detection 一致) self.debug_mode: Optional[bool] = None self.output_dir: Optional[str] = None self.page_name: Optional[str] = None def _watermark_debug_options(self) -> Dict[str, Any]: wm_cfg = self.config.get('watermark_removal', {}) opts = wm_cfg.get('debug_options', {}) return opts if isinstance(opts, dict) else {} def _is_watermark_debug_enabled(self) -> bool: debug_mode = getattr(self, 'debug_mode', None) if debug_mode is not None: return bool(debug_mode) return bool(self._watermark_debug_options().get('enabled', False)) def _resolve_watermark_debug_paths(self) -> Tuple[Optional[str], str]: output_dir = getattr(self, 'output_dir', None) if output_dir is None: output_dir = self._watermark_debug_options().get('output_dir') page_name = getattr(self, 'page_name', None) if not page_name: page_name = self._watermark_debug_options().get('prefix') or 'watermark' prefix = self._watermark_debug_options().get('prefix', '') if prefix and page_name and not str(page_name).startswith(str(prefix)): page_name = f"{prefix}_{page_name}" return output_dir, str(page_name) def _save_watermark_debug_images( self, before: np.ndarray, after: np.ndarray, threshold: int, morph_close_kernel: int, contrast_cfg: Optional[Dict[str, Any]] = None, ) -> None: """保存水印调试图(委托 ocr_utils.watermark_utils)。""" from ocr_utils.watermark_utils import save_watermark_removal_debug output_dir, page_name = self._resolve_watermark_debug_paths() if not output_dir: return opts = self._watermark_debug_options() params: Dict[str, Any] = { "threshold": threshold, "morph_close_kernel": morph_close_kernel, } if contrast_cfg: params["contrast_enhancement"] = contrast_cfg try: save_watermark_removal_debug( before, after, output_dir, page_name, processing_params=params, image_format=opts.get("image_format") or "png", save_compare=opts.get("save_compare", True), subdir=opts.get("subdir", "watermark_removal"), ) except Exception as e: logger.warning(f"Watermark debug save failed: {e}") def remove_watermark(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray: """页级水印去除(默认无操作,子类可覆盖)。""" if isinstance(image, Image.Image): return np.array(image) return image def _preprocess_order(self) -> str: """预处理步骤顺序:orient_first(默认)| watermark_first。""" order = str(self.config.get('order', 'orient_first')).strip().lower() if order not in ('orient_first', 'watermark_first'): logger.warning( f"Unknown preprocessor.order={order!r}, fallback to orient_first" ) return 'orient_first' return order def correct_orientation( self, image: Union[np.ndarray, Image.Image], *, pdf_rotate_angle: Optional[int] = None, use_orientation_classifier: bool = True, ) -> tuple[np.ndarray, int]: """ 仅方向校正,不去水印。用于表格裁剪等页级已预处理场景。 Args: pdf_rotate_angle: 文字 PDF 页级旋转(逆时针角度,与 pipeline 一致) use_orientation_classifier: 是否使用方向分类器(扫描件为 True) """ if isinstance(image, Image.Image): image = np.array(image) if pdf_rotate_angle: pil_rotated = Image.fromarray(image).rotate(pdf_rotate_angle, expand=True) return np.array(pil_rotated), int(pdf_rotate_angle) return image, 0 def prepare_detection_image( self, image: Union[np.ndarray, Image.Image], *, pdf_rotate_angle: Optional[int] = None, use_orientation_classifier: bool = True, ) -> tuple[np.ndarray, int]: """ 页级完整预处理:按 preprocessor.order 执行方向校正与水印去除。 Returns: (detection_image, rotate_angle) """ if isinstance(image, Image.Image): image = np.array(image) order = self._preprocess_order() def _orient(img: np.ndarray) -> tuple[np.ndarray, int]: return self.correct_orientation( img, pdf_rotate_angle=pdf_rotate_angle, use_orientation_classifier=use_orientation_classifier, ) if order == 'watermark_first': cleaned = self.remove_watermark(image) return _orient(cleaned) oriented, rotate_angle = _orient(image) return self.remove_watermark(oriented), rotate_angle def process( self, image: Union[np.ndarray, Image.Image], skip_watermark: bool = False, ) -> tuple[np.ndarray, int]: """ 裁剪块:仅方向校正(skip_watermark=True)。 页级请使用 prepare_detection_image()。 """ if skip_watermark: return self.correct_orientation(image, use_orientation_classifier=True) return self.prepare_detection_image(image, use_orientation_classifier=True) def _apply_rotation(self, image: np.ndarray, rotation_angle: int) -> np.ndarray: """应用旋转""" import cv2 if rotation_angle == 90: # 90度 return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) elif rotation_angle == 180: # 180度 return cv2.rotate(image, cv2.ROTATE_180) elif rotation_angle == 270: # 270度 return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) return image class BaseLayoutDetector(BaseAdapter): """版式检测器基类""" def __init__(self, config: Dict[str, Any]): """初始化版式检测器 Args: config: 配置字典 """ super().__init__(config) # 初始化 debug 相关属性(支持从配置或运行时设置) self.debug_mode = None # 将在 detect() 方法中从配置读取 self.output_dir = None # 将在 detect() 方法中从配置读取 self.page_name = None # 将在 detect() 方法中从配置读取 def detect( self, image: Union[np.ndarray, Image.Image], ocr_spans: Optional[List[Dict[str, Any]]] = None ) -> List[Dict[str, Any]]: """ 检测版式(模板方法,自动执行后处理) 此方法会: 1. 调用子类实现的 _detect_raw() 进行原始检测 2. 自动执行后处理(去除重叠框、文本转表格等) Args: image: 输入图像 ocr_spans: OCR结果(可选,某些detector可能需要) Returns: 后处理后的布局检测结果 """ # 调用子类实现的原始检测方法 layout_results = self._detect_raw(image, ocr_spans) debug_mode = self._is_layout_debug_enabled() output_dir, page_name = self._resolve_layout_debug_paths() dbg_opts = self._layout_debug_options() if debug_mode: logger.debug( f"Layout detection raw results (before post-processing): " f"{len(layout_results)} elements" ) if output_dir and dbg_opts.get('save_raw', True): self._visualize_layout_results( image, layout_results, output_dir, page_name, suffix='raw' ) # 自动执行后处理 if layout_results: layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {} layout_results = self.post_process(layout_results, image, layout_config) if debug_mode and output_dir and dbg_opts.get('save_post_processed', True): self._visualize_layout_results( image, layout_results, output_dir, page_name, suffix='post' ) return layout_results @abstractmethod def _detect_raw( self, image: Union[np.ndarray, Image.Image], ocr_spans: Optional[List[Dict[str, Any]]] = None ) -> List[Dict[str, Any]]: """ 原始检测方法(子类必须实现) Args: image: 输入图像 ocr_spans: OCR结果(可选) Returns: 原始检测结果(未后处理) """ pass def post_process( self, layout_results: List[Dict[str, Any]], image: Union[np.ndarray, Image.Image], config: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """ 后处理布局检测结果 默认实现包括: 1. 去除重叠框 2. 将大面积文本块转换为表格(如果配置启用) 子类可以重写此方法以自定义后处理逻辑 Args: layout_results: 原始检测结果 image: 输入图像 config: 后处理配置(可选),如果为None则使用self.config中的post_process配置 Returns: 后处理后的布局结果 """ if not layout_results: return layout_results # 获取配置 if config is None: config = self.config.get('post_process', {}) if hasattr(self, 'config') else {} # 导入 CoordinateUtils(适配器可以访问) try: from ocr_utils.coordinate_utils import CoordinateUtils except ImportError: try: from ocr_utils import CoordinateUtils except ImportError: # 如果无法导入,返回原始结果 return layout_results # 1. 去除重叠框 layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils) # 2. 将大面积文本块转换为表格(如果配置启用) layout_config = config if config is not None else {} if layout_config.get('convert_large_text_to_table', False): # 获取图像尺寸 if isinstance(image, Image.Image): h, w = image.size[1], image.size[0] else: h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0) layout_results_converted_large_text = self._convert_large_text_to_table( layout_results_removed_overlapping, (h, w), min_area_ratio=layout_config.get('min_text_area_ratio', 0.25), min_width_ratio=layout_config.get('min_text_width_ratio', 0.4), min_height_ratio=layout_config.get('min_text_height_ratio', 0.3) ) return layout_results_converted_large_text else: return layout_results_removed_overlapping def _convert_large_text_to_table( self, layout_results: List[Dict[str, Any]], image_shape: Tuple[int, int], min_area_ratio: float = 0.25, min_width_ratio: float = 0.4, min_height_ratio: float = 0.3 ) -> List[Dict[str, Any]]: """ 将大面积的文本块转换为表格 判断规则: 1. 面积占比:占页面面积超过 min_area_ratio(默认25%) 2. 尺寸比例:宽度和高度都超过一定比例(避免细长条) 3. 不与其他表格重叠:如果已有表格,不转换 """ if not layout_results: return layout_results img_height, img_width = image_shape img_area = img_height * img_width if img_area == 0: return layout_results # 检查是否已有表格 has_table = any( item.get('category', '').lower() in ['table', 'table_body'] for item in layout_results ) # 如果已有表格,不进行转换(避免误判) if has_table: return layout_results # 复制列表避免修改原数据 results = [item.copy() for item in layout_results] converted_count = 0 for item in results: category = item.get('category', '').lower() # 只处理文本类型的元素 if category not in ['text', 'ocr_text']: continue bbox = item.get('bbox', [0, 0, 0, 0]) if len(bbox) < 4: continue x1, y1, x2, y2 = bbox[:4] width = x2 - x1 height = y2 - y1 area = width * height # 计算占比 area_ratio = area / img_area if img_area > 0 else 0 width_ratio = width / img_width if img_width > 0 else 0 height_ratio = height / img_height if img_height > 0 else 0 # 判断是否满足转换条件 if (area_ratio >= min_area_ratio and width_ratio >= min_width_ratio and height_ratio >= min_height_ratio): # 转换为表格 item['category'] = 'table' item['original_category'] = category # 保留原始类别 converted_count += 1 return results def _map_category_id(self, category_id: int) -> str: """映射类别ID到字符串""" category_map = { 0: 'title', 1: 'text', 2: 'abandon', 3: 'image_body', 4: 'image_caption', 5: 'table_body', 6: 'table_caption', 7: 'table_footnote', 8: 'interline_equation', 9: 'interline_equation_number', 13: 'inline_equation', 14: 'interline_equation_yolo', 15: 'ocr_text', 16: 'low_score_text', 101: 'image_footnote' } return category_map.get(category_id, f'unknown_{category_id}') def _layout_debug_options(self) -> Dict[str, Any]: opts = self.config.get('debug_options', {}) return opts if isinstance(opts, dict) else {} def _is_layout_debug_enabled(self) -> bool: debug_mode = getattr(self, 'debug_mode', None) if debug_mode is not None: return bool(debug_mode) if self.config.get('debug_mode', False): return True return bool(self._layout_debug_options().get('enabled', False)) def _resolve_layout_debug_paths(self) -> Tuple[Optional[str], str]: output_dir = getattr(self, 'output_dir', None) if output_dir is None: output_dir = self.config.get('output_dir') if output_dir is None: output_dir = self._layout_debug_options().get('output_dir') if output_dir is not None: output_dir = str(output_dir) page_name = getattr(self, 'page_name', None) if page_name is None: page_name = self.config.get('page_name') if not page_name: prefix = self._layout_debug_options().get('prefix', '') page_name = prefix if prefix else 'layout_detection' return output_dir, str(page_name) def _visualize_layout_results( self, image: Union[np.ndarray, Image.Image], layout_results: List[Dict[str, Any]], output_dir: str, page_name: str, suffix: str = 'raw', ) -> None: """保存 layout 模块 debug(底图为 inference / detection 输入)。""" if not layout_results: return from ocr_utils.module_debug_viz import save_layout_debug opts = self._layout_debug_options() save_layout_debug( image, layout_results, output_dir, page_name, suffix=suffix, subdir=opts.get('subdir', 'layout_detection'), image_format=opts.get('image_format', 'png'), save_json=bool(opts.get('save_json', True)), ) def _remove_overlapping_boxes( self, layout_results: List[Dict[str, Any]], coordinate_utils: Any, iou_threshold: float = 0.8, overlap_ratio_threshold: float = 0.8 ) -> List[Dict[str, Any]]: """ 改进版重叠框处理算法(基于优先级和决策规则的清晰算法) 策略: 1. 定义类别优先级(abandon < text/image < table_body) 2. 使用统一的决策规则 3. 按综合评分排序处理,优先保留大的聚合框 Args: layout_results: 布局检测结果 coordinate_utils: 坐标工具类 iou_threshold: IoU阈值(默认0.8) overlap_ratio_threshold: 重叠比例阈值(默认0.8) Returns: 去重后的布局结果 """ if not layout_results or len(layout_results) <= 1: return layout_results # 常量定义 CATEGORY_PRIORITY = { 'abandon': 0, 'text': 1, 'image_body': 1, 'title': 2, 'footer': 2, 'header': 2, 'table_body': 3, } AGGREGATE_LABELS = {'key-value region', 'form'} MAX_AREA = 4000000.0 # 用于面积归一化 AREA_WEIGHT = 0.5 CONFIDENCE_WEIGHT = 0.5 AGGREGATE_BONUS = 0.1 AREA_RATIO_THRESHOLD = 3.0 # 大框面积需大于小框的倍数 def get_bbox_area(bbox: List[float]) -> float: """计算bbox面积""" if len(bbox) < 4: return 0.0 return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) def is_aggregate_type(box: Dict[str, Any]) -> bool: """检查是否是聚合类型""" original_label = box.get('raw', {}).get('original_label', '').lower() return original_label in AGGREGATE_LABELS def is_bbox_inside(inner: List[float], outer: List[float]) -> bool: """检查inner是否完全包含在outer内""" if len(inner) < 4 or len(outer) < 4: return False return (inner[0] >= outer[0] and inner[1] >= outer[1] and inner[2] <= outer[2] and inner[3] <= outer[3]) def calculate_composite_score(box: Dict[str, Any], area: float) -> float: """计算text类型的综合评分(面积+置信度)""" if box.get('category') != 'text': return box.get('confidence', box.get('score', 0)) normalized_area = min(area / MAX_AREA, 1.0) area_score = (normalized_area ** 0.5) * AREA_WEIGHT confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0 return area_score + confidence_score + bonus def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any], iou: float, overlap_ratio: float, contained_1_in_2: bool, contained_2_in_1: bool) -> bool: """判断是否应该保留box1""" # 提取基本信息 cat1, cat2 = box1.get('category', ''), box2.get('category', '') score1 = box1.get('confidence', box1.get('score', 0)) score2 = box2.get('confidence', box2.get('score', 0)) bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0]) area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2) is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2) # 规则1: 类别优先级 priority1 = CATEGORY_PRIORITY.get(cat1, 1) priority2 = CATEGORY_PRIORITY.get(cat2, 1) if priority1 != priority2: return priority1 > priority2 # 规则2: 包含关系 + 聚合类型优先 if contained_2_in_1 and is_agg1 and not is_agg2: return True if contained_1_in_2 and is_agg2 and not is_agg1: return False # 规则3: 包含关系 + 面积比例 if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD: return True if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD: return False # 规则4: text类型使用综合评分 if cat1 == 'text' or cat2 == 'text': comp_score1 = calculate_composite_score(box1, area1) comp_score2 = calculate_composite_score(box2, area2) if abs(comp_score1 - comp_score2) > 0.05: return comp_score1 > comp_score2 # 规则5: 置信度比较 if abs(score1 - score2) > 0.1: return score1 > score2 # 规则6: 面积比较 return area1 >= area2 # 主处理逻辑 results = [item.copy() for item in layout_results] need_remove = set() # 按综合评分排序(高分优先) def get_sort_key(i: int) -> float: item = results[i] if item.get('category') == 'text': return -calculate_composite_score(item, get_bbox_area(item.get('bbox', []))) return -item.get('confidence', item.get('score', 0)) sorted_indices = sorted(range(len(results)), key=get_sort_key) # 比较每对框 for idx_i, i in enumerate(sorted_indices): if i in need_remove: continue for idx_j, j in enumerate(sorted_indices): if j == i or j in need_remove or idx_j >= idx_i: continue bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', []) if len(bbox1) < 4 or len(bbox2) < 4: continue # 计算重叠指标 iou = coordinate_utils.calculate_iou(bbox1, bbox2) overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2) contained_1_in_2 = is_bbox_inside(bbox1, bbox2) contained_2_in_1 = is_bbox_inside(bbox2, bbox1) # 检查是否有显著重叠 if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or contained_1_in_2 or contained_2_in_1): continue # 应用决策规则 if should_keep_box1(results[i], results[j], iou, overlap_ratio, contained_1_in_2, contained_2_in_1): need_remove.add(j) else: need_remove.add(i) break return [results[i] for i in range(len(results)) if i not in need_remove] class BaseVLRecognizer(BaseAdapter): """VL识别器基类""" @abstractmethod def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]: """识别表格""" pass @abstractmethod def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]: """识别公式""" pass class BaseOCRRecognizer(BaseAdapter): """OCR识别器基类""" @abstractmethod def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]: """识别文本""" pass @abstractmethod def detect_text_boxes(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]: """ 只检测文本框(不识别文字内容) 子类必须实现此方法。建议使用只运行检测模型的方式(不运行识别模型)以优化性能。 如果无法优化,至少实现一个调用 recognize_text() 的版本作为兜底。 Returns: 文本框列表,每项包含 'bbox', 'poly',可能包含 'confidence' """ pass