table_line_generator.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. """
  2. 基于 OCR bbox 的表格线生成模块
  3. 自动分析无线表格的行列结构,生成表格线
  4. """
  5. import cv2
  6. import numpy as np
  7. from PIL import Image, ImageDraw
  8. from pathlib import Path
  9. from typing import List, Dict, Tuple, Optional, Union
  10. import json
  11. class TableLineGenerator:
  12. """表格线生成器"""
  13. def __init__(self, image: Union[str, Image.Image], ocr_data: List[Dict]):
  14. """
  15. 初始化表格线生成器
  16. Args:
  17. image: 图片路径(str) 或 PIL.Image 对象
  18. ocr_data: OCR识别结果(包含bbox)
  19. """
  20. if isinstance(image, str):
  21. # 传入的是路径
  22. self.image_path = image
  23. self.image = Image.open(image)
  24. elif isinstance(image, Image.Image):
  25. # 传入的是 PIL Image 对象
  26. self.image_path = None # 没有路径
  27. self.image = image
  28. else:
  29. raise TypeError(
  30. f"image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
  31. f"实际类型: {type(image)}"
  32. )
  33. self.ocr_data = ocr_data
  34. # 表格结构参数
  35. self.rows = [] # 行坐标列表 [(y_start, y_end), ...]
  36. self.columns = [] # 列坐标列表 [(x_start, x_end), ...]
  37. self.row_height = 0 # 标准行高
  38. self.col_widths = [] # 各列宽度
  39. @staticmethod
  40. def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]:
  41. """
  42. 解析 PPStructure V3 的 OCR 结果
  43. Args:
  44. ocr_result: PPStructure V3 的完整 JSON 结果
  45. Returns:
  46. (table_bbox, text_boxes): 表格边界框和文本框列表
  47. """
  48. # 1. 从 parsing_res_list 中找到 table 区域
  49. table_bbox = None
  50. if 'parsing_res_list' in ocr_result:
  51. for block in ocr_result['parsing_res_list']:
  52. if block.get('block_label') == 'table':
  53. table_bbox = block.get('block_bbox')
  54. break
  55. if not table_bbox:
  56. raise ValueError("未找到表格区域 (block_label='table')")
  57. # 2. 从 overall_ocr_res 中提取文本框(使用 rec_boxes)
  58. text_boxes = []
  59. if 'overall_ocr_res' in ocr_result:
  60. rec_boxes = ocr_result['overall_ocr_res'].get('rec_boxes', [])
  61. rec_texts = ocr_result['overall_ocr_res'].get('rec_texts', [])
  62. # 过滤出表格区域内的文本框
  63. for i, bbox in enumerate(rec_boxes):
  64. if len(bbox) >= 4:
  65. # bbox 格式: [x1, y1, x2, y2]
  66. x1, y1, x2, y2 = bbox[:4]
  67. # 判断文本框是否在表格区域内
  68. if (x1 >= table_bbox[0] and y1 >= table_bbox[1] and
  69. x2 <= table_bbox[2] and y2 <= table_bbox[3]):
  70. text_boxes.append({
  71. 'bbox': [int(x1), int(y1), int(x2), int(y2)],
  72. 'text': rec_texts[i] if i < len(rec_texts) else ''
  73. })
  74. # 对text_boxes从上到下,从左到右排序
  75. text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
  76. return table_bbox, text_boxes
  77. def analyze_table_structure(self,
  78. y_tolerance: int = 5,
  79. x_tolerance: int = 10,
  80. min_row_height: int = 20) -> Dict:
  81. """
  82. 分析表格结构(行列分布)
  83. Args:
  84. y_tolerance: Y轴聚类容差(像素)
  85. x_tolerance: X轴聚类容差(像素)
  86. min_row_height: 最小行高(像素)
  87. Returns:
  88. 表格结构信息,包含:
  89. - rows: 行区间列表
  90. - columns: 列区间列表
  91. - horizontal_lines: 横线Y坐标列表 [y1, y2, ..., y_{n+1}]
  92. - vertical_lines: 竖线X坐标列表 [x1, x2, ..., x_{m+1}]
  93. - row_height: 标准行高
  94. - col_widths: 各列宽度
  95. - table_bbox: 表格边界框
  96. """
  97. if not self.ocr_data:
  98. return {}
  99. # 1. 提取所有bbox的Y坐标(用于行检测)
  100. y_coords = []
  101. for item in self.ocr_data:
  102. bbox = item.get('bbox', [])
  103. if len(bbox) >= 4:
  104. y1, y2 = bbox[1], bbox[3]
  105. y_coords.append((y1, y2, bbox))
  106. # 按Y坐标排序
  107. y_coords.sort(key=lambda x: x[0])
  108. # 2. 聚类检测行(基于Y坐标相近的bbox)
  109. self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height)
  110. # 3. 计算标准行高(中位数)
  111. row_heights = [row['y_end'] - row['y_start'] for row in self.rows]
  112. self.row_height = int(np.median(row_heights)) if row_heights else 30
  113. # 4. 提取所有bbox的X坐标(用于列检测)
  114. x_coords = []
  115. for item in self.ocr_data:
  116. bbox = item.get('bbox', [])
  117. if len(bbox) >= 4:
  118. x1, x2 = bbox[0], bbox[2]
  119. x_coords.append((x1, x2))
  120. # 5. 聚类检测列(基于X坐标相近的bbox)
  121. self.columns = self._cluster_columns(x_coords, x_tolerance)
  122. # 6. 计算各列宽度
  123. self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
  124. # 🆕 7. 生成横线坐标列表(共 n+1 条)
  125. horizontal_lines = []
  126. for row in self.rows:
  127. horizontal_lines.append(row['y_start'])
  128. # 添加最后一条横线
  129. if self.rows:
  130. horizontal_lines.append(self.rows[-1]['y_end'])
  131. # 🆕 8. 生成竖线坐标列表(共 m+1 条)
  132. vertical_lines = []
  133. for col in self.columns:
  134. vertical_lines.append(col['x_start'])
  135. # 添加最后一条竖线
  136. if self.columns:
  137. vertical_lines.append(self.columns[-1]['x_end'])
  138. return {
  139. 'rows': self.rows,
  140. 'columns': self.columns,
  141. 'horizontal_lines': horizontal_lines, # 🆕 横线Y坐标列表
  142. 'vertical_lines': vertical_lines, # 🆕 竖线X坐标列表
  143. 'row_height': self.row_height,
  144. 'col_widths': self.col_widths,
  145. 'table_bbox': self._get_table_bbox()
  146. }
  147. def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
  148. """
  149. 聚类检测行
  150. 策略:
  151. 1. 按Y坐标排序
  152. 2. 相近的Y坐标(容差内)归为同一行
  153. 3. 过滤掉高度过小的行
  154. """
  155. if not y_coords:
  156. return []
  157. rows = []
  158. current_row = {
  159. 'y_start': y_coords[0][0],
  160. 'y_end': y_coords[0][1],
  161. 'bboxes': [y_coords[0][2]]
  162. }
  163. for i in range(1, len(y_coords)):
  164. y1, y2, bbox = y_coords[i]
  165. # 判断是否属于当前行(Y坐标相近)
  166. if abs(y1 - current_row['y_start']) <= tolerance:
  167. # 更新行的Y范围
  168. current_row['y_start'] = min(current_row['y_start'], y1)
  169. current_row['y_end'] = max(current_row['y_end'], y2)
  170. current_row['bboxes'].append(bbox)
  171. else:
  172. # 保存当前行(如果高度足够)
  173. if current_row['y_end'] - current_row['y_start'] >= min_height:
  174. rows.append(current_row)
  175. # 开始新行
  176. current_row = {
  177. 'y_start': y1,
  178. 'y_end': y2,
  179. 'bboxes': [bbox]
  180. }
  181. # 保存最后一行
  182. if current_row['y_end'] - current_row['y_start'] >= min_height:
  183. rows.append(current_row)
  184. return rows
  185. def _cluster_columns(self, x_coords: List[Tuple], tolerance: int) -> List[Dict]:
  186. """
  187. 聚类检测列
  188. 策略:
  189. 1. 提取所有bbox的左边界和右边界
  190. 2. 聚类相近的X坐标
  191. 3. 生成列分界线
  192. """
  193. if not x_coords:
  194. return []
  195. # 提取所有X坐标(左边界和右边界)
  196. all_x = []
  197. for x1, x2 in x_coords:
  198. all_x.append(x1)
  199. all_x.append(x2)
  200. all_x = sorted(set(all_x))
  201. # 聚类X坐标
  202. columns = []
  203. current_x = all_x[0]
  204. for x in all_x[1:]:
  205. if x - current_x > tolerance:
  206. # 新列开始
  207. columns.append(current_x)
  208. current_x = x
  209. columns.append(current_x)
  210. # 生成列区间
  211. column_regions = []
  212. for i in range(len(columns) - 1):
  213. column_regions.append({
  214. 'x_start': columns[i],
  215. 'x_end': columns[i + 1]
  216. })
  217. return column_regions
  218. def _get_table_bbox(self) -> List[int]:
  219. """获取表格整体边界框"""
  220. if not self.rows or not self.columns:
  221. return [0, 0, self.image.width, self.image.height]
  222. y_min = min(row['y_start'] for row in self.rows)
  223. y_max = max(row['y_end'] for row in self.rows)
  224. x_min = min(col['x_start'] for col in self.columns)
  225. x_max = max(col['x_end'] for col in self.columns)
  226. return [x_min, y_min, x_max, y_max]
  227. def generate_table_lines(self,
  228. line_color: Tuple[int, int, int] = (0, 0, 255),
  229. line_width: int = 2) -> Image.Image:
  230. """
  231. 在原图上绘制表格线
  232. Args:
  233. line_color: 线条颜色 (R, G, B)
  234. line_width: 线条宽度
  235. Returns:
  236. 绘制了表格线的图片
  237. """
  238. # 复制原图
  239. img_with_lines = self.image.copy()
  240. draw = ImageDraw.Draw(img_with_lines)
  241. # 🔧 简化:使用行列区间而不是重复计算
  242. x_start = self.columns[0]['x_start'] if self.columns else 0
  243. x_end = self.columns[-1]['x_end'] if self.columns else img_with_lines.width
  244. y_start = self.rows[0]['y_start'] if self.rows else 0
  245. y_end = self.rows[-1]['y_end'] if self.rows else img_with_lines.height
  246. # 绘制横线(包括最后一条)
  247. for row in self.rows:
  248. y = row['y_start']
  249. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  250. # 绘制最后一条横线
  251. if self.rows:
  252. y = self.rows[-1]['y_end']
  253. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  254. # 绘制竖线(包括最后一条)
  255. for col in self.columns:
  256. x = col['x_start']
  257. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  258. # 绘制最后一条竖线
  259. if self.columns:
  260. x = self.columns[-1]['x_end']
  261. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  262. return img_with_lines
  263. def save_table_structure(self, output_path: str):
  264. """保存表格结构配置(用于应用到其他页)"""
  265. structure = {
  266. 'row_height': self.row_height,
  267. 'col_widths': self.col_widths,
  268. 'columns': self.columns,
  269. 'first_row_y': self.rows[0]['y_start'] if self.rows else 0,
  270. 'table_bbox': self._get_table_bbox()
  271. }
  272. with open(output_path, 'w', encoding='utf-8') as f:
  273. json.dump(structure, f, indent=2, ensure_ascii=False)
  274. return structure
  275. def apply_structure_to_image(self,
  276. target_image: Union[str, Image.Image],
  277. structure: Dict,
  278. output_path: str) -> str:
  279. """
  280. 将表格结构应用到其他页
  281. Args:
  282. target_image: 目标图片路径(str) 或 PIL.Image 对象
  283. structure: 表格结构配置
  284. output_path: 输出路径
  285. Returns:
  286. 生成的有线表格图片路径
  287. """
  288. # 🔧 修改:支持传入 Image 对象或路径
  289. if isinstance(target_image, str):
  290. target_img = Image.open(target_image)
  291. elif isinstance(target_image, Image.Image):
  292. target_img = target_image
  293. else:
  294. raise TypeError(
  295. f"target_image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
  296. f"实际类型: {type(target_image)}"
  297. )
  298. draw = ImageDraw.Draw(target_img)
  299. row_height = structure['row_height']
  300. col_widths = structure['col_widths']
  301. columns = structure['columns']
  302. first_row_y = structure['first_row_y']
  303. table_bbox = structure['table_bbox']
  304. # 计算行数(根据图片高度)
  305. num_rows = int((target_img.height - first_row_y) / row_height)
  306. # 绘制横线
  307. for i in range(num_rows + 1):
  308. y = first_row_y + i * row_height
  309. draw.line([(table_bbox[0], y), (table_bbox[2], y)],
  310. fill=(0, 0, 255), width=2)
  311. # 绘制竖线
  312. for col in columns:
  313. x = col['x_start']
  314. draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
  315. fill=(0, 0, 255), width=2)
  316. # 绘制最后一条竖线
  317. x = columns[-1]['x_end']
  318. draw.line([(x, first_row_y), (x, first_row_y + num_rows * row_height)],
  319. fill=(0, 0, 255), width=2)
  320. # 保存
  321. target_img.save(output_path)
  322. return output_path