smart_generator.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. """
  2. 智能表格线生成器
  3. 整合列检测 + 自适应行分割,提供统一入口
  4. """
  5. import json
  6. import cv2
  7. import numpy as np
  8. from pathlib import Path
  9. from typing import List, Dict, Tuple, Optional, Union
  10. from dataclasses import dataclass, asdict
  11. from .adaptive_row_splitter import AdaptiveRowSplitter, RowRegion
  12. from .column_detector import ColumnBoundaryDetector, ColumnRegion
  13. @dataclass
  14. class TableStructure:
  15. """表格结构"""
  16. table_region: Tuple[int, int, int, int] # (x1, y1, x2, y2)
  17. columns: List[ColumnRegion]
  18. rows: List[RowRegion]
  19. page_size: Tuple[int, int] # (width, height)
  20. def get_vertical_lines(self) -> List[int]:
  21. """获取所有竖线的 X 坐标"""
  22. if not self.columns:
  23. return []
  24. lines = [self.columns[0].x_left]
  25. for col in self.columns:
  26. lines.append(col.x_right)
  27. return lines
  28. def get_horizontal_lines(self) -> List[int]:
  29. """获取所有横线的 Y 坐标"""
  30. if not self.rows:
  31. return []
  32. lines = [self.rows[0].y_top]
  33. for row in self.rows:
  34. lines.append(row.y_bottom)
  35. return lines
  36. def get_cell_bboxes(self) -> List[List[Tuple[int, int, int, int]]]:
  37. """获取所有单元格的 bbox"""
  38. cells = []
  39. for row in self.rows:
  40. row_cells = []
  41. for col in self.columns:
  42. cell_bbox = (col.x_left, row.y_top, col.x_right, row.y_bottom)
  43. row_cells.append(cell_bbox)
  44. cells.append(row_cells)
  45. return cells
  46. def to_dict(self) -> Dict:
  47. return {
  48. 'table_region': self.table_region,
  49. 'columns': [c.to_dict() for c in self.columns],
  50. 'rows': [r.to_dict() for r in self.rows],
  51. 'page_size': self.page_size,
  52. 'vertical_lines': self.get_vertical_lines(),
  53. 'horizontal_lines': self.get_horizontal_lines()
  54. }
  55. class SmartTableLineGenerator:
  56. """
  57. 智能表格线生成器
  58. 功能:
  59. 1. 自动检测表格区域
  60. 2. 自动检测列边界
  61. 3. 自适应行分割(支持可变行高)
  62. 4. 生成表格线坐标
  63. 5. 可选:在图片上绘制表格线
  64. """
  65. def __init__(self,
  66. # 列检测参数
  67. x_tolerance: int = 20,
  68. min_boxes_per_column: int = 3,
  69. # 行分割参数
  70. min_gap_height: int = 6,
  71. density_threshold: float = 0.05,
  72. min_row_height: int = 15,
  73. # 其他参数
  74. table_margin: int = 10,
  75. header_detection: bool = True):
  76. self.column_detector = ColumnBoundaryDetector(
  77. x_tolerance=x_tolerance,
  78. min_boxes_per_column=min_boxes_per_column
  79. )
  80. self.row_splitter = AdaptiveRowSplitter(
  81. min_gap_height=min_gap_height,
  82. density_threshold=density_threshold,
  83. min_row_height=min_row_height
  84. )
  85. self.table_margin = table_margin
  86. self.header_detection = header_detection
  87. def generate(self,
  88. ocr_boxes: List[Dict],
  89. page_size: Tuple[int, int],
  90. table_region: Optional[Tuple[int, int, int, int]] = None,
  91. header_row_count: int = 1,
  92. debug: bool = False) -> Tuple[TableStructure, Optional[Dict]]:
  93. """
  94. 生成表格结构
  95. Args:
  96. ocr_boxes: OCR 结果 [{text, bbox: [x1,y1,x2,y2]}]
  97. page_size: (width, height)
  98. table_region: 表格区域 (可选,自动检测)
  99. header_row_count: 表头行数
  100. debug: 是否返回调试信息
  101. Returns:
  102. (TableStructure, debug_info)
  103. """
  104. width, height = page_size
  105. debug_info = {} if debug else None
  106. # 1. 自动检测表格区域
  107. if table_region is None:
  108. table_region = self._detect_table_region(ocr_boxes, width, height)
  109. if debug:
  110. debug_info['table_region'] = table_region
  111. # 2. 检测列边界
  112. columns = self.column_detector.detect(ocr_boxes, table_region, width)
  113. if debug:
  114. debug_info['columns'] = [c.to_dict() for c in columns]
  115. # 3. 自适应行分割
  116. rows, row_debug = self.row_splitter.split_rows(
  117. ocr_boxes, table_region, debug=debug
  118. )
  119. if debug and row_debug:
  120. debug_info['row_splitter'] = row_debug
  121. # 4. 构建表格结构
  122. structure = TableStructure(
  123. table_region=table_region,
  124. columns=columns,
  125. rows=rows,
  126. page_size=page_size
  127. )
  128. return structure, debug_info
  129. def generate_from_image(self,
  130. image_path: Union[str, Path],
  131. ocr_boxes: List[Dict],
  132. table_region: Optional[Tuple[int, int, int, int]] = None,
  133. debug: bool = False) -> Tuple[TableStructure, Optional[Dict]]:
  134. """
  135. 从图片生成表格结构
  136. 自动获取图片尺寸
  137. """
  138. image = cv2.imread(str(image_path))
  139. if image is None:
  140. raise ValueError(f"无法读取图片: {image_path}")
  141. height, width = image.shape[:2]
  142. return self.generate(ocr_boxes, (width, height), table_region, debug=debug)
  143. def draw_table_lines(self,
  144. image: np.ndarray,
  145. structure: TableStructure,
  146. line_color: Tuple[int, int, int] = (0, 0, 255),
  147. line_thickness: int = 1,
  148. draw_cells: bool = False) -> np.ndarray:
  149. """
  150. 在图片上绘制表格线
  151. Args:
  152. image: 原始图片
  153. structure: 表格结构
  154. line_color: 线条颜色 (BGR)
  155. line_thickness: 线条粗细
  156. draw_cells: 是否绘制单元格边框
  157. Returns:
  158. 绘制后的图片
  159. """
  160. result = image.copy()
  161. v_lines = structure.get_vertical_lines()
  162. h_lines = structure.get_horizontal_lines()
  163. if not v_lines or not h_lines:
  164. return result
  165. y_top = h_lines[0]
  166. y_bottom = h_lines[-1]
  167. x_left = v_lines[0]
  168. x_right = v_lines[-1]
  169. # 绘制竖线
  170. for x in v_lines:
  171. cv2.line(result, (x, y_top), (x, y_bottom), line_color, line_thickness)
  172. # 绘制横线
  173. for y in h_lines:
  174. cv2.line(result, (x_left, y), (x_right, y), line_color, line_thickness)
  175. return result
  176. def build_table_data(self,
  177. ocr_boxes: List[Dict],
  178. structure: TableStructure) -> List[List[str]]:
  179. """
  180. 根据表格结构构建结构化数据
  181. Returns:
  182. 二维列表 table[row_idx][col_idx] = cell_text
  183. """
  184. n_rows = len(structure.rows)
  185. n_cols = len(structure.columns)
  186. # 初始化表格
  187. table = [[[] for _ in range(n_cols)] for _ in range(n_rows)]
  188. # 将每个 box 分配到对应的单元格
  189. for box in ocr_boxes:
  190. row_idx = self._find_row_index(box, structure.rows)
  191. col_idx = self._find_column_index(box, structure.columns)
  192. if row_idx >= 0 and col_idx >= 0:
  193. table[row_idx][col_idx].append(box['text'])
  194. # 合并每个单元格的文本
  195. result = []
  196. for row in table:
  197. result.append([' '.join(texts) for texts in row])
  198. return result
  199. def _detect_table_region(self,
  200. boxes: List[Dict],
  201. width: int,
  202. height: int) -> Tuple[int, int, int, int]:
  203. """自动检测表格区域"""
  204. if not boxes:
  205. return (0, 0, width, height)
  206. # 使用所有 boxes 的边界框
  207. x1 = min(b['bbox'][0] for b in boxes)
  208. y1 = min(b['bbox'][1] for b in boxes)
  209. x2 = max(b['bbox'][2] for b in boxes)
  210. y2 = max(b['bbox'][3] for b in boxes)
  211. # 留边距
  212. return (
  213. max(0, int(x1 - self.table_margin)),
  214. max(0, int(y1 - self.table_margin)),
  215. min(width, int(x2 + self.table_margin)),
  216. min(height, int(y2 + self.table_margin))
  217. )
  218. def _find_row_index(self, box: Dict, rows: List[RowRegion]) -> int:
  219. """找到 box 所属的行索引"""
  220. cy = (box['bbox'][1] + box['bbox'][3]) / 2
  221. for i, row in enumerate(rows):
  222. if row.y_top <= cy < row.y_bottom:
  223. return i
  224. return -1
  225. def _find_column_index(self, box: Dict, columns: List[ColumnRegion]) -> int:
  226. """找到 box 所属的列索引"""
  227. cx = (box['bbox'][0] + box['bbox'][2]) / 2
  228. for i, col in enumerate(columns):
  229. if col.x_left - 10 <= cx <= col.x_right + 10:
  230. return i
  231. return -1
  232. def save_structure(self, structure: TableStructure, output_path: Union[str, Path]):
  233. """保存表格结构到 JSON"""
  234. with open(output_path, 'w', encoding='utf-8') as f:
  235. json.dump(structure.to_dict(), f, ensure_ascii=False, indent=2)
  236. def load_structure(self, input_path: Union[str, Path]) -> TableStructure:
  237. """从 JSON 加载表格结构"""
  238. with open(input_path, 'r', encoding='utf-8') as f:
  239. data = json.load(f)
  240. columns = [
  241. ColumnRegion(x_left=c['x_left'], x_right=c['x_right'], column_index=c['column_index'])
  242. for c in data['columns']
  243. ]
  244. rows = [
  245. RowRegion(y_top=r['y_top'], y_bottom=r['y_bottom'], row_index=r['row_index'], boxes=[])
  246. for r in data['rows']
  247. ]
  248. return TableStructure(
  249. table_region=tuple(data['table_region']),
  250. columns=columns,
  251. rows=rows,
  252. page_size=tuple(data['page_size'])
  253. )