module_debug_viz.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. """
  2. 模块级 Debug 可视化(Layout / OCR)
  3. 用于 ``{output_dir}/debug/{subdir}/`` 下基于 inference_image 的调试图;
  4. 用户审计图由 VisualizationUtils + original_image 负责,不在此模块。
  5. """
  6. from __future__ import annotations
  7. import json
  8. from pathlib import Path
  9. from typing import Any, Dict, List, Optional, Union
  10. import cv2
  11. import numpy as np
  12. from loguru import logger
  13. from PIL import Image
  14. # 各模块 debug_options 默认落盘根目录(相对 pipeline output_dir)
  15. MODULE_DEBUG_ROOT = "debug"
  16. def _json_default(o: Any):
  17. """json.dumps 的兜底序列化:处理 numpy 标量/数组(如 OCR confidence 的 float32)。"""
  18. if isinstance(o, np.generic):
  19. return o.item()
  20. if isinstance(o, np.ndarray):
  21. return o.tolist()
  22. if isinstance(o, (set, tuple)):
  23. return list(o)
  24. raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")
  25. def resolve_module_debug_dir(
  26. output_dir: Union[str, Path],
  27. subdir: str,
  28. *,
  29. debug_root: str = MODULE_DEBUG_ROOT,
  30. ) -> Path:
  31. """``{output_dir}/{debug_root}/{subdir}/``,目录不存在则创建。"""
  32. debug_dir = Path(output_dir) / debug_root / subdir
  33. debug_dir.mkdir(parents=True, exist_ok=True)
  34. return debug_dir
  35. LAYOUT_CATEGORY_COLORS_BGR = {
  36. 'table_body': (0, 0, 255),
  37. 'table_caption': (0, 0, 200),
  38. 'table_footnote': (0, 0, 150),
  39. 'text': (255, 0, 0),
  40. 'title': (0, 255, 255),
  41. 'header': (255, 0, 255),
  42. 'footer': (0, 165, 255),
  43. 'image_body': (0, 255, 0),
  44. 'image_caption': (0, 200, 0),
  45. 'image_footnote': (0, 150, 0),
  46. 'chart': (255, 255, 0), # 青色(BGR 下 B=255,G=255,R=0)
  47. # 注意:OpenCV 为 BGR,(0,255,255) 在屏幕上呈黄色(与 title 相同),勿用于 seal
  48. 'seal': (0, 140, 255), # 亮橙,与红 table / 黄 title / 蓝 text 均易区分
  49. 'abandon': (128, 128, 128),
  50. }
  51. # seal 常与 table 重叠:加粗线宽 + 黑色外描边
  52. LAYOUT_HIGHLIGHT_CATEGORIES = frozenset({'seal'})
  53. LAYOUT_HIGHLIGHT_LINE_THICKNESS = 4
  54. LAYOUT_HIGHLIGHT_OUTLINE_BGR = (0, 0, 0)
  55. LAYOUT_DEFAULT_LINE_THICKNESS = 2
  56. # OCR 框线宽 (不受配色统一影响)
  57. OCR_BOX_LINE_THICKNESS = 2
  58. OCR_BOX_DASH_LENGTH = 8
  59. OCR_BOX_DASH_GAP = 6
  60. def _ocr_box_color_bgr() -> tuple:
  61. """亮蓝 OCR 框 (BGR),派生自 VisualizationUtils.COLOR_MAP['ocr_box']。"""
  62. from ocr_utils.visualization_utils import VisualizationUtils
  63. return VisualizationUtils.rgb_to_bgr(VisualizationUtils.COLOR_MAP['ocr_box'])
  64. def _seal_ocr_box_color_bgr() -> tuple:
  65. """印章 OCR 框 (BGR),派生自 VisualizationUtils.COLOR_MAP['seal_ocr_box']。"""
  66. from ocr_utils.visualization_utils import VisualizationUtils
  67. return VisualizationUtils.rgb_to_bgr(VisualizationUtils.COLOR_MAP['seal_ocr_box'])
  68. def ocr_box_color_rgb() -> tuple:
  69. """OCR 亮蓝 (RGB),供 PIL / Plotly 使用。"""
  70. from ocr_utils.visualization_utils import VisualizationUtils
  71. return VisualizationUtils.COLOR_MAP['ocr_box']
  72. def _to_bgr(image: Union[np.ndarray, Image.Image]) -> np.ndarray:
  73. if isinstance(image, Image.Image):
  74. arr = np.array(image)
  75. else:
  76. arr = image.copy()
  77. if arr.ndim == 2:
  78. return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
  79. if arr.shape[2] == 3:
  80. return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
  81. return arr
  82. def draw_layout_boxes_cv2(
  83. image: Union[np.ndarray, Image.Image],
  84. layout_results: List[Dict[str, Any]],
  85. ) -> np.ndarray:
  86. """在 BGR 图像上绘制 layout 检测框,返回新图像。"""
  87. vis = _to_bgr(image)
  88. for result in layout_results:
  89. bbox = result.get('bbox', [])
  90. if not bbox or len(bbox) < 4:
  91. continue
  92. category = result.get('category', 'unknown')
  93. color = LAYOUT_CATEGORY_COLORS_BGR.get(category, (128, 128, 128))
  94. thickness = (
  95. LAYOUT_HIGHLIGHT_LINE_THICKNESS
  96. if category in LAYOUT_HIGHLIGHT_CATEGORIES
  97. else LAYOUT_DEFAULT_LINE_THICKNESS
  98. )
  99. x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
  100. if category in LAYOUT_HIGHLIGHT_CATEGORIES:
  101. cv2.rectangle(
  102. vis, (x1, y1), (x2, y2),
  103. LAYOUT_HIGHLIGHT_OUTLINE_BGR,
  104. thickness + 2,
  105. )
  106. cv2.rectangle(vis, (x1, y1), (x2, y2), color, thickness)
  107. label = category
  108. confidence = result.get('confidence', result.get('score', 0))
  109. if confidence:
  110. label += f":{float(confidence):.2f}"
  111. font = cv2.FONT_HERSHEY_SIMPLEX
  112. font_scale = 0.5 if category in LAYOUT_HIGHLIGHT_CATEGORIES else 0.4
  113. text_thickness = 1
  114. (text_width, text_height), baseline = cv2.getTextSize(
  115. label, font, font_scale, text_thickness
  116. )
  117. text_y = max(y1 - baseline - 1, text_height + baseline)
  118. cv2.rectangle(
  119. vis,
  120. (x1, text_y - text_height - baseline - 2),
  121. (x1 + text_width, text_y),
  122. color,
  123. -1,
  124. )
  125. cv2.putText(
  126. vis, label, (x1, text_y - baseline - 1),
  127. font, font_scale, (255, 255, 255), text_thickness,
  128. )
  129. return vis
  130. def _draw_dashed_segment(
  131. vis: np.ndarray,
  132. p1: np.ndarray,
  133. p2: np.ndarray,
  134. color: tuple,
  135. thickness: int,
  136. *,
  137. dash_length: int = OCR_BOX_DASH_LENGTH,
  138. gap_length: int = OCR_BOX_DASH_GAP,
  139. ) -> None:
  140. """在 p1→p2 上绘制虚线段。"""
  141. start = p1.astype(np.float64)
  142. end = p2.astype(np.float64)
  143. vec = end - start
  144. length = float(np.linalg.norm(vec))
  145. if length < 1e-6:
  146. return
  147. direction = vec / length
  148. pos = 0.0
  149. draw = True
  150. while pos < length:
  151. seg = float(dash_length if draw else gap_length)
  152. seg_end = min(pos + seg, length)
  153. if draw:
  154. s = (start + direction * pos).astype(np.int32)
  155. e = (start + direction * seg_end).astype(np.int32)
  156. cv2.line(
  157. vis,
  158. (int(s[0]), int(s[1])),
  159. (int(e[0]), int(e[1])),
  160. color,
  161. thickness,
  162. cv2.LINE_AA,
  163. )
  164. pos = seg_end
  165. draw = not draw
  166. def _draw_span_outline(
  167. vis: np.ndarray,
  168. pts: np.ndarray,
  169. color: tuple,
  170. thickness: int,
  171. *,
  172. dashed: bool,
  173. ) -> None:
  174. n = len(pts)
  175. if n < 2:
  176. return
  177. for i in range(n):
  178. p1 = pts[i]
  179. p2 = pts[(i + 1) % n]
  180. if dashed:
  181. _draw_dashed_segment(vis, p1, p2, color, thickness)
  182. else:
  183. cv2.line(
  184. vis,
  185. (int(p1[0]), int(p1[1])),
  186. (int(p2[0]), int(p2[1])),
  187. color,
  188. thickness,
  189. cv2.LINE_AA,
  190. )
  191. def draw_ocr_spans_cv2(
  192. image: Union[np.ndarray, Image.Image],
  193. spans: List[Dict[str, Any]],
  194. *,
  195. max_label_chars: int = 12,
  196. ) -> np.ndarray:
  197. """在 BGR 图像上绘制 OCR span(poly 或 bbox);无文字用虚线框。
  198. span 可带 category='seal' 使用印章专用亮橙色,否则使用亮蓝。
  199. """
  200. vis = _to_bgr(image)
  201. for span in spans:
  202. poly = span.get('poly')
  203. bbox = span.get('bbox', [])
  204. pts = None
  205. if poly and len(poly) >= 4:
  206. pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
  207. elif bbox and len(bbox) >= 4:
  208. x0, y0, x1, y1 = map(int, bbox[:4])
  209. pts = np.array(
  210. [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], dtype=np.int32
  211. )
  212. if pts is not None:
  213. text_raw = str(span.get('text', '') or '').strip()
  214. color = _seal_ocr_box_color_bgr() if span.get('category') == 'seal' else _ocr_box_color_bgr()
  215. _draw_span_outline(
  216. vis,
  217. pts,
  218. color,
  219. OCR_BOX_LINE_THICKNESS,
  220. dashed=not text_raw,
  221. )
  222. text = str(span.get('text', '')).strip()[:max_label_chars]
  223. if text and pts is not None:
  224. color = _seal_ocr_box_color_bgr() if span.get('category') == 'seal' else _ocr_box_color_bgr()
  225. x, y = int(pts[0][0]), int(pts[0][1])
  226. cv2.putText(
  227. vis, text, (x, max(y - 2, 10)),
  228. cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1, cv2.LINE_AA,
  229. )
  230. return vis
  231. def save_layout_debug(
  232. image: Union[np.ndarray, Image.Image],
  233. layout_results: List[Dict[str, Any]],
  234. output_dir: Union[str, Path],
  235. page_name: str,
  236. *,
  237. suffix: str = 'raw',
  238. subdir: str = 'layout_detection',
  239. image_format: str = 'jpg',
  240. save_json: bool = True,
  241. ) -> Optional[Dict[str, str]]:
  242. """保存 layout 模块 debug 图与 JSON。"""
  243. if not layout_results or not output_dir:
  244. return None
  245. try:
  246. fmt = (image_format or 'jpg').lstrip('.')
  247. debug_dir = resolve_module_debug_dir(output_dir, subdir)
  248. vis = draw_layout_boxes_cv2(image, layout_results)
  249. img_path = debug_dir / f'{page_name}_layout_{suffix}.{fmt}'
  250. cv2.imwrite(str(img_path), vis)
  251. paths: Dict[str, str] = {'image': str(img_path)}
  252. logger.info(f"Saved layout detection image ({suffix}): {img_path}")
  253. if save_json:
  254. json_data = {
  255. 'page_name': page_name,
  256. 'suffix': suffix,
  257. 'count': len(layout_results),
  258. 'results': [
  259. {
  260. 'category': r.get('category'),
  261. 'bbox': r.get('bbox'),
  262. 'confidence': r.get('confidence', r.get('score', 0.0)),
  263. }
  264. for r in layout_results
  265. ],
  266. }
  267. json_path = debug_dir / f'{page_name}_layout_{suffix}.json'
  268. json_path.write_text(
  269. json.dumps(json_data, ensure_ascii=False, indent=2, default=_json_default),
  270. encoding='utf-8',
  271. )
  272. paths['json'] = str(json_path)
  273. logger.info(f"Saved layout detection JSON ({suffix}): {json_path}")
  274. return paths
  275. except Exception as e:
  276. logger.warning(f"Failed to save layout debug ({suffix}): {e}")
  277. return None
  278. def save_ocr_debug(
  279. image: Union[np.ndarray, Image.Image],
  280. spans: List[Dict[str, Any]],
  281. output_dir: Union[str, Path],
  282. page_name: str,
  283. *,
  284. subdir: str = 'ocr_recognition',
  285. image_format: str = 'png',
  286. save_json: bool = True,
  287. ) -> Optional[Dict[str, str]]:
  288. """保存 OCR 模块 debug 图与 JSON。"""
  289. if not output_dir:
  290. return None
  291. try:
  292. fmt = (image_format or 'png').lstrip('.')
  293. debug_dir = resolve_module_debug_dir(output_dir, subdir)
  294. vis = draw_ocr_spans_cv2(image, spans or [])
  295. img_path = debug_dir / f'{page_name}_ocr_spans.{fmt}'
  296. cv2.imwrite(str(img_path), vis)
  297. paths: Dict[str, str] = {'image': str(img_path)}
  298. logger.info(f"Saved OCR debug image: {img_path}")
  299. if save_json:
  300. json_data = {
  301. 'page_name': page_name,
  302. 'count': len(spans or []),
  303. 'spans': [
  304. {
  305. 'bbox': s.get('bbox'),
  306. 'poly': s.get('poly'),
  307. 'text': s.get('text'),
  308. 'confidence': s.get('confidence'),
  309. }
  310. for s in (spans or [])
  311. ],
  312. }
  313. json_path = debug_dir / f'{page_name}_ocr_spans.json'
  314. json_path.write_text(
  315. json.dumps(json_data, ensure_ascii=False, indent=2, default=_json_default),
  316. encoding='utf-8',
  317. )
  318. paths['json'] = str(json_path)
  319. logger.info(f"Saved OCR debug JSON: {json_path}")
  320. return paths
  321. except Exception as e:
  322. logger.warning(f"Failed to save OCR debug: {e}")
  323. return None