visualization_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. """
  2. 可视化工具模块
  3. 提供文档处理结果的可视化功能:
  4. - Layout 布局可视化
  5. - OCR 结果可视化
  6. - 图片元素保存
  7. """
  8. from pathlib import Path
  9. from typing import Dict, Any, List, Tuple
  10. import numpy as np
  11. from PIL import Image, ImageDraw, ImageFont
  12. import cv2
  13. from loguru import logger
  14. class VisualizationUtils:
  15. """可视化工具类"""
  16. # 颜色映射(与 MinerU 保持一致)
  17. COLOR_MAP = {
  18. 'title': (102, 102, 255), # 蓝色
  19. 'text': (153, 0, 76), # 深红
  20. 'image': (153, 255, 51), # 绿色
  21. 'image_body': (153, 255, 51),
  22. 'image_caption': (102, 178, 255),
  23. 'image_footnote': (255, 178, 102),
  24. 'table': (204, 204, 0), # 黄色
  25. 'table_body': (204, 204, 0),
  26. 'table_caption': (255, 255, 102),
  27. 'table_footnote': (229, 255, 204),
  28. 'interline_equation': (0, 255, 0), # 亮绿
  29. 'inline_equation': (0, 200, 0),
  30. 'list': (40, 169, 92),
  31. 'code': (102, 0, 204), # 紫色
  32. 'header': (128, 128, 128), # 灰色
  33. 'footer': (128, 128, 128),
  34. 'ref_text': (180, 180, 180),
  35. 'ocr_text': (153, 0, 76),
  36. 'error': (255, 0, 0), # 红色
  37. }
  38. # OCR 框颜色
  39. OCR_BOX_COLOR = (0, 255, 0) # 绿色
  40. CELL_BOX_COLOR = (255, 165, 0) # 橙色
  41. DISCARD_COLOR = (128, 128, 128) # 灰色
  42. @staticmethod
  43. def save_image_elements(
  44. results: Dict[str, Any],
  45. images_dir: Path,
  46. doc_name: str
  47. ) -> List[str]:
  48. """
  49. 保存图片元素
  50. Args:
  51. results: 处理结果
  52. images_dir: 图片输出目录
  53. doc_name: 文档名称
  54. Returns:
  55. 保存的图片路径列表
  56. """
  57. saved_paths = []
  58. image_count = 0
  59. for page in results.get('pages', []):
  60. page_idx = page.get('page_idx', 0)
  61. for element in page.get('elements', []):
  62. if element.get('type') in ['image', 'image_body', 'figure']:
  63. content = element.get('content', {})
  64. image_data = content.get('image_data')
  65. if image_data is not None:
  66. image_count += 1
  67. image_filename = f"{doc_name}_page_{page_idx + 1}_image_{image_count}.png"
  68. image_path = images_dir / image_filename
  69. try:
  70. if isinstance(image_data, np.ndarray):
  71. cv2.imwrite(str(image_path), image_data)
  72. else:
  73. Image.fromarray(image_data).save(image_path)
  74. # 更新路径(只保存文件名)
  75. content['image_path'] = image_filename
  76. content.pop('image_data', None)
  77. saved_paths.append(str(image_path))
  78. logger.debug(f"🖼️ Image saved: {image_path}")
  79. except Exception as e:
  80. logger.warning(f"Failed to save image: {e}")
  81. if image_count > 0:
  82. logger.info(f"🖼️ {image_count} images saved to: {images_dir}")
  83. return saved_paths
  84. @staticmethod
  85. def save_layout_images(
  86. results: Dict[str, Any],
  87. output_dir: Path,
  88. doc_name: str,
  89. draw_type_label: bool = True,
  90. draw_bbox_number: bool = True
  91. ) -> List[str]:
  92. """
  93. 保存 Layout 可视化图片
  94. Args:
  95. results: 处理结果
  96. output_dir: 输出目录
  97. doc_name: 文档名称
  98. draw_type_label: 是否绘制类型标签
  99. draw_bbox_number: 是否绘制序号
  100. Returns:
  101. 保存的图片路径列表
  102. """
  103. layout_paths = []
  104. for page in results.get('pages', []):
  105. page_idx = page.get('page_idx', 0)
  106. processed_image = page.get('original_image')
  107. if processed_image is None:
  108. processed_image = page.get('processed_image')
  109. if processed_image is None:
  110. logger.warning(f"Page {page_idx}: No image data found for layout visualization")
  111. continue
  112. if isinstance(processed_image, np.ndarray):
  113. image = Image.fromarray(processed_image).convert('RGB')
  114. elif isinstance(processed_image, Image.Image):
  115. image = processed_image.convert('RGB')
  116. else:
  117. continue
  118. draw = ImageDraw.Draw(image, 'RGBA')
  119. font = VisualizationUtils._get_font(14)
  120. # 绘制普通元素
  121. for idx, element in enumerate(page.get('elements', []), 1):
  122. elem_type = element.get('type', '')
  123. bbox = element.get('bbox', [0, 0, 0, 0])
  124. if len(bbox) < 4:
  125. continue
  126. x0, y0, x1, y1 = map(int, bbox[:4])
  127. color = VisualizationUtils.COLOR_MAP.get(elem_type, (255, 0, 0))
  128. # 半透明填充
  129. overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
  130. overlay_draw = ImageDraw.Draw(overlay)
  131. overlay_draw.rectangle([x0, y0, x1, y1], fill=(*color, 50))
  132. image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
  133. draw = ImageDraw.Draw(image)
  134. # 边框
  135. draw.rectangle([x0, y0, x1, y1], outline=color, width=2)
  136. # 类型标签
  137. if draw_type_label:
  138. label = elem_type.replace('_', ' ').title()
  139. bbox_label = draw.textbbox((x0 + 2, y0 + 2), label, font=font)
  140. draw.rectangle(bbox_label, fill=color)
  141. draw.text((x0 + 2, y0 + 2), label, fill='white', font=font)
  142. # 序号
  143. if draw_bbox_number:
  144. number_text = str(idx)
  145. bbox_number = draw.textbbox((x1 - 25, y0 + 2), number_text, font=font)
  146. draw.rectangle(bbox_number, fill=(255, 0, 0))
  147. draw.text((x1 - 25, y0 + 2), number_text, fill='white', font=font)
  148. # 绘制丢弃元素(灰色样式)
  149. for idx, element in enumerate(page.get('discarded_blocks', []), 1):
  150. original_category = element.get('original_category', 'unknown')
  151. bbox = element.get('bbox', [0, 0, 0, 0])
  152. if len(bbox) < 4:
  153. continue
  154. x0, y0, x1, y1 = map(int, bbox[:4])
  155. # 半透明填充
  156. overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
  157. overlay_draw = ImageDraw.Draw(overlay)
  158. overlay_draw.rectangle([x0, y0, x1, y1], fill=(*VisualizationUtils.DISCARD_COLOR, 30))
  159. image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
  160. draw = ImageDraw.Draw(image)
  161. # 灰色边框
  162. draw.rectangle([x0, y0, x1, y1], outline=VisualizationUtils.DISCARD_COLOR, width=1)
  163. # 类型标签
  164. if draw_type_label:
  165. label = f"D:{original_category}"
  166. bbox_label = draw.textbbox((x0 + 2, y0 + 2), label, font=font)
  167. draw.rectangle(bbox_label, fill=VisualizationUtils.DISCARD_COLOR)
  168. draw.text((x0 + 2, y0 + 2), label, fill='white', font=font)
  169. layout_path = output_dir / f"{doc_name}_page_{page_idx + 1}_layout.png"
  170. image.save(layout_path)
  171. layout_paths.append(str(layout_path))
  172. logger.info(f"🖼️ Layout image saved: {layout_path}")
  173. return layout_paths
  174. @staticmethod
  175. def save_ocr_images(
  176. results: Dict[str, Any],
  177. output_dir: Path,
  178. doc_name: str
  179. ) -> List[str]:
  180. """
  181. 保存 OCR 可视化图片
  182. Args:
  183. results: 处理结果
  184. output_dir: 输出目录
  185. doc_name: 文档名称
  186. Returns:
  187. 保存的图片路径列表
  188. """
  189. ocr_paths = []
  190. for page in results.get('pages', []):
  191. page_idx = page.get('page_idx', 0)
  192. processed_image = page.get('original_image')
  193. if processed_image is None:
  194. processed_image = page.get('processed_image')
  195. if processed_image is None:
  196. logger.warning(f"Page {page_idx}: No image data found for OCR visualization")
  197. continue
  198. if isinstance(processed_image, np.ndarray):
  199. image = Image.fromarray(processed_image).convert('RGB')
  200. elif isinstance(processed_image, Image.Image):
  201. image = processed_image.convert('RGB')
  202. else:
  203. continue
  204. draw = ImageDraw.Draw(image)
  205. font = VisualizationUtils._get_font(10)
  206. for element in page.get('elements', []):
  207. content = element.get('content', {})
  208. # OCR 文本框
  209. ocr_details = content.get('ocr_details', [])
  210. for ocr_item in ocr_details:
  211. ocr_bbox = ocr_item.get('bbox', [])
  212. if ocr_bbox:
  213. VisualizationUtils._draw_polygon(
  214. draw, ocr_bbox, VisualizationUtils.OCR_BOX_COLOR, width=1
  215. )
  216. # 表格单元格
  217. cells = content.get('cells', [])
  218. for cell in cells:
  219. cell_bbox = cell.get('bbox', [])
  220. if cell_bbox and len(cell_bbox) >= 4:
  221. x0, y0, x1, y1 = map(int, cell_bbox[:4])
  222. draw.rectangle(
  223. [x0, y0, x1, y1],
  224. outline=VisualizationUtils.CELL_BOX_COLOR,
  225. width=2
  226. )
  227. cell_text = cell.get('text', '')[:10]
  228. if cell_text:
  229. draw.text(
  230. (x0 + 2, y0 + 2),
  231. cell_text,
  232. fill=VisualizationUtils.CELL_BOX_COLOR,
  233. font=font
  234. )
  235. # OCR 框
  236. ocr_boxes = content.get('ocr_boxes', [])
  237. for ocr_box in ocr_boxes:
  238. bbox = ocr_box.get('bbox', [])
  239. if bbox:
  240. VisualizationUtils._draw_polygon(
  241. draw, bbox, VisualizationUtils.OCR_BOX_COLOR, width=1
  242. )
  243. # 绘制丢弃元素的 OCR 框
  244. for element in page.get('discarded_blocks', []):
  245. bbox = element.get('bbox', [0, 0, 0, 0])
  246. content = element.get('content', {})
  247. if len(bbox) >= 4:
  248. x0, y0, x1, y1 = map(int, bbox[:4])
  249. draw.rectangle(
  250. [x0, y0, x1, y1],
  251. outline=VisualizationUtils.DISCARD_COLOR,
  252. width=1
  253. )
  254. ocr_details = content.get('ocr_details', [])
  255. for ocr_item in ocr_details:
  256. ocr_bbox = ocr_item.get('bbox', [])
  257. if ocr_bbox:
  258. VisualizationUtils._draw_polygon(
  259. draw, ocr_bbox, VisualizationUtils.DISCARD_COLOR, width=1
  260. )
  261. ocr_path = output_dir / f"{doc_name}_page_{page_idx + 1}_ocr.png"
  262. image.save(ocr_path)
  263. ocr_paths.append(str(ocr_path))
  264. logger.info(f"🖼️ OCR image saved: {ocr_path}")
  265. return ocr_paths
  266. @staticmethod
  267. def _draw_polygon(
  268. draw: ImageDraw.Draw,
  269. bbox: List,
  270. color: Tuple[int, int, int],
  271. width: int = 1
  272. ):
  273. """
  274. 绘制多边形或矩形
  275. Args:
  276. draw: ImageDraw 对象
  277. bbox: 坐标(4点多边形或矩形)
  278. color: 颜色
  279. width: 线宽
  280. """
  281. if isinstance(bbox[0], (list, tuple)):
  282. points = [(int(p[0]), int(p[1])) for p in bbox]
  283. points.append(points[0])
  284. draw.line(points, fill=color, width=width)
  285. elif len(bbox) >= 4:
  286. x0, y0, x1, y1 = map(int, bbox[:4])
  287. draw.rectangle([x0, y0, x1, y1], outline=color, width=width)
  288. @staticmethod
  289. def _get_font(size: int) -> ImageFont.FreeTypeFont:
  290. """
  291. 获取字体
  292. Args:
  293. size: 字体大小
  294. Returns:
  295. 字体对象
  296. """
  297. font_paths = [
  298. "/System/Library/Fonts/Helvetica.ttc",
  299. "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
  300. "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
  301. ]
  302. for font_path in font_paths:
  303. try:
  304. return ImageFont.truetype(font_path, size)
  305. except:
  306. continue
  307. return ImageFont.load_default()