| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674 |
- from abc import ABC, abstractmethod
- from typing import Dict, Any, List, Union, Optional, Tuple
- import numpy as np
- from PIL import Image
- from loguru import logger
- from pathlib import Path
- import cv2
- import json
- class BaseAdapter(ABC):
- """基础适配器接口"""
-
- def __init__(self, config: Dict[str, Any]):
- self.config = config
-
- @abstractmethod
- def initialize(self):
- """初始化模型"""
- pass
-
- @abstractmethod
- def cleanup(self):
- """清理资源"""
- pass
- class BasePreprocessor(BaseAdapter):
- """预处理器基类"""
- def __init__(self, config: Dict[str, Any]):
- super().__init__(config)
- # 运行时由 pipeline 按页注入(与 layout_detection 一致)
- self.debug_mode: Optional[bool] = None
- self.output_dir: Optional[str] = None
- self.page_name: Optional[str] = None
- def _watermark_debug_options(self) -> Dict[str, Any]:
- wm_cfg = self.config.get('watermark_removal', {})
- opts = wm_cfg.get('debug_options', {})
- return opts if isinstance(opts, dict) else {}
- def _is_watermark_debug_enabled(self) -> bool:
- debug_mode = getattr(self, 'debug_mode', None)
- if debug_mode is not None:
- return bool(debug_mode)
- return bool(self._watermark_debug_options().get('enabled', False))
- def _resolve_watermark_debug_paths(self) -> Tuple[Optional[str], str]:
- output_dir = getattr(self, 'output_dir', None)
- if output_dir is None:
- output_dir = self._watermark_debug_options().get('output_dir')
- page_name = getattr(self, 'page_name', None)
- if not page_name:
- page_name = self._watermark_debug_options().get('prefix') or 'watermark'
- prefix = self._watermark_debug_options().get('prefix', '')
- if prefix and page_name and not str(page_name).startswith(str(prefix)):
- page_name = f"{prefix}_{page_name}"
- return output_dir, str(page_name)
- def _save_watermark_debug_images(
- self,
- before: np.ndarray,
- after: np.ndarray,
- threshold: int,
- morph_close_kernel: int,
- contrast_cfg: Optional[Dict[str, Any]] = None,
- ) -> None:
- """保存水印调试图(委托 ocr_utils.watermark_utils)。"""
- from ocr_utils.watermark_utils import save_watermark_removal_debug
- output_dir, page_name = self._resolve_watermark_debug_paths()
- if not output_dir:
- return
- opts = self._watermark_debug_options()
- params: Dict[str, Any] = {
- "threshold": threshold,
- "morph_close_kernel": morph_close_kernel,
- }
- if contrast_cfg:
- params["contrast_enhancement"] = contrast_cfg
- try:
- save_watermark_removal_debug(
- before,
- after,
- output_dir,
- page_name,
- processing_params=params,
- image_format=opts.get("image_format") or "png",
- save_compare=opts.get("save_compare", True),
- )
- except Exception as e:
- logger.warning(f"Watermark debug save failed: {e}")
-
- def remove_watermark(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
- """页级水印去除(默认无操作,子类可覆盖)。"""
- if isinstance(image, Image.Image):
- return np.array(image)
- return image
- def _preprocess_order(self) -> str:
- """预处理步骤顺序:orient_first(默认)| watermark_first。"""
- order = str(self.config.get('order', 'orient_first')).strip().lower()
- if order not in ('orient_first', 'watermark_first'):
- logger.warning(
- f"Unknown preprocessor.order={order!r}, fallback to orient_first"
- )
- return 'orient_first'
- return order
- def correct_orientation(
- self,
- image: Union[np.ndarray, Image.Image],
- *,
- pdf_rotate_angle: Optional[int] = None,
- use_orientation_classifier: bool = True,
- ) -> tuple[np.ndarray, int]:
- """
- 仅方向校正,不去水印。用于表格裁剪等页级已预处理场景。
- Args:
- pdf_rotate_angle: 文字 PDF 页级旋转(逆时针角度,与 pipeline 一致)
- use_orientation_classifier: 是否使用方向分类器(扫描件为 True)
- """
- if isinstance(image, Image.Image):
- image = np.array(image)
- if pdf_rotate_angle:
- pil_rotated = Image.fromarray(image).rotate(pdf_rotate_angle, expand=True)
- return np.array(pil_rotated), int(pdf_rotate_angle)
- return image, 0
- def prepare_detection_image(
- self,
- image: Union[np.ndarray, Image.Image],
- *,
- pdf_rotate_angle: Optional[int] = None,
- use_orientation_classifier: bool = True,
- ) -> tuple[np.ndarray, int]:
- """
- 页级完整预处理:按 preprocessor.order 执行方向校正与水印去除。
- Returns:
- (detection_image, rotate_angle)
- """
- if isinstance(image, Image.Image):
- image = np.array(image)
- order = self._preprocess_order()
- def _orient(img: np.ndarray) -> tuple[np.ndarray, int]:
- return self.correct_orientation(
- img,
- pdf_rotate_angle=pdf_rotate_angle,
- use_orientation_classifier=use_orientation_classifier,
- )
- if order == 'watermark_first':
- cleaned = self.remove_watermark(image)
- return _orient(cleaned)
- oriented, rotate_angle = _orient(image)
- return self.remove_watermark(oriented), rotate_angle
- def process(
- self,
- image: Union[np.ndarray, Image.Image],
- skip_watermark: bool = False,
- ) -> tuple[np.ndarray, int]:
- """
- 裁剪块:仅方向校正(skip_watermark=True)。
- 页级请使用 prepare_detection_image()。
- """
- if skip_watermark:
- return self.correct_orientation(image, use_orientation_classifier=True)
- return self.prepare_detection_image(image, use_orientation_classifier=True)
-
- def _apply_rotation(self, image: np.ndarray, rotation_angle: int) -> np.ndarray:
- """应用旋转"""
- import cv2
- if rotation_angle == 90: # 90度
- return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
- elif rotation_angle == 180: # 180度
- return cv2.rotate(image, cv2.ROTATE_180)
- elif rotation_angle == 270: # 270度
- return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
- return image
- class BaseLayoutDetector(BaseAdapter):
- """版式检测器基类"""
-
- def __init__(self, config: Dict[str, Any]):
- """初始化版式检测器
-
- Args:
- config: 配置字典
- """
- super().__init__(config)
- # 初始化 debug 相关属性(支持从配置或运行时设置)
- self.debug_mode = None # 将在 detect() 方法中从配置读取
- self.output_dir = None # 将在 detect() 方法中从配置读取
- self.page_name = None # 将在 detect() 方法中从配置读取
-
- def detect(
- self,
- image: Union[np.ndarray, Image.Image],
- ocr_spans: Optional[List[Dict[str, Any]]] = None
- ) -> List[Dict[str, Any]]:
- """
- 检测版式(模板方法,自动执行后处理)
-
- 此方法会:
- 1. 调用子类实现的 _detect_raw() 进行原始检测
- 2. 自动执行后处理(去除重叠框、文本转表格等)
-
- Args:
- image: 输入图像
- ocr_spans: OCR结果(可选,某些detector可能需要)
-
- Returns:
- 后处理后的布局检测结果
- """
- # 调用子类实现的原始检测方法
- layout_results = self._detect_raw(image, ocr_spans)
-
- debug_mode = self._is_layout_debug_enabled()
- output_dir, page_name = self._resolve_layout_debug_paths()
- dbg_opts = self._layout_debug_options()
- if debug_mode:
- logger.debug(
- f"Layout detection raw results (before post-processing): "
- f"{len(layout_results)} elements"
- )
- if output_dir and dbg_opts.get('save_raw', True):
- self._visualize_layout_results(
- image, layout_results, output_dir, page_name, suffix='raw'
- )
- # 自动执行后处理
- if layout_results:
- layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
- layout_results = self.post_process(layout_results, image, layout_config)
- if debug_mode and output_dir and dbg_opts.get('save_post_processed', True):
- self._visualize_layout_results(
- image, layout_results, output_dir, page_name, suffix='post'
- )
- return layout_results
-
- @abstractmethod
- def _detect_raw(
- self,
- image: Union[np.ndarray, Image.Image],
- ocr_spans: Optional[List[Dict[str, Any]]] = None
- ) -> List[Dict[str, Any]]:
- """
- 原始检测方法(子类必须实现)
-
- Args:
- image: 输入图像
- ocr_spans: OCR结果(可选)
-
- Returns:
- 原始检测结果(未后处理)
- """
- pass
-
- def post_process(
- self,
- layout_results: List[Dict[str, Any]],
- image: Union[np.ndarray, Image.Image],
- config: Optional[Dict[str, Any]] = None
- ) -> List[Dict[str, Any]]:
- """
- 后处理布局检测结果
-
- 默认实现包括:
- 1. 去除重叠框
- 2. 将大面积文本块转换为表格(如果配置启用)
-
- 子类可以重写此方法以自定义后处理逻辑
-
- Args:
- layout_results: 原始检测结果
- image: 输入图像
- config: 后处理配置(可选),如果为None则使用self.config中的post_process配置
-
- Returns:
- 后处理后的布局结果
- """
- if not layout_results:
- return layout_results
-
- # 获取配置
- if config is None:
- config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
-
- # 导入 CoordinateUtils(适配器可以访问)
- try:
- from ocr_utils.coordinate_utils import CoordinateUtils
- except ImportError:
- try:
- from ocr_utils import CoordinateUtils
- except ImportError:
- # 如果无法导入,返回原始结果
- return layout_results
-
- # 1. 去除重叠框
- layout_results_removed_overlapping = self._remove_overlapping_boxes(layout_results, CoordinateUtils)
-
- # 2. 将大面积文本块转换为表格(如果配置启用)
- layout_config = config if config is not None else {}
- if layout_config.get('convert_large_text_to_table', False):
- # 获取图像尺寸
- if isinstance(image, Image.Image):
- h, w = image.size[1], image.size[0]
- else:
- h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
-
- layout_results_converted_large_text = self._convert_large_text_to_table(
- layout_results_removed_overlapping,
- (h, w),
- min_area_ratio=layout_config.get('min_text_area_ratio', 0.25),
- min_width_ratio=layout_config.get('min_text_width_ratio', 0.4),
- min_height_ratio=layout_config.get('min_text_height_ratio', 0.3)
- )
- return layout_results_converted_large_text
- else:
- return layout_results_removed_overlapping
-
- def _convert_large_text_to_table(
- self,
- layout_results: List[Dict[str, Any]],
- image_shape: Tuple[int, int],
- min_area_ratio: float = 0.25,
- min_width_ratio: float = 0.4,
- min_height_ratio: float = 0.3
- ) -> List[Dict[str, Any]]:
- """
- 将大面积的文本块转换为表格
-
- 判断规则:
- 1. 面积占比:占页面面积超过 min_area_ratio(默认25%)
- 2. 尺寸比例:宽度和高度都超过一定比例(避免细长条)
- 3. 不与其他表格重叠:如果已有表格,不转换
- """
- if not layout_results:
- return layout_results
-
- img_height, img_width = image_shape
- img_area = img_height * img_width
-
- if img_area == 0:
- return layout_results
-
- # 检查是否已有表格
- has_table = any(
- item.get('category', '').lower() in ['table', 'table_body']
- for item in layout_results
- )
-
- # 如果已有表格,不进行转换(避免误判)
- if has_table:
- return layout_results
-
- # 复制列表避免修改原数据
- results = [item.copy() for item in layout_results]
- converted_count = 0
-
- for item in results:
- category = item.get('category', '').lower()
-
- # 只处理文本类型的元素
- if category not in ['text', 'ocr_text']:
- continue
-
- bbox = item.get('bbox', [0, 0, 0, 0])
- if len(bbox) < 4:
- continue
-
- x1, y1, x2, y2 = bbox[:4]
- width = x2 - x1
- height = y2 - y1
- area = width * height
-
- # 计算占比
- area_ratio = area / img_area if img_area > 0 else 0
- width_ratio = width / img_width if img_width > 0 else 0
- height_ratio = height / img_height if img_height > 0 else 0
-
- # 判断是否满足转换条件
- if (area_ratio >= min_area_ratio and
- width_ratio >= min_width_ratio and
- height_ratio >= min_height_ratio):
-
- # 转换为表格
- item['category'] = 'table'
- item['original_category'] = category # 保留原始类别
- converted_count += 1
-
- return results
-
- def _map_category_id(self, category_id: int) -> str:
- """映射类别ID到字符串"""
- category_map = {
- 0: 'title',
- 1: 'text',
- 2: 'abandon',
- 3: 'image_body',
- 4: 'image_caption',
- 5: 'table_body',
- 6: 'table_caption',
- 7: 'table_footnote',
- 8: 'interline_equation',
- 9: 'interline_equation_number',
- 13: 'inline_equation',
- 14: 'interline_equation_yolo',
- 15: 'ocr_text',
- 16: 'low_score_text',
- 101: 'image_footnote'
- }
- return category_map.get(category_id, f'unknown_{category_id}')
-
- def _layout_debug_options(self) -> Dict[str, Any]:
- opts = self.config.get('debug_options', {})
- return opts if isinstance(opts, dict) else {}
- def _is_layout_debug_enabled(self) -> bool:
- debug_mode = getattr(self, 'debug_mode', None)
- if debug_mode is not None:
- return bool(debug_mode)
- if self.config.get('debug_mode', False):
- return True
- return bool(self._layout_debug_options().get('enabled', False))
- def _resolve_layout_debug_paths(self) -> Tuple[Optional[str], str]:
- output_dir = getattr(self, 'output_dir', None)
- if output_dir is None:
- output_dir = self.config.get('output_dir')
- if output_dir is None:
- output_dir = self._layout_debug_options().get('output_dir')
- page_name = getattr(self, 'page_name', None)
- if page_name is None:
- page_name = self.config.get('page_name')
- if not page_name:
- prefix = self._layout_debug_options().get('prefix', '')
- page_name = prefix if prefix else 'layout_detection'
- return output_dir, str(page_name)
- def _visualize_layout_results(
- self,
- image: Union[np.ndarray, Image.Image],
- layout_results: List[Dict[str, Any]],
- output_dir: str,
- page_name: str,
- suffix: str = 'raw',
- ) -> None:
- """保存 layout 模块 debug(底图为 inference / detection 输入)。"""
- if not layout_results:
- return
- from ocr_utils.module_debug_viz import save_layout_debug
- opts = self._layout_debug_options()
- save_layout_debug(
- image,
- layout_results,
- output_dir,
- page_name,
- suffix=suffix,
- subdir=opts.get('subdir', 'layout_detection'),
- image_format=opts.get('image_format', 'jpg'),
- save_json=bool(opts.get('save_json', True)),
- )
-
- def _remove_overlapping_boxes(
- self,
- layout_results: List[Dict[str, Any]],
- coordinate_utils: Any,
- iou_threshold: float = 0.8,
- overlap_ratio_threshold: float = 0.8
- ) -> List[Dict[str, Any]]:
- """
- 改进版重叠框处理算法(基于优先级和决策规则的清晰算法)
-
- 策略:
- 1. 定义类别优先级(abandon < text/image < table_body)
- 2. 使用统一的决策规则
- 3. 按综合评分排序处理,优先保留大的聚合框
-
- Args:
- layout_results: 布局检测结果
- coordinate_utils: 坐标工具类
- iou_threshold: IoU阈值(默认0.8)
- overlap_ratio_threshold: 重叠比例阈值(默认0.8)
-
- Returns:
- 去重后的布局结果
- """
- if not layout_results or len(layout_results) <= 1:
- return layout_results
-
- # 常量定义
- CATEGORY_PRIORITY = {
- 'abandon': 0,
- 'text': 1,
- 'image_body': 1,
- 'title': 2,
- 'footer': 2,
- 'header': 2,
- 'table_body': 3,
- }
- AGGREGATE_LABELS = {'key-value region', 'form'}
- MAX_AREA = 4000000.0 # 用于面积归一化
- AREA_WEIGHT = 0.5
- CONFIDENCE_WEIGHT = 0.5
- AGGREGATE_BONUS = 0.1
- AREA_RATIO_THRESHOLD = 3.0 # 大框面积需大于小框的倍数
-
- def get_bbox_area(bbox: List[float]) -> float:
- """计算bbox面积"""
- if len(bbox) < 4:
- return 0.0
- return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
-
- def is_aggregate_type(box: Dict[str, Any]) -> bool:
- """检查是否是聚合类型"""
- original_label = box.get('raw', {}).get('original_label', '').lower()
- return original_label in AGGREGATE_LABELS
-
- def is_bbox_inside(inner: List[float], outer: List[float]) -> bool:
- """检查inner是否完全包含在outer内"""
- if len(inner) < 4 or len(outer) < 4:
- return False
- return (inner[0] >= outer[0] and inner[1] >= outer[1] and
- inner[2] <= outer[2] and inner[3] <= outer[3])
-
- def calculate_composite_score(box: Dict[str, Any], area: float) -> float:
- """计算text类型的综合评分(面积+置信度)"""
- if box.get('category') != 'text':
- return box.get('confidence', box.get('score', 0))
-
- normalized_area = min(area / MAX_AREA, 1.0)
- area_score = (normalized_area ** 0.5) * AREA_WEIGHT
- confidence_score = box.get('confidence', box.get('score', 0)) * CONFIDENCE_WEIGHT
- bonus = AGGREGATE_BONUS if is_aggregate_type(box) else 0.0
- return area_score + confidence_score + bonus
-
- def should_keep_box1(box1: Dict[str, Any], box2: Dict[str, Any],
- iou: float, overlap_ratio: float,
- contained_1_in_2: bool, contained_2_in_1: bool) -> bool:
- """判断是否应该保留box1"""
- # 提取基本信息
- cat1, cat2 = box1.get('category', ''), box2.get('category', '')
- score1 = box1.get('confidence', box1.get('score', 0))
- score2 = box2.get('confidence', box2.get('score', 0))
- bbox1, bbox2 = box1.get('bbox', [0, 0, 0, 0]), box2.get('bbox', [0, 0, 0, 0])
- area1, area2 = get_bbox_area(bbox1), get_bbox_area(bbox2)
- is_agg1, is_agg2 = is_aggregate_type(box1), is_aggregate_type(box2)
-
- # 规则1: 类别优先级
- priority1 = CATEGORY_PRIORITY.get(cat1, 1)
- priority2 = CATEGORY_PRIORITY.get(cat2, 1)
- if priority1 != priority2:
- return priority1 > priority2
-
- # 规则2: 包含关系 + 聚合类型优先
- if contained_2_in_1 and is_agg1 and not is_agg2:
- return True
- if contained_1_in_2 and is_agg2 and not is_agg1:
- return False
-
- # 规则3: 包含关系 + 面积比例
- if contained_2_in_1 and area1 > area2 * AREA_RATIO_THRESHOLD:
- return True
- if contained_1_in_2 and area2 > area1 * AREA_RATIO_THRESHOLD:
- return False
-
- # 规则4: text类型使用综合评分
- if cat1 == 'text' or cat2 == 'text':
- comp_score1 = calculate_composite_score(box1, area1)
- comp_score2 = calculate_composite_score(box2, area2)
- if abs(comp_score1 - comp_score2) > 0.05:
- return comp_score1 > comp_score2
-
- # 规则5: 置信度比较
- if abs(score1 - score2) > 0.1:
- return score1 > score2
-
- # 规则6: 面积比较
- return area1 >= area2
-
- # 主处理逻辑
- results = [item.copy() for item in layout_results]
- need_remove = set()
-
- # 按综合评分排序(高分优先)
- def get_sort_key(i: int) -> float:
- item = results[i]
- if item.get('category') == 'text':
- return -calculate_composite_score(item, get_bbox_area(item.get('bbox', [])))
- return -item.get('confidence', item.get('score', 0))
-
- sorted_indices = sorted(range(len(results)), key=get_sort_key)
-
- # 比较每对框
- for idx_i, i in enumerate(sorted_indices):
- if i in need_remove:
- continue
-
- for idx_j, j in enumerate(sorted_indices):
- if j == i or j in need_remove or idx_j >= idx_i:
- continue
-
- bbox1, bbox2 = results[i].get('bbox', []), results[j].get('bbox', [])
- if len(bbox1) < 4 or len(bbox2) < 4:
- continue
-
- # 计算重叠指标
- iou = coordinate_utils.calculate_iou(bbox1, bbox2)
- overlap_ratio = coordinate_utils.calculate_overlap_ratio(bbox1, bbox2)
- contained_1_in_2 = is_bbox_inside(bbox1, bbox2)
- contained_2_in_1 = is_bbox_inside(bbox2, bbox1)
-
- # 检查是否有显著重叠
- if not (iou > iou_threshold or overlap_ratio > overlap_ratio_threshold or
- contained_1_in_2 or contained_2_in_1):
- continue
-
- # 应用决策规则
- if should_keep_box1(results[i], results[j], iou, overlap_ratio,
- contained_1_in_2, contained_2_in_1):
- need_remove.add(j)
- else:
- need_remove.add(i)
- break
-
- return [results[i] for i in range(len(results)) if i not in need_remove]
- class BaseVLRecognizer(BaseAdapter):
- """VL识别器基类"""
-
- @abstractmethod
- def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
- """识别表格"""
- pass
-
- @abstractmethod
- def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
- """识别公式"""
- pass
- class BaseOCRRecognizer(BaseAdapter):
- """OCR识别器基类"""
- @abstractmethod
- def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
- """识别文本"""
- pass
-
- @abstractmethod
- def detect_text_boxes(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
- """
- 只检测文本框(不识别文字内容)
-
- 子类必须实现此方法。建议使用只运行检测模型的方式(不运行识别模型)以优化性能。
- 如果无法优化,至少实现一个调用 recognize_text() 的版本作为兜底。
-
- Returns:
- 文本框列表,每项包含 'bbox', 'poly',可能包含 'confidence'
- """
- pass
|