module_debug_viz.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. 模块级 Debug 可视化(Layout / OCR)
  3. 用于 ``{output_dir}/debug/{subdir}/`` 下基于 inference_image 的调试图;
  4. 用户审计图由 VisualizationUtils + original_image 负责,不在此模块。
  5. """
  6. from __future__ import annotations
  7. import json
  8. from pathlib import Path
  9. from typing import Any, Dict, List, Optional, Union
  10. import cv2
  11. import numpy as np
  12. from loguru import logger
  13. from PIL import Image
  14. # 各模块 debug_options 默认落盘根目录(相对 pipeline output_dir)
  15. MODULE_DEBUG_ROOT = "debug"
  16. def resolve_module_debug_dir(
  17. output_dir: Union[str, Path],
  18. subdir: str,
  19. *,
  20. debug_root: str = MODULE_DEBUG_ROOT,
  21. ) -> Path:
  22. """``{output_dir}/{debug_root}/{subdir}/``,目录不存在则创建。"""
  23. debug_dir = Path(output_dir) / debug_root / subdir
  24. debug_dir.mkdir(parents=True, exist_ok=True)
  25. return debug_dir
  26. LAYOUT_CATEGORY_COLORS_BGR = {
  27. 'table_body': (0, 0, 255),
  28. 'table_caption': (0, 0, 200),
  29. 'table_footnote': (0, 0, 150),
  30. 'text': (255, 0, 0),
  31. 'title': (0, 255, 255),
  32. 'header': (255, 0, 255),
  33. 'footer': (0, 165, 255),
  34. 'image_body': (0, 255, 0),
  35. 'image_caption': (0, 200, 0),
  36. 'image_footnote': (0, 150, 0),
  37. 'abandon': (128, 128, 128),
  38. }
  39. # 亮蓝(BGR),在白底/浅灰流水上比黄色更易辨认;与 layout 红色框区分
  40. OCR_BOX_COLOR_BGR = (255, 0, 0)
  41. OCR_BOX_LINE_THICKNESS = 2
  42. OCR_BOX_DASH_LENGTH = 8
  43. OCR_BOX_DASH_GAP = 6
  44. def _to_bgr(image: Union[np.ndarray, Image.Image]) -> np.ndarray:
  45. if isinstance(image, Image.Image):
  46. arr = np.array(image)
  47. else:
  48. arr = image.copy()
  49. if arr.ndim == 2:
  50. return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
  51. if arr.shape[2] == 3:
  52. return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
  53. return arr
  54. def draw_layout_boxes_cv2(
  55. image: Union[np.ndarray, Image.Image],
  56. layout_results: List[Dict[str, Any]],
  57. ) -> np.ndarray:
  58. """在 BGR 图像上绘制 layout 检测框,返回新图像。"""
  59. vis = _to_bgr(image)
  60. for result in layout_results:
  61. bbox = result.get('bbox', [])
  62. if not bbox or len(bbox) < 4:
  63. continue
  64. category = result.get('category', 'unknown')
  65. color = LAYOUT_CATEGORY_COLORS_BGR.get(category, (128, 128, 128))
  66. x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
  67. cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
  68. label = category
  69. confidence = result.get('confidence', result.get('score', 0))
  70. if confidence:
  71. label += f":{float(confidence):.2f}"
  72. font = cv2.FONT_HERSHEY_SIMPLEX
  73. font_scale = 0.4
  74. text_thickness = 1
  75. (text_width, text_height), baseline = cv2.getTextSize(
  76. label, font, font_scale, text_thickness
  77. )
  78. text_y = max(y1 - baseline - 1, text_height + baseline)
  79. cv2.rectangle(
  80. vis,
  81. (x1, text_y - text_height - baseline - 2),
  82. (x1 + text_width, text_y),
  83. color,
  84. -1,
  85. )
  86. cv2.putText(
  87. vis, label, (x1, text_y - baseline - 1),
  88. font, font_scale, (255, 255, 255), text_thickness,
  89. )
  90. return vis
  91. def _draw_dashed_segment(
  92. vis: np.ndarray,
  93. p1: np.ndarray,
  94. p2: np.ndarray,
  95. color: tuple,
  96. thickness: int,
  97. *,
  98. dash_length: int = OCR_BOX_DASH_LENGTH,
  99. gap_length: int = OCR_BOX_DASH_GAP,
  100. ) -> None:
  101. """在 p1→p2 上绘制虚线段。"""
  102. start = p1.astype(np.float64)
  103. end = p2.astype(np.float64)
  104. vec = end - start
  105. length = float(np.linalg.norm(vec))
  106. if length < 1e-6:
  107. return
  108. direction = vec / length
  109. pos = 0.0
  110. draw = True
  111. while pos < length:
  112. seg = float(dash_length if draw else gap_length)
  113. seg_end = min(pos + seg, length)
  114. if draw:
  115. s = (start + direction * pos).astype(np.int32)
  116. e = (start + direction * seg_end).astype(np.int32)
  117. cv2.line(
  118. vis,
  119. (int(s[0]), int(s[1])),
  120. (int(e[0]), int(e[1])),
  121. color,
  122. thickness,
  123. cv2.LINE_AA,
  124. )
  125. pos = seg_end
  126. draw = not draw
  127. def _draw_span_outline(
  128. vis: np.ndarray,
  129. pts: np.ndarray,
  130. color: tuple,
  131. thickness: int,
  132. *,
  133. dashed: bool,
  134. ) -> None:
  135. n = len(pts)
  136. if n < 2:
  137. return
  138. for i in range(n):
  139. p1 = pts[i]
  140. p2 = pts[(i + 1) % n]
  141. if dashed:
  142. _draw_dashed_segment(vis, p1, p2, color, thickness)
  143. else:
  144. cv2.line(
  145. vis,
  146. (int(p1[0]), int(p1[1])),
  147. (int(p2[0]), int(p2[1])),
  148. color,
  149. thickness,
  150. cv2.LINE_AA,
  151. )
  152. def draw_ocr_spans_cv2(
  153. image: Union[np.ndarray, Image.Image],
  154. spans: List[Dict[str, Any]],
  155. *,
  156. max_label_chars: int = 12,
  157. ) -> np.ndarray:
  158. """在 BGR 图像上绘制 OCR span(poly 或 bbox);无文字用虚线框。"""
  159. vis = _to_bgr(image)
  160. for span in spans:
  161. poly = span.get('poly')
  162. bbox = span.get('bbox', [])
  163. pts = None
  164. if poly and len(poly) >= 4:
  165. pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
  166. elif bbox and len(bbox) >= 4:
  167. x0, y0, x1, y1 = map(int, bbox[:4])
  168. pts = np.array(
  169. [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], dtype=np.int32
  170. )
  171. if pts is not None:
  172. text_raw = str(span.get('text', '') or '').strip()
  173. _draw_span_outline(
  174. vis,
  175. pts,
  176. OCR_BOX_COLOR_BGR,
  177. OCR_BOX_LINE_THICKNESS,
  178. dashed=not text_raw,
  179. )
  180. text = str(span.get('text', '')).strip()[:max_label_chars]
  181. if text and pts is not None:
  182. x, y = int(pts[0][0]), int(pts[0][1])
  183. cv2.putText(
  184. vis, text, (x, max(y - 2, 10)),
  185. cv2.FONT_HERSHEY_SIMPLEX, 0.35, OCR_BOX_COLOR_BGR, 1, cv2.LINE_AA,
  186. )
  187. return vis
  188. def save_layout_debug(
  189. image: Union[np.ndarray, Image.Image],
  190. layout_results: List[Dict[str, Any]],
  191. output_dir: Union[str, Path],
  192. page_name: str,
  193. *,
  194. suffix: str = 'raw',
  195. subdir: str = 'layout_detection',
  196. image_format: str = 'jpg',
  197. save_json: bool = True,
  198. ) -> Optional[Dict[str, str]]:
  199. """保存 layout 模块 debug 图与 JSON。"""
  200. if not layout_results or not output_dir:
  201. return None
  202. try:
  203. fmt = (image_format or 'jpg').lstrip('.')
  204. debug_dir = resolve_module_debug_dir(output_dir, subdir)
  205. vis = draw_layout_boxes_cv2(image, layout_results)
  206. img_path = debug_dir / f'{page_name}_layout_{suffix}.{fmt}'
  207. cv2.imwrite(str(img_path), vis)
  208. paths: Dict[str, str] = {'image': str(img_path)}
  209. logger.info(f"Saved layout detection image ({suffix}): {img_path}")
  210. if save_json:
  211. json_data = {
  212. 'page_name': page_name,
  213. 'suffix': suffix,
  214. 'count': len(layout_results),
  215. 'results': [
  216. {
  217. 'category': r.get('category'),
  218. 'bbox': r.get('bbox'),
  219. 'confidence': r.get('confidence', r.get('score', 0.0)),
  220. }
  221. for r in layout_results
  222. ],
  223. }
  224. json_path = debug_dir / f'{page_name}_layout_{suffix}.json'
  225. json_path.write_text(
  226. json.dumps(json_data, ensure_ascii=False, indent=2),
  227. encoding='utf-8',
  228. )
  229. paths['json'] = str(json_path)
  230. logger.info(f"Saved layout detection JSON ({suffix}): {json_path}")
  231. return paths
  232. except Exception as e:
  233. logger.warning(f"Failed to save layout debug ({suffix}): {e}")
  234. return None
  235. def save_ocr_debug(
  236. image: Union[np.ndarray, Image.Image],
  237. spans: List[Dict[str, Any]],
  238. output_dir: Union[str, Path],
  239. page_name: str,
  240. *,
  241. subdir: str = 'ocr_recognition',
  242. image_format: str = 'png',
  243. save_json: bool = True,
  244. ) -> Optional[Dict[str, str]]:
  245. """保存 OCR 模块 debug 图与 JSON。"""
  246. if not output_dir:
  247. return None
  248. try:
  249. fmt = (image_format or 'png').lstrip('.')
  250. debug_dir = resolve_module_debug_dir(output_dir, subdir)
  251. vis = draw_ocr_spans_cv2(image, spans or [])
  252. img_path = debug_dir / f'{page_name}_ocr_spans.{fmt}'
  253. cv2.imwrite(str(img_path), vis)
  254. paths: Dict[str, str] = {'image': str(img_path)}
  255. logger.info(f"Saved OCR debug image: {img_path}")
  256. if save_json:
  257. json_data = {
  258. 'page_name': page_name,
  259. 'count': len(spans or []),
  260. 'spans': [
  261. {
  262. 'bbox': s.get('bbox'),
  263. 'poly': s.get('poly'),
  264. 'text': s.get('text'),
  265. 'confidence': s.get('confidence'),
  266. }
  267. for s in (spans or [])
  268. ],
  269. }
  270. json_path = debug_dir / f'{page_name}_ocr_spans.json'
  271. json_path.write_text(
  272. json.dumps(json_data, ensure_ascii=False, indent=2),
  273. encoding='utf-8',
  274. )
  275. paths['json'] = str(json_path)
  276. logger.info(f"Saved OCR debug JSON: {json_path}")
  277. return paths
  278. except Exception as e:
  279. logger.warning(f"Failed to save OCR debug: {e}")
  280. return None