| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- """
- 模块级 Debug 可视化(Layout / OCR)
- 用于 ``{output_dir}/debug/{subdir}/`` 下基于 inference_image 的调试图;
- 用户审计图由 VisualizationUtils + original_image 负责,不在此模块。
- """
- from __future__ import annotations
- import json
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Union
- import cv2
- import numpy as np
- from loguru import logger
- from PIL import Image
- # 各模块 debug_options 默认落盘根目录(相对 pipeline output_dir)
- MODULE_DEBUG_ROOT = "debug"
- def resolve_module_debug_dir(
- output_dir: Union[str, Path],
- subdir: str,
- *,
- debug_root: str = MODULE_DEBUG_ROOT,
- ) -> Path:
- """``{output_dir}/{debug_root}/{subdir}/``,目录不存在则创建。"""
- debug_dir = Path(output_dir) / debug_root / subdir
- debug_dir.mkdir(parents=True, exist_ok=True)
- return debug_dir
- LAYOUT_CATEGORY_COLORS_BGR = {
- 'table_body': (0, 0, 255),
- 'table_caption': (0, 0, 200),
- 'table_footnote': (0, 0, 150),
- 'text': (255, 0, 0),
- 'title': (0, 255, 255),
- 'header': (255, 0, 255),
- 'footer': (0, 165, 255),
- 'image_body': (0, 255, 0),
- 'image_caption': (0, 200, 0),
- 'image_footnote': (0, 150, 0),
- 'abandon': (128, 128, 128),
- }
- # 亮蓝(BGR),在白底/浅灰流水上比黄色更易辨认;与 layout 红色框区分
- OCR_BOX_COLOR_BGR = (255, 0, 0)
- OCR_BOX_LINE_THICKNESS = 2
- OCR_BOX_DASH_LENGTH = 8
- OCR_BOX_DASH_GAP = 6
- def _to_bgr(image: Union[np.ndarray, Image.Image]) -> np.ndarray:
- if isinstance(image, Image.Image):
- arr = np.array(image)
- else:
- arr = image.copy()
- if arr.ndim == 2:
- return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
- if arr.shape[2] == 3:
- return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
- return arr
- def draw_layout_boxes_cv2(
- image: Union[np.ndarray, Image.Image],
- layout_results: List[Dict[str, Any]],
- ) -> np.ndarray:
- """在 BGR 图像上绘制 layout 检测框,返回新图像。"""
- vis = _to_bgr(image)
- for result in layout_results:
- bbox = result.get('bbox', [])
- if not bbox or len(bbox) < 4:
- continue
- category = result.get('category', 'unknown')
- color = LAYOUT_CATEGORY_COLORS_BGR.get(category, (128, 128, 128))
- x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
- cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
- label = category
- confidence = result.get('confidence', result.get('score', 0))
- if confidence:
- label += f":{float(confidence):.2f}"
- font = cv2.FONT_HERSHEY_SIMPLEX
- font_scale = 0.4
- text_thickness = 1
- (text_width, text_height), baseline = cv2.getTextSize(
- label, font, font_scale, text_thickness
- )
- text_y = max(y1 - baseline - 1, text_height + baseline)
- cv2.rectangle(
- vis,
- (x1, text_y - text_height - baseline - 2),
- (x1 + text_width, text_y),
- color,
- -1,
- )
- cv2.putText(
- vis, label, (x1, text_y - baseline - 1),
- font, font_scale, (255, 255, 255), text_thickness,
- )
- return vis
- def _draw_dashed_segment(
- vis: np.ndarray,
- p1: np.ndarray,
- p2: np.ndarray,
- color: tuple,
- thickness: int,
- *,
- dash_length: int = OCR_BOX_DASH_LENGTH,
- gap_length: int = OCR_BOX_DASH_GAP,
- ) -> None:
- """在 p1→p2 上绘制虚线段。"""
- start = p1.astype(np.float64)
- end = p2.astype(np.float64)
- vec = end - start
- length = float(np.linalg.norm(vec))
- if length < 1e-6:
- return
- direction = vec / length
- pos = 0.0
- draw = True
- while pos < length:
- seg = float(dash_length if draw else gap_length)
- seg_end = min(pos + seg, length)
- if draw:
- s = (start + direction * pos).astype(np.int32)
- e = (start + direction * seg_end).astype(np.int32)
- cv2.line(
- vis,
- (int(s[0]), int(s[1])),
- (int(e[0]), int(e[1])),
- color,
- thickness,
- cv2.LINE_AA,
- )
- pos = seg_end
- draw = not draw
- def _draw_span_outline(
- vis: np.ndarray,
- pts: np.ndarray,
- color: tuple,
- thickness: int,
- *,
- dashed: bool,
- ) -> None:
- n = len(pts)
- if n < 2:
- return
- for i in range(n):
- p1 = pts[i]
- p2 = pts[(i + 1) % n]
- if dashed:
- _draw_dashed_segment(vis, p1, p2, color, thickness)
- else:
- cv2.line(
- vis,
- (int(p1[0]), int(p1[1])),
- (int(p2[0]), int(p2[1])),
- color,
- thickness,
- cv2.LINE_AA,
- )
- def draw_ocr_spans_cv2(
- image: Union[np.ndarray, Image.Image],
- spans: List[Dict[str, Any]],
- *,
- max_label_chars: int = 12,
- ) -> np.ndarray:
- """在 BGR 图像上绘制 OCR span(poly 或 bbox);无文字用虚线框。"""
- vis = _to_bgr(image)
- for span in spans:
- poly = span.get('poly')
- bbox = span.get('bbox', [])
- pts = None
- if poly and len(poly) >= 4:
- pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
- elif bbox and len(bbox) >= 4:
- x0, y0, x1, y1 = map(int, bbox[:4])
- pts = np.array(
- [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], dtype=np.int32
- )
- if pts is not None:
- text_raw = str(span.get('text', '') or '').strip()
- _draw_span_outline(
- vis,
- pts,
- OCR_BOX_COLOR_BGR,
- OCR_BOX_LINE_THICKNESS,
- dashed=not text_raw,
- )
- text = str(span.get('text', '')).strip()[:max_label_chars]
- if text and pts is not None:
- x, y = int(pts[0][0]), int(pts[0][1])
- cv2.putText(
- vis, text, (x, max(y - 2, 10)),
- cv2.FONT_HERSHEY_SIMPLEX, 0.35, OCR_BOX_COLOR_BGR, 1, cv2.LINE_AA,
- )
- return vis
- def save_layout_debug(
- image: Union[np.ndarray, Image.Image],
- layout_results: List[Dict[str, Any]],
- output_dir: Union[str, Path],
- page_name: str,
- *,
- suffix: str = 'raw',
- subdir: str = 'layout_detection',
- image_format: str = 'jpg',
- save_json: bool = True,
- ) -> Optional[Dict[str, str]]:
- """保存 layout 模块 debug 图与 JSON。"""
- if not layout_results or not output_dir:
- return None
- try:
- fmt = (image_format or 'jpg').lstrip('.')
- debug_dir = resolve_module_debug_dir(output_dir, subdir)
- vis = draw_layout_boxes_cv2(image, layout_results)
- img_path = debug_dir / f'{page_name}_layout_{suffix}.{fmt}'
- cv2.imwrite(str(img_path), vis)
- paths: Dict[str, str] = {'image': str(img_path)}
- logger.info(f"Saved layout detection image ({suffix}): {img_path}")
- if save_json:
- json_data = {
- 'page_name': page_name,
- 'suffix': suffix,
- 'count': len(layout_results),
- 'results': [
- {
- 'category': r.get('category'),
- 'bbox': r.get('bbox'),
- 'confidence': r.get('confidence', r.get('score', 0.0)),
- }
- for r in layout_results
- ],
- }
- json_path = debug_dir / f'{page_name}_layout_{suffix}.json'
- json_path.write_text(
- json.dumps(json_data, ensure_ascii=False, indent=2),
- encoding='utf-8',
- )
- paths['json'] = str(json_path)
- logger.info(f"Saved layout detection JSON ({suffix}): {json_path}")
- return paths
- except Exception as e:
- logger.warning(f"Failed to save layout debug ({suffix}): {e}")
- return None
- def save_ocr_debug(
- image: Union[np.ndarray, Image.Image],
- spans: List[Dict[str, Any]],
- output_dir: Union[str, Path],
- page_name: str,
- *,
- subdir: str = 'ocr_recognition',
- image_format: str = 'png',
- save_json: bool = True,
- ) -> Optional[Dict[str, str]]:
- """保存 OCR 模块 debug 图与 JSON。"""
- if not output_dir:
- return None
- try:
- fmt = (image_format or 'png').lstrip('.')
- debug_dir = resolve_module_debug_dir(output_dir, subdir)
- vis = draw_ocr_spans_cv2(image, spans or [])
- img_path = debug_dir / f'{page_name}_ocr_spans.{fmt}'
- cv2.imwrite(str(img_path), vis)
- paths: Dict[str, str] = {'image': str(img_path)}
- logger.info(f"Saved OCR debug image: {img_path}")
- if save_json:
- json_data = {
- 'page_name': page_name,
- 'count': len(spans or []),
- 'spans': [
- {
- 'bbox': s.get('bbox'),
- 'poly': s.get('poly'),
- 'text': s.get('text'),
- 'confidence': s.get('confidence'),
- }
- for s in (spans or [])
- ],
- }
- json_path = debug_dir / f'{page_name}_ocr_spans.json'
- json_path.write_text(
- json.dumps(json_data, ensure_ascii=False, indent=2),
- encoding='utf-8',
- )
- paths['json'] = str(json_path)
- logger.info(f"Saved OCR debug JSON: {json_path}")
- return paths
- except Exception as e:
- logger.warning(f"Failed to save OCR debug: {e}")
- return None
|