visualization_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. """
  2. 可视化工具模块
  3. 提供文档处理结果的可视化功能:
  4. - Layout 布局可视化
  5. - OCR 结果可视化
  6. - 图片元素保存
  7. """
  8. from pathlib import Path
  9. from typing import Dict, Any, List, Tuple
  10. import numpy as np
  11. from PIL import Image, ImageDraw, ImageFont
  12. import cv2
  13. from loguru import logger
  14. class VisualizationUtils:
  15. """可视化工具类"""
  16. # 颜色映射(与 MinerU BlockType / EnhancedDocPipeline 类别保持一致)
  17. COLOR_MAP = {
  18. # 文本类元素 (TEXT_CATEGORIES)
  19. 'title': (102, 102, 255), # 蓝色
  20. 'text': (153, 0, 76), # 深红
  21. 'ocr_text': (153, 0, 76), # 深红(同 text)
  22. 'low_score_text': (200, 100, 100), # 浅红
  23. 'header': (128, 128, 128), # 灰色
  24. 'footer': (128, 128, 128), # 灰色
  25. 'page_number': (160, 160, 160), # 浅灰
  26. 'ref_text': (180, 180, 180), # 浅灰
  27. 'aside_text': (180, 180, 180), # 浅灰
  28. 'page_footnote': (200, 200, 200), # 浅灰
  29. # 表格相关元素
  30. 'table': (204, 204, 0), # 黄色
  31. 'table_body': (204, 204, 0), # 黄色
  32. 'table_caption': (255, 255, 102), # 浅黄
  33. 'table_footnote': (229, 255, 204), # 浅黄绿
  34. # 图片相关元素
  35. 'image': (153, 255, 51), # 绿色
  36. 'image_body': (153, 255, 51), # 绿色
  37. 'figure': (153, 255, 51), # 绿色
  38. 'image_caption': (102, 178, 255), # 浅蓝
  39. 'image_footnote': (255, 178, 102), # 橙色
  40. # 公式类元素
  41. 'interline_equation': (0, 255, 0), # 亮绿
  42. 'inline_equation': (0, 200, 0), # 绿色
  43. 'equation': (0, 220, 0), # 绿色
  44. 'interline_equation_yolo': (0, 180, 0),
  45. 'interline_equation_number': (0, 160, 0),
  46. # 代码类元素
  47. 'code': (102, 0, 204), # 紫色
  48. 'code_body': (102, 0, 204), # 紫色
  49. 'code_caption': (153, 51, 255), # 浅紫
  50. 'algorithm': (128, 0, 255), # 紫色
  51. # 列表类元素
  52. 'list': (40, 169, 92), # 青绿
  53. 'index': (60, 180, 100), # 青绿
  54. # 图表 / 印章
  55. 'chart': (0, 200, 200),
  56. 'seal': (255, 140, 0), # 亮橙(RGB),debug 与最终 layout 图一致
  57. # 丢弃类元素
  58. 'abandon': (100, 100, 100), # 深灰
  59. 'discarded': (100, 100, 100), # 深灰
  60. # 错误
  61. 'error': (255, 0, 0), # 红色
  62. }
  63. # OCR 框颜色(与 module_debug_viz.OCR_BOX_COLOR_BGR 一致:亮蓝 BGR→RGB)
  64. OCR_BOX_COLOR = (0, 0, 255)
  65. CELL_BOX_COLOR = (0, 0, 255)
  66. DISCARD_COLOR = (128, 128, 128) # 灰色
  67. @staticmethod
  68. def save_image_elements(
  69. results: Dict[str, Any],
  70. images_dir: Path,
  71. doc_name: str,
  72. is_pdf: bool = True
  73. ) -> List[str]:
  74. """
  75. 保存图片元素
  76. 命名规则:
  77. - PDF输入: 文件名_page_001_image_1.png
  78. - 图片输入(单页): 文件名_image_1.png
  79. Args:
  80. results: 处理结果
  81. images_dir: 图片输出目录
  82. doc_name: 文档名称
  83. is_pdf: 是否为 PDF 输入
  84. Returns:
  85. 保存的图片路径列表
  86. """
  87. saved_paths = []
  88. image_count = 0
  89. total_pages = len(results.get('pages', []))
  90. for page in results.get('pages', []):
  91. page_idx = page.get('page_idx', 0)
  92. for element in page.get('elements', []):
  93. if element.get('type') in ['image', 'image_body', 'figure']:
  94. content = element.get('content', {})
  95. image_data = content.get('image_data')
  96. if image_data is not None:
  97. image_count += 1
  98. # 根据输入类型决定命名
  99. if is_pdf or total_pages > 1:
  100. image_filename = f"{doc_name}_page_{page_idx + 1}_image_{image_count}.png"
  101. else:
  102. image_filename = f"{doc_name}_image_{image_count}.png"
  103. image_path = images_dir / image_filename
  104. try:
  105. if isinstance(image_data, np.ndarray):
  106. cv2.imwrite(str(image_path), image_data)
  107. else:
  108. Image.fromarray(image_data).save(image_path)
  109. # 更新路径(只保存文件名)
  110. content['image_path'] = image_filename
  111. content.pop('image_data', None)
  112. saved_paths.append(str(image_path))
  113. logger.debug(f"🖼️ Image saved: {image_path}")
  114. except Exception as e:
  115. logger.warning(f"Failed to save image: {e}")
  116. if image_count > 0:
  117. logger.info(f"🖼️ {image_count} images saved to: {images_dir}")
  118. return saved_paths
  119. @staticmethod
  120. def save_layout_images(
  121. results: Dict[str, Any],
  122. output_dir: Path,
  123. doc_name: str,
  124. draw_type_label: bool = True,
  125. draw_bbox_number: bool = True,
  126. is_pdf: bool = True
  127. ) -> List[str]:
  128. """
  129. 保存 Layout 可视化图片
  130. 命名规则:
  131. - PDF输入: 文件名_page_001_layout.png
  132. - 图片输入(单页): 文件名_layout.png
  133. Args:
  134. results: 处理结果
  135. output_dir: 输出目录
  136. doc_name: 文档名称
  137. draw_type_label: 是否绘制类型标签
  138. draw_bbox_number: 是否绘制序号
  139. is_pdf: 是否为 PDF 输入
  140. Returns:
  141. 保存的图片路径列表
  142. """
  143. layout_paths = []
  144. total_pages = len(results.get('pages', []))
  145. for page in results.get('pages', []):
  146. page_idx = page.get('page_idx', 0)
  147. processed_image = page.get('original_image')
  148. if processed_image is None:
  149. processed_image = page.get('processed_image')
  150. if processed_image is None:
  151. logger.warning(f"Page {page_idx}: No image data found for layout visualization")
  152. continue
  153. if isinstance(processed_image, np.ndarray):
  154. image = Image.fromarray(processed_image).convert('RGB')
  155. elif isinstance(processed_image, Image.Image):
  156. image = processed_image.convert('RGB')
  157. else:
  158. continue
  159. draw = ImageDraw.Draw(image, 'RGBA')
  160. font = VisualizationUtils._get_font(14)
  161. # 绘制普通元素
  162. for idx, element in enumerate(page.get('elements', []), 1):
  163. elem_type = element.get('type', '')
  164. bbox = element.get('bbox', [0, 0, 0, 0])
  165. if len(bbox) < 4:
  166. continue
  167. x0, y0, x1, y1 = map(int, bbox[:4])
  168. color = VisualizationUtils.COLOR_MAP.get(elem_type, (255, 0, 0))
  169. # 半透明填充
  170. overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
  171. overlay_draw = ImageDraw.Draw(overlay)
  172. overlay_draw.rectangle([x0, y0, x1, y1], fill=(*color, 50))
  173. image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
  174. draw = ImageDraw.Draw(image)
  175. # 边框
  176. draw.rectangle([x0, y0, x1, y1], outline=color, width=2)
  177. # 类型标签
  178. if draw_type_label:
  179. label = elem_type.replace('_', ' ').title()
  180. bbox_label = draw.textbbox((x0 + 2, y0 + 2), label, font=font)
  181. draw.rectangle(bbox_label, fill=color)
  182. draw.text((x0 + 2, y0 + 2), label, fill='white', font=font)
  183. # 序号
  184. if draw_bbox_number:
  185. number_text = str(idx)
  186. bbox_number = draw.textbbox((x1 - 25, y0 + 2), number_text, font=font)
  187. draw.rectangle(bbox_number, fill=(255, 0, 0))
  188. draw.text((x1 - 25, y0 + 2), number_text, fill='white', font=font)
  189. # 绘制丢弃元素(灰色样式)
  190. for idx, element in enumerate(page.get('discarded_blocks', []), 1):
  191. original_category = element.get('original_category', 'unknown')
  192. bbox = element.get('bbox', [0, 0, 0, 0])
  193. if len(bbox) < 4:
  194. continue
  195. x0, y0, x1, y1 = map(int, bbox[:4])
  196. # 半透明填充
  197. overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
  198. overlay_draw = ImageDraw.Draw(overlay)
  199. overlay_draw.rectangle([x0, y0, x1, y1], fill=(*VisualizationUtils.DISCARD_COLOR, 30))
  200. image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
  201. draw = ImageDraw.Draw(image)
  202. # 灰色边框
  203. draw.rectangle([x0, y0, x1, y1], outline=VisualizationUtils.DISCARD_COLOR, width=1)
  204. # 类型标签
  205. if draw_type_label:
  206. label = f"D:{original_category}"
  207. bbox_label = draw.textbbox((x0 + 2, y0 + 2), label, font=font)
  208. draw.rectangle(bbox_label, fill=VisualizationUtils.DISCARD_COLOR)
  209. draw.text((x0 + 2, y0 + 2), label, fill='white', font=font)
  210. # 根据输入类型决定命名
  211. if is_pdf or total_pages > 1:
  212. layout_path = output_dir / f"{doc_name}_page_{page_idx + 1:03d}_layout.png"
  213. else:
  214. layout_path = output_dir / f"{doc_name}_layout.png"
  215. image.save(layout_path)
  216. layout_paths.append(str(layout_path))
  217. logger.info(f"🖼️ Layout image saved: {layout_path}")
  218. return layout_paths
  219. @staticmethod
  220. def save_ocr_images(
  221. results: Dict[str, Any],
  222. output_dir: Path,
  223. doc_name: str,
  224. is_pdf: bool = True
  225. ) -> List[str]:
  226. """
  227. 保存 OCR 可视化图片
  228. 命名规则:
  229. - PDF输入: 文件名_page_001_ocr.png
  230. - 图片输入(单页): 文件名_ocr.png
  231. Args:
  232. results: 处理结果
  233. output_dir: 输出目录
  234. doc_name: 文档名称
  235. is_pdf: 是否为 PDF 输入
  236. Returns:
  237. 保存的图片路径列表
  238. """
  239. ocr_paths = []
  240. total_pages = len(results.get('pages', []))
  241. for page in results.get('pages', []):
  242. page_idx = page.get('page_idx', 0)
  243. processed_image = page.get('original_image')
  244. if processed_image is None:
  245. processed_image = page.get('processed_image')
  246. if processed_image is None:
  247. logger.warning(f"Page {page_idx}: No image data found for OCR visualization")
  248. continue
  249. if isinstance(processed_image, np.ndarray):
  250. image = Image.fromarray(processed_image).convert('RGB')
  251. elif isinstance(processed_image, Image.Image):
  252. image = processed_image.convert('RGB')
  253. else:
  254. continue
  255. draw = ImageDraw.Draw(image)
  256. font = VisualizationUtils._get_font(10)
  257. for element in page.get('elements', []):
  258. content = element.get('content', {})
  259. # OCR 文本框
  260. ocr_details = content.get('ocr_details', [])
  261. for ocr_item in ocr_details:
  262. ocr_bbox = ocr_item.get('bbox', [])
  263. if ocr_bbox:
  264. VisualizationUtils._draw_polygon(
  265. draw, ocr_bbox, VisualizationUtils.OCR_BOX_COLOR, width=1
  266. )
  267. # 表格单元格
  268. cells = content.get('cells', [])
  269. for cell in cells:
  270. cell_bbox = cell.get('bbox', [])
  271. if cell_bbox and len(cell_bbox) >= 4:
  272. x0, y0, x1, y1 = map(int, cell_bbox[:4])
  273. draw.rectangle(
  274. [x0, y0, x1, y1],
  275. outline=VisualizationUtils.CELL_BOX_COLOR,
  276. width=2
  277. )
  278. cell_text = cell.get('text', '')[:10]
  279. if cell_text:
  280. draw.text(
  281. (x0 + 2, y0 + 2),
  282. cell_text,
  283. fill=VisualizationUtils.CELL_BOX_COLOR,
  284. font=font
  285. )
  286. # OCR 框
  287. ocr_boxes = content.get('ocr_boxes', [])
  288. for ocr_box in ocr_boxes:
  289. bbox = ocr_box.get('bbox', [])
  290. if bbox:
  291. VisualizationUtils._draw_polygon(
  292. draw, bbox, VisualizationUtils.OCR_BOX_COLOR, width=1
  293. )
  294. # 绘制丢弃元素的 OCR 框
  295. for element in page.get('discarded_blocks', []):
  296. bbox = element.get('bbox', [0, 0, 0, 0])
  297. content = element.get('content', {})
  298. if len(bbox) >= 4:
  299. x0, y0, x1, y1 = map(int, bbox[:4])
  300. draw.rectangle(
  301. [x0, y0, x1, y1],
  302. outline=VisualizationUtils.DISCARD_COLOR,
  303. width=1
  304. )
  305. ocr_details = content.get('ocr_details', [])
  306. for ocr_item in ocr_details:
  307. ocr_bbox = ocr_item.get('bbox', [])
  308. if ocr_bbox:
  309. VisualizationUtils._draw_polygon(
  310. draw, ocr_bbox, VisualizationUtils.DISCARD_COLOR, width=1
  311. )
  312. # 根据输入类型决定命名
  313. if is_pdf or total_pages > 1:
  314. ocr_path = output_dir / f"{doc_name}_page_{page_idx + 1:03d}_ocr.png"
  315. else:
  316. ocr_path = output_dir / f"{doc_name}_ocr.png"
  317. image.save(ocr_path)
  318. ocr_paths.append(str(ocr_path))
  319. logger.info(f"🖼️ OCR image saved: {ocr_path}")
  320. return ocr_paths
  321. @staticmethod
  322. def _draw_polygon(
  323. draw: ImageDraw.Draw,
  324. bbox: List,
  325. color: Tuple[int, int, int],
  326. width: int = 1
  327. ):
  328. """
  329. 绘制多边形或矩形
  330. Args:
  331. draw: ImageDraw 对象
  332. bbox: 坐标(4点多边形或矩形)
  333. color: 颜色
  334. width: 线宽
  335. """
  336. if isinstance(bbox[0], (list, tuple)):
  337. points = [(int(p[0]), int(p[1])) for p in bbox]
  338. points.append(points[0])
  339. draw.line(points, fill=color, width=width)
  340. elif len(bbox) >= 4:
  341. x0, y0, x1, y1 = map(int, bbox[:4])
  342. draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
  343. @staticmethod
  344. def _get_font(size: int) -> ImageFont.FreeTypeFont:
  345. """
  346. 获取字体
  347. Args:
  348. size: 字体大小
  349. Returns:
  350. 字体对象
  351. """
  352. font_paths = [
  353. "/System/Library/Fonts/Helvetica.ttc",
  354. "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
  355. "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
  356. ]
  357. for font_path in font_paths:
  358. try:
  359. return ImageFont.truetype(font_path, size)
  360. except:
  361. continue
  362. return ImageFont.load_default()
  363. @staticmethod
  364. def draw_bbox_on_image(image: Image.Image, bbox: List[int], color: str = "red", width: int = 3) -> Image.Image:
  365. """
  366. 在图片上绘制bbox框
  367. Args:
  368. image: PIL Image 对象
  369. bbox: 边界框坐标 [x1, y1, x2, y2]
  370. color: 边框颜色(字符串,如 "red", "blue", "green")
  371. width: 边框宽度
  372. Returns:
  373. 绘制了 bbox 的图像副本
  374. """
  375. img_copy = image.copy()
  376. draw = ImageDraw.Draw(img_copy)
  377. x1, y1, x2, y2 = bbox
  378. # 绘制矩形框
  379. draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
  380. # 添加半透明填充
  381. overlay = Image.new('RGBA', img_copy.size, (0, 0, 0, 0))
  382. overlay_draw = ImageDraw.Draw(overlay)
  383. color_map = {
  384. "red": (255, 0, 0, 30),
  385. "blue": (0, 0, 255, 30),
  386. "green": (0, 255, 0, 30)
  387. }
  388. fill_color = color_map.get(color, (255, 255, 0, 30))
  389. overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
  390. img_copy = Image.alpha_composite(img_copy.convert('RGBA'), overlay).convert('RGB')
  391. return img_copy