|
|
@@ -25,6 +25,72 @@ class BaseAdapter(ABC):
|
|
|
|
|
|
class BasePreprocessor(BaseAdapter):
|
|
|
"""预处理器基类"""
|
|
|
+
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
+ super().__init__(config)
|
|
|
+ # 运行时由 pipeline 按页注入(与 layout_detection 一致)
|
|
|
+ self.debug_mode: Optional[bool] = None
|
|
|
+ self.output_dir: Optional[str] = None
|
|
|
+ self.page_name: Optional[str] = None
|
|
|
+
|
|
|
+ def _watermark_debug_options(self) -> Dict[str, Any]:
|
|
|
+ wm_cfg = self.config.get('watermark_removal', {})
|
|
|
+ opts = wm_cfg.get('debug_options', {})
|
|
|
+ return opts if isinstance(opts, dict) else {}
|
|
|
+
|
|
|
+ def _is_watermark_debug_enabled(self) -> bool:
|
|
|
+ debug_mode = getattr(self, 'debug_mode', None)
|
|
|
+ if debug_mode is not None:
|
|
|
+ return bool(debug_mode)
|
|
|
+ return bool(self._watermark_debug_options().get('enabled', False))
|
|
|
+
|
|
|
+ def _resolve_watermark_debug_paths(self) -> Tuple[Optional[str], str]:
|
|
|
+ output_dir = getattr(self, 'output_dir', None)
|
|
|
+ if output_dir is None:
|
|
|
+ output_dir = self._watermark_debug_options().get('output_dir')
|
|
|
+ page_name = getattr(self, 'page_name', None)
|
|
|
+ if not page_name:
|
|
|
+ page_name = self._watermark_debug_options().get('prefix') or 'watermark'
|
|
|
+ prefix = self._watermark_debug_options().get('prefix', '')
|
|
|
+ if prefix and page_name and not str(page_name).startswith(str(prefix)):
|
|
|
+ page_name = f"{prefix}_{page_name}"
|
|
|
+ return output_dir, str(page_name)
|
|
|
+
|
|
|
+ def _save_watermark_debug_images(
|
|
|
+ self,
|
|
|
+ before: np.ndarray,
|
|
|
+ after: np.ndarray,
|
|
|
+ threshold: int,
|
|
|
+ morph_close_kernel: int,
|
|
|
+ contrast_cfg: Optional[Dict[str, Any]] = None,
|
|
|
+ ) -> None:
|
|
|
+ """保存水印调试图(委托 ocr_utils.watermark_utils)。"""
|
|
|
+ from ocr_utils.watermark_utils import save_watermark_removal_debug
|
|
|
+
|
|
|
+ output_dir, page_name = self._resolve_watermark_debug_paths()
|
|
|
+ if not output_dir:
|
|
|
+ return
|
|
|
+
|
|
|
+ opts = self._watermark_debug_options()
|
|
|
+ params: Dict[str, Any] = {
|
|
|
+ "threshold": threshold,
|
|
|
+ "morph_close_kernel": morph_close_kernel,
|
|
|
+ }
|
|
|
+ if contrast_cfg:
|
|
|
+ params["contrast_enhancement"] = contrast_cfg
|
|
|
+
|
|
|
+ try:
|
|
|
+ save_watermark_removal_debug(
|
|
|
+ before,
|
|
|
+ after,
|
|
|
+ output_dir,
|
|
|
+ page_name,
|
|
|
+ processing_params=params,
|
|
|
+ image_format=opts.get("image_format") or "png",
|
|
|
+ save_compare=opts.get("save_compare", True),
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Watermark debug save failed: {e}")
|
|
|
|
|
|
def remove_watermark(self, image: Union[np.ndarray, Image.Image]) -> np.ndarray:
|
|
|
"""页级水印去除(默认无操作,子类可覆盖)。"""
|
|
|
@@ -32,17 +98,82 @@ class BasePreprocessor(BaseAdapter):
|
|
|
return np.array(image)
|
|
|
return image
|
|
|
|
|
|
- @abstractmethod
|
|
|
+ def _preprocess_order(self) -> str:
|
|
|
+ """预处理步骤顺序:orient_first(默认)| watermark_first。"""
|
|
|
+ order = str(self.config.get('order', 'orient_first')).strip().lower()
|
|
|
+ if order not in ('orient_first', 'watermark_first'):
|
|
|
+ logger.warning(
|
|
|
+ f"Unknown preprocessor.order={order!r}, fallback to orient_first"
|
|
|
+ )
|
|
|
+ return 'orient_first'
|
|
|
+ return order
|
|
|
+
|
|
|
+ def correct_orientation(
|
|
|
+ self,
|
|
|
+ image: Union[np.ndarray, Image.Image],
|
|
|
+ *,
|
|
|
+ pdf_rotate_angle: Optional[int] = None,
|
|
|
+ use_orientation_classifier: bool = True,
|
|
|
+ ) -> tuple[np.ndarray, int]:
|
|
|
+ """
|
|
|
+ 仅方向校正,不去水印。用于表格裁剪等页级已预处理场景。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ pdf_rotate_angle: 文字 PDF 页级旋转(逆时针角度,与 pipeline 一致)
|
|
|
+ use_orientation_classifier: 是否使用方向分类器(扫描件为 True)
|
|
|
+ """
|
|
|
+ if isinstance(image, Image.Image):
|
|
|
+ image = np.array(image)
|
|
|
+
|
|
|
+ if pdf_rotate_angle:
|
|
|
+ pil_rotated = Image.fromarray(image).rotate(pdf_rotate_angle, expand=True)
|
|
|
+ return np.array(pil_rotated), int(pdf_rotate_angle)
|
|
|
+ return image, 0
|
|
|
+
|
|
|
+ def prepare_detection_image(
|
|
|
+ self,
|
|
|
+ image: Union[np.ndarray, Image.Image],
|
|
|
+ *,
|
|
|
+ pdf_rotate_angle: Optional[int] = None,
|
|
|
+ use_orientation_classifier: bool = True,
|
|
|
+ ) -> tuple[np.ndarray, int]:
|
|
|
+ """
|
|
|
+ 页级完整预处理:按 preprocessor.order 执行方向校正与水印去除。
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (detection_image, rotate_angle)
|
|
|
+ """
|
|
|
+ if isinstance(image, Image.Image):
|
|
|
+ image = np.array(image)
|
|
|
+
|
|
|
+ order = self._preprocess_order()
|
|
|
+
|
|
|
+ def _orient(img: np.ndarray) -> tuple[np.ndarray, int]:
|
|
|
+ return self.correct_orientation(
|
|
|
+ img,
|
|
|
+ pdf_rotate_angle=pdf_rotate_angle,
|
|
|
+ use_orientation_classifier=use_orientation_classifier,
|
|
|
+ )
|
|
|
+
|
|
|
+ if order == 'watermark_first':
|
|
|
+ cleaned = self.remove_watermark(image)
|
|
|
+ return _orient(cleaned)
|
|
|
+
|
|
|
+ oriented, rotate_angle = _orient(image)
|
|
|
+ return self.remove_watermark(oriented), rotate_angle
|
|
|
+
|
|
|
def process(
|
|
|
self,
|
|
|
image: Union[np.ndarray, Image.Image],
|
|
|
skip_watermark: bool = False,
|
|
|
) -> tuple[np.ndarray, int]:
|
|
|
"""
|
|
|
- 处理图像
|
|
|
- 返回处理后的图像和旋转角度
|
|
|
+ 裁剪块:仅方向校正(skip_watermark=True)。
|
|
|
+ 页级请使用 prepare_detection_image()。
|
|
|
"""
|
|
|
- pass
|
|
|
+ if skip_watermark:
|
|
|
+ return self.correct_orientation(image, use_orientation_classifier=True)
|
|
|
+ return self.prepare_detection_image(image, use_orientation_classifier=True)
|
|
|
|
|
|
def _apply_rotation(self, image: np.ndarray, rotation_angle: int) -> np.ndarray:
|
|
|
"""应用旋转"""
|
|
|
@@ -92,63 +223,30 @@ class BaseLayoutDetector(BaseAdapter):
|
|
|
# 调用子类实现的原始检测方法
|
|
|
layout_results = self._detect_raw(image, ocr_spans)
|
|
|
|
|
|
- # Debug 模式:打印和可视化后处理前的检测结果
|
|
|
- # 优先从实例属性读取(如果存在),否则从配置读取
|
|
|
- # 支持两种配置方式:debug_mode 或 debug_options.enabled
|
|
|
- debug_mode = getattr(self, 'debug_mode', None)
|
|
|
- if debug_mode is None:
|
|
|
- if hasattr(self, 'config'):
|
|
|
- # 优先从 debug_mode 读取
|
|
|
- debug_mode = self.config.get('debug_mode', False)
|
|
|
- # 如果没有 debug_mode,尝试从 debug_options.enabled 读取
|
|
|
- if not debug_mode:
|
|
|
- debug_options = self.config.get('debug_options', {})
|
|
|
- if isinstance(debug_options, dict):
|
|
|
- debug_mode = debug_options.get('enabled', False)
|
|
|
- else:
|
|
|
- debug_mode = False
|
|
|
-
|
|
|
+ debug_mode = self._is_layout_debug_enabled()
|
|
|
+ output_dir, page_name = self._resolve_layout_debug_paths()
|
|
|
+ dbg_opts = self._layout_debug_options()
|
|
|
+
|
|
|
if debug_mode:
|
|
|
- logger.debug(f"🔍 Layout detection raw results (before post-processing): {len(layout_results)} elements")
|
|
|
- # logger.debug(f"Raw layout_results: {layout_results}")
|
|
|
- # 可视化 layout 结果
|
|
|
- output_dir = getattr(self, 'output_dir', None)
|
|
|
- if output_dir is None:
|
|
|
- if hasattr(self, 'config'):
|
|
|
- # 优先从 output_dir 读取
|
|
|
- output_dir = self.config.get('output_dir', None)
|
|
|
- # 如果没有 output_dir,尝试从 debug_options.output_dir 读取
|
|
|
- if output_dir is None:
|
|
|
- debug_options = self.config.get('debug_options', {})
|
|
|
- if isinstance(debug_options, dict):
|
|
|
- output_dir = debug_options.get('output_dir', None)
|
|
|
- else:
|
|
|
- output_dir = None
|
|
|
-
|
|
|
- page_name = getattr(self, 'page_name', None)
|
|
|
- if page_name is None:
|
|
|
- if hasattr(self, 'config'):
|
|
|
- # 优先从 page_name 读取
|
|
|
- page_name = self.config.get('page_name', None)
|
|
|
- # 如果没有 page_name,尝试从 debug_options.prefix 读取
|
|
|
- if page_name is None:
|
|
|
- debug_options = self.config.get('debug_options', {})
|
|
|
- if isinstance(debug_options, dict):
|
|
|
- prefix = debug_options.get('prefix', '')
|
|
|
- page_name = prefix if prefix else 'layout_detection'
|
|
|
- if page_name is None:
|
|
|
- page_name = 'layout_detection'
|
|
|
- else:
|
|
|
- page_name = 'layout_detection'
|
|
|
-
|
|
|
- if output_dir:
|
|
|
- self._visualize_layout_results(image, layout_results, output_dir, page_name, suffix='raw')
|
|
|
-
|
|
|
+ logger.debug(
|
|
|
+ f"Layout detection raw results (before post-processing): "
|
|
|
+ f"{len(layout_results)} elements"
|
|
|
+ )
|
|
|
+ if output_dir and dbg_opts.get('save_raw', True):
|
|
|
+ self._visualize_layout_results(
|
|
|
+ image, layout_results, output_dir, page_name, suffix='raw'
|
|
|
+ )
|
|
|
+
|
|
|
# 自动执行后处理
|
|
|
if layout_results:
|
|
|
layout_config = self.config.get('post_process', {}) if hasattr(self, 'config') else {}
|
|
|
layout_results = self.post_process(layout_results, image, layout_config)
|
|
|
-
|
|
|
+
|
|
|
+ if debug_mode and output_dir and dbg_opts.get('save_post_processed', True):
|
|
|
+ self._visualize_layout_results(
|
|
|
+ image, layout_results, output_dir, page_name, suffix='post'
|
|
|
+ )
|
|
|
+
|
|
|
return layout_results
|
|
|
|
|
|
@abstractmethod
|
|
|
@@ -325,116 +423,57 @@ class BaseLayoutDetector(BaseAdapter):
|
|
|
}
|
|
|
return category_map.get(category_id, f'unknown_{category_id}')
|
|
|
|
|
|
+ def _layout_debug_options(self) -> Dict[str, Any]:
|
|
|
+ opts = self.config.get('debug_options', {})
|
|
|
+ return opts if isinstance(opts, dict) else {}
|
|
|
+
|
|
|
+ def _is_layout_debug_enabled(self) -> bool:
|
|
|
+ debug_mode = getattr(self, 'debug_mode', None)
|
|
|
+ if debug_mode is not None:
|
|
|
+ return bool(debug_mode)
|
|
|
+ if self.config.get('debug_mode', False):
|
|
|
+ return True
|
|
|
+ return bool(self._layout_debug_options().get('enabled', False))
|
|
|
+
|
|
|
+ def _resolve_layout_debug_paths(self) -> Tuple[Optional[str], str]:
|
|
|
+ output_dir = getattr(self, 'output_dir', None)
|
|
|
+ if output_dir is None:
|
|
|
+ output_dir = self.config.get('output_dir')
|
|
|
+ if output_dir is None:
|
|
|
+ output_dir = self._layout_debug_options().get('output_dir')
|
|
|
+
|
|
|
+ page_name = getattr(self, 'page_name', None)
|
|
|
+ if page_name is None:
|
|
|
+ page_name = self.config.get('page_name')
|
|
|
+ if not page_name:
|
|
|
+ prefix = self._layout_debug_options().get('prefix', '')
|
|
|
+ page_name = prefix if prefix else 'layout_detection'
|
|
|
+ return output_dir, str(page_name)
|
|
|
+
|
|
|
def _visualize_layout_results(
|
|
|
self,
|
|
|
image: Union[np.ndarray, Image.Image],
|
|
|
layout_results: List[Dict[str, Any]],
|
|
|
output_dir: str,
|
|
|
page_name: str,
|
|
|
- suffix: str = 'raw'
|
|
|
+ suffix: str = 'raw',
|
|
|
) -> None:
|
|
|
- """
|
|
|
- 可视化 layout 检测结果
|
|
|
-
|
|
|
- Args:
|
|
|
- image: 输入图像
|
|
|
- layout_results: 布局检测结果
|
|
|
- output_dir: 输出目录
|
|
|
- page_name: 页面名称
|
|
|
- suffix: 文件名后缀(如 'raw', 'postprocessed')
|
|
|
- """
|
|
|
+ """保存 layout 模块 debug(底图为 inference / detection 输入)。"""
|
|
|
if not layout_results:
|
|
|
return
|
|
|
-
|
|
|
- try:
|
|
|
- # 转换为 numpy 数组
|
|
|
- if isinstance(image, Image.Image):
|
|
|
- vis_image = np.array(image)
|
|
|
- if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
|
|
|
- # PIL RGB -> OpenCV BGR
|
|
|
- vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
|
|
|
- else:
|
|
|
- vis_image = image.copy()
|
|
|
- if len(vis_image.shape) == 3 and vis_image.shape[2] == 3:
|
|
|
- # 如果是 RGB,转换为 BGR
|
|
|
- vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
|
|
|
-
|
|
|
- # 定义类别颜色映射 (BGR格式)
|
|
|
- category_colors = {
|
|
|
- 'table_body': (0, 0, 255), # 红色
|
|
|
- 'table_caption': (0, 0, 200), # 暗红色
|
|
|
- 'table_footnote': (0, 0, 150), # 更暗的红色
|
|
|
- 'text': (255, 0, 0), # 蓝色
|
|
|
- 'title': (0, 255, 255), # 黄色
|
|
|
- 'header': (255, 0, 255), # 紫色
|
|
|
- 'footer': (0, 165, 255), # 橙色
|
|
|
- 'image_body': (0, 255, 0), # 绿色
|
|
|
- 'image_caption': (0, 200, 0), # 暗绿色
|
|
|
- 'image_footnote': (0, 150, 0), # 更暗的绿色
|
|
|
- 'abandon': (128, 128, 128), # 灰色
|
|
|
- }
|
|
|
-
|
|
|
- # 绘制检测框
|
|
|
- for result in layout_results:
|
|
|
- bbox = result.get('bbox', [])
|
|
|
- if not bbox or len(bbox) < 4:
|
|
|
- continue
|
|
|
-
|
|
|
- category = result.get('category', 'unknown')
|
|
|
- color = category_colors.get(category, (128, 128, 128)) # 默认灰色
|
|
|
- thickness = 2
|
|
|
-
|
|
|
- x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
|
|
- cv2.rectangle(vis_image, (x1, y1), (x2, y2), color, thickness)
|
|
|
-
|
|
|
- # 添加类别标签
|
|
|
- label = f"{category}"
|
|
|
- confidence = result.get('confidence', result.get('score', 0))
|
|
|
- if confidence:
|
|
|
- label += f":{confidence:.2f}"
|
|
|
-
|
|
|
- # 计算文本大小
|
|
|
- font = cv2.FONT_HERSHEY_SIMPLEX
|
|
|
- font_scale = 0.4
|
|
|
- text_thickness = 1
|
|
|
- (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, text_thickness)
|
|
|
-
|
|
|
- # 在框的上方绘制文本背景
|
|
|
- text_y = max(y1 - baseline - 1, text_height + baseline)
|
|
|
- cv2.rectangle(vis_image, (x1, text_y - text_height - baseline - 2),
|
|
|
- (x1 + text_width, text_y), color, -1)
|
|
|
- # 绘制文本
|
|
|
- cv2.putText(vis_image, label, (x1, text_y - baseline - 1),
|
|
|
- font, font_scale, (255, 255, 255), text_thickness)
|
|
|
-
|
|
|
- # 保存图像
|
|
|
- debug_dir = Path(output_dir) / "debug_comparison" / "layout_detection"
|
|
|
- debug_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- output_path = debug_dir / f"{page_name}_layout_{suffix}.jpg"
|
|
|
- cv2.imwrite(str(output_path), vis_image)
|
|
|
- logger.info(f"📊 Saved layout detection image ({suffix}): {output_path}")
|
|
|
-
|
|
|
- # 保存 JSON 数据
|
|
|
- json_data = {
|
|
|
- 'page_name': page_name,
|
|
|
- 'suffix': suffix,
|
|
|
- 'count': len(layout_results),
|
|
|
- 'results': [
|
|
|
- {
|
|
|
- 'category': r.get('category'),
|
|
|
- 'bbox': r.get('bbox'),
|
|
|
- 'confidence': r.get('confidence', r.get('score', 0.0))
|
|
|
- }
|
|
|
- for r in layout_results
|
|
|
- ]
|
|
|
- }
|
|
|
- json_path = debug_dir / f"{page_name}_layout_{suffix}.json"
|
|
|
- with open(json_path, 'w', encoding='utf-8') as f:
|
|
|
- json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
- logger.info(f"📊 Saved layout detection JSON ({suffix}): {json_path}")
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"⚠️ Failed to visualize layout results: {e}")
|
|
|
+ from ocr_utils.module_debug_viz import save_layout_debug
|
|
|
+
|
|
|
+ opts = self._layout_debug_options()
|
|
|
+ save_layout_debug(
|
|
|
+ image,
|
|
|
+ layout_results,
|
|
|
+ output_dir,
|
|
|
+ page_name,
|
|
|
+ suffix=suffix,
|
|
|
+ subdir=opts.get('subdir', 'layout_detection'),
|
|
|
+ image_format=opts.get('image_format', 'jpg'),
|
|
|
+ save_json=bool(opts.get('save_json', True)),
|
|
|
+ )
|
|
|
|
|
|
def _remove_overlapping_boxes(
|
|
|
self,
|
|
|
@@ -615,7 +654,7 @@ class BaseVLRecognizer(BaseAdapter):
|
|
|
|
|
|
class BaseOCRRecognizer(BaseAdapter):
|
|
|
"""OCR识别器基类"""
|
|
|
-
|
|
|
+
|
|
|
@abstractmethod
|
|
|
def recognize_text(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
|
|
|
"""识别文本"""
|