drawing.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. """
  2. 表格线绘制功能
  3. """
  4. import streamlit as st
  5. from PIL import Image, ImageDraw, ImageFont
  6. import json
  7. import hashlib
  8. def draw_table_lines_with_numbers(image, structure, line_width=2, show_numbers=True):
  9. """
  10. 绘制带编号的表格线(使用线坐标列表)
  11. Args:
  12. image: PIL Image 对象
  13. structure: 表格结构字典(包含 horizontal_lines 和 vertical_lines)
  14. line_width: 线条宽度
  15. show_numbers: 是否显示编号
  16. Returns:
  17. 绘制了表格线和编号的图片
  18. """
  19. img_with_lines = image.copy()
  20. draw = ImageDraw.Draw(img_with_lines)
  21. # 尝试加载字体
  22. try:
  23. font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 20)
  24. except:
  25. font = ImageFont.load_default()
  26. # 使用线坐标列表
  27. horizontal_lines = structure.get('horizontal_lines', [])
  28. vertical_lines = structure.get('vertical_lines', [])
  29. modified_h_lines = structure.get('modified_h_lines', set())
  30. modified_v_lines = structure.get('modified_v_lines', set())
  31. # 计算绘制范围
  32. x_start = vertical_lines[0] if vertical_lines else 0
  33. x_end = vertical_lines[-1] if vertical_lines else img_with_lines.width
  34. y_start = horizontal_lines[0] if horizontal_lines else 0
  35. y_end = horizontal_lines[-1] if horizontal_lines else img_with_lines.height
  36. # 绘制横线
  37. for idx, y in enumerate(horizontal_lines):
  38. color = (255, 0, 0) if idx in modified_h_lines else (0, 0, 255)
  39. draw.line([(x_start, y), (x_end, y)], fill=color, width=line_width)
  40. # 绘制行编号
  41. if show_numbers:
  42. text = f"R{idx+1}"
  43. bbox = draw.textbbox((x_start - 35, y - 10), text, font=font)
  44. draw.rectangle(bbox, fill='white', outline='black')
  45. draw.text((x_start - 35, y - 10), text, fill=color, font=font)
  46. # 绘制竖线
  47. for idx, x in enumerate(vertical_lines):
  48. color = (255, 0, 0) if idx in modified_v_lines else (0, 0, 255)
  49. draw.line([(x, y_start), (x, y_end)], fill=color, width=line_width)
  50. # 绘制列编号
  51. if show_numbers:
  52. text = f"C{idx+1}"
  53. bbox = draw.textbbox((x - 10, y_start - 25), text, font=font)
  54. draw.rectangle(bbox, fill='white', outline='black')
  55. draw.text((x - 10, y_start - 25), text, fill=color, font=font)
  56. bbox = draw.textbbox((x - 10, y_end + 25), text, font=font)
  57. draw.rectangle(bbox, fill='white', outline='black')
  58. draw.text((x - 10, y_end + 25), text, fill=color, font=font)
  59. return img_with_lines
  60. def draw_clean_table_lines(image, structure, line_width=2, line_color=(0, 0, 0)):
  61. """
  62. 绘制纯净的表格线(用于保存)
  63. - 所有线用统一颜色
  64. - 不显示编号
  65. Args:
  66. image: PIL Image 对象
  67. structure: 表格结构字典
  68. line_width: 线条宽度
  69. line_color: 线条颜色,默认黑色 (0, 0, 0)
  70. Returns:
  71. 绘制了纯净表格线的图片
  72. """
  73. img_with_lines = image.copy()
  74. draw = ImageDraw.Draw(img_with_lines)
  75. horizontal_lines = structure.get('horizontal_lines', [])
  76. vertical_lines = structure.get('vertical_lines', [])
  77. if not horizontal_lines or not vertical_lines:
  78. return img_with_lines
  79. # 计算绘制范围
  80. x_start = vertical_lines[0]
  81. x_end = vertical_lines[-1]
  82. y_start = horizontal_lines[0]
  83. y_end = horizontal_lines[-1]
  84. # 绘制横线
  85. for y in horizontal_lines:
  86. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  87. # 绘制竖线
  88. for x in vertical_lines:
  89. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  90. return img_with_lines
  91. def get_structure_hash(structure, line_width, show_numbers):
  92. """生成结构的哈希值,用于判断是否需要重新绘制"""
  93. key_data = {
  94. 'horizontal_lines': structure.get('horizontal_lines', []),
  95. 'vertical_lines': structure.get('vertical_lines', []),
  96. 'modified_h_lines': sorted(list(structure.get('modified_h_lines', set()))),
  97. 'modified_v_lines': sorted(list(structure.get('modified_v_lines', set()))),
  98. 'line_width': line_width,
  99. 'show_numbers': show_numbers
  100. }
  101. key_str = json.dumps(key_data, sort_keys=True)
  102. return hashlib.md5(key_str.encode()).hexdigest()
  103. def get_cached_table_lines_image(image, structure, line_width, show_numbers):
  104. """
  105. 获取缓存的表格线图片,如果缓存不存在或失效则重新绘制
  106. Args:
  107. image: PIL Image 对象
  108. structure: 表格结构字典
  109. line_width: 线条宽度
  110. show_numbers: 是否显示编号
  111. Returns:
  112. 绘制了表格线和编号的图片
  113. """
  114. # 初始化缓存
  115. if 'cached_table_image' not in st.session_state:
  116. st.session_state.cached_table_image = None
  117. if 'cached_table_hash' not in st.session_state:
  118. st.session_state.cached_table_hash = None
  119. # 计算当前结构的哈希
  120. current_hash = get_structure_hash(structure, line_width, show_numbers)
  121. # 检查缓存是否有效
  122. if (st.session_state.cached_table_hash == current_hash and
  123. st.session_state.cached_table_image is not None):
  124. return st.session_state.cached_table_image
  125. # 缓存失效,重新绘制
  126. img_with_lines = draw_table_lines_with_numbers(
  127. image,
  128. structure,
  129. line_width=line_width,
  130. show_numbers=show_numbers
  131. )
  132. # 更新缓存
  133. st.session_state.cached_table_image = img_with_lines
  134. st.session_state.cached_table_hash = current_hash
  135. return img_with_lines
  136. def clear_table_image_cache():
  137. """清除表格图片缓存"""
  138. if 'cached_table_image' in st.session_state:
  139. st.session_state.cached_table_image = None
  140. if 'cached_table_hash' in st.session_state:
  141. st.session_state.cached_table_hash = None