table_line_generator.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. """
  2. 基于 OCR bbox 的表格线生成模块
  3. 自动分析无线表格的行列结构,生成表格线
  4. """
  5. import cv2
  6. import numpy as np
  7. from PIL import Image, ImageDraw
  8. from pathlib import Path
  9. from typing import List, Dict, Tuple, Optional, Union
  10. import json
  11. from bs4 import BeautifulSoup
  12. class TableLineGenerator:
  13. """表格线生成器"""
  14. def __init__(self, image: Union[str, Image.Image], ocr_data: List[Dict]):
  15. """
  16. 初始化表格线生成器
  17. Args:
  18. image: 图片路径(str) 或 PIL.Image 对象
  19. ocr_data: OCR识别结果(包含bbox)
  20. """
  21. if isinstance(image, str):
  22. self.image_path = image
  23. self.image = Image.open(image)
  24. elif isinstance(image, Image.Image):
  25. self.image_path = None
  26. self.image = image
  27. else:
  28. raise TypeError(
  29. f"image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
  30. f"实际类型: {type(image)}"
  31. )
  32. self.ocr_data = ocr_data
  33. # 表格结构参数
  34. self.rows = []
  35. self.columns = []
  36. self.row_height = 0
  37. self.col_widths = []
  38. @staticmethod
  39. def parse_mineru_table_result(mineru_result: Union[Dict, List], use_table_body: bool = True) -> Tuple[List[int], Dict]:
  40. """
  41. 解析 MinerU 格式的结果,自动提取 table 并计算行列分割线
  42. Args:
  43. mineru_result: MinerU 的完整 JSON 结果(可以是 dict 或 list)
  44. use_table_body: 是否使用 table_body 来确定准确的行列数
  45. Returns:
  46. (table_bbox, structure): 表格边界框和结构信息
  47. """
  48. # 🔑 提取 table 数据
  49. table_data = _extract_table_data(mineru_result)
  50. if not table_data:
  51. raise ValueError("未找到 MinerU 格式的表格数据 (type='table')")
  52. # 验证必要字段
  53. if 'table_cells' not in table_data:
  54. raise ValueError("表格数据中未找到 table_cells 字段")
  55. table_cells = table_data['table_cells']
  56. if not table_cells:
  57. raise ValueError("table_cells 为空")
  58. # 🔑 优先使用 table_body 确定准确的行列数
  59. if use_table_body and 'table_body' in table_data:
  60. actual_rows, actual_cols = _parse_table_body_structure(table_data['table_body'])
  61. print(f"📋 从 table_body 解析: {actual_rows} 行 × {actual_cols} 列")
  62. else:
  63. # 回退:从 table_cells 推断
  64. actual_rows = max(cell.get('row', 0) for cell in table_cells if 'row' in cell)
  65. actual_cols = max(cell.get('col', 0) for cell in table_cells if 'col' in cell)
  66. print(f"📋 从 table_cells 推断: {actual_rows} 行 × {actual_cols} 列")
  67. # 🔑 按行列索引分组单元格
  68. cells_by_row = {}
  69. cells_by_col = {}
  70. for cell in table_cells:
  71. if 'row' not in cell or 'col' not in cell or 'bbox' not in cell:
  72. continue
  73. row = cell['row']
  74. col = cell['col']
  75. bbox = cell['bbox'] # [x1, y1, x2, y2]
  76. # 仅保留在有效范围内的单元格
  77. if row <= actual_rows and col <= actual_cols:
  78. if row not in cells_by_row:
  79. cells_by_row[row] = []
  80. cells_by_row[row].append(bbox)
  81. if col not in cells_by_col:
  82. cells_by_col[col] = []
  83. cells_by_col[col].append(bbox)
  84. # 🔑 计算每行的 y 边界(考虑折行)
  85. row_boundaries = {}
  86. for row_num in range(1, actual_rows + 1):
  87. if row_num in cells_by_row:
  88. bboxes = cells_by_row[row_num]
  89. y_min = min(bbox[1] for bbox in bboxes)
  90. y_max = max(bbox[3] for bbox in bboxes)
  91. row_boundaries[row_num] = (y_min, y_max)
  92. # 🔑 分析行间距,识别记录边界
  93. horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
  94. # 🔑 计算竖线(考虑列间距)
  95. col_boundaries = {}
  96. for col_num in range(1, actual_cols + 1):
  97. if col_num in cells_by_col:
  98. bboxes = cells_by_col[col_num]
  99. x_min = min(bbox[0] for bbox in bboxes)
  100. x_max = max(bbox[2] for bbox in bboxes)
  101. col_boundaries[col_num] = (x_min, x_max)
  102. vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
  103. # 🔑 生成行区间
  104. rows = []
  105. for row_num in sorted(row_boundaries.keys()):
  106. y_min, y_max = row_boundaries[row_num]
  107. rows.append({
  108. 'y_start': y_min,
  109. 'y_end': y_max,
  110. 'bboxes': cells_by_row.get(row_num, []),
  111. 'row_index': row_num
  112. })
  113. # 🔑 生成列区间
  114. columns = []
  115. for col_num in sorted(col_boundaries.keys()):
  116. x_min, x_max = col_boundaries[col_num]
  117. columns.append({
  118. 'x_start': x_min,
  119. 'x_end': x_max,
  120. 'col_index': col_num
  121. })
  122. # 🔑 计算表格边界框
  123. all_bboxes = [
  124. cell['bbox'] for cell in table_cells
  125. if 'bbox' in cell and cell.get('row', 0) <= actual_rows and cell.get('col', 0) <= actual_cols
  126. ]
  127. if all_bboxes:
  128. x_min = min(bbox[0] for bbox in all_bboxes)
  129. y_min = min(bbox[1] for bbox in all_bboxes)
  130. x_max = max(bbox[2] for bbox in all_bboxes)
  131. y_max = max(bbox[3] for bbox in all_bboxes)
  132. table_bbox = [x_min, y_min, x_max, y_max]
  133. else:
  134. table_bbox = table_data.get('bbox', [0, 0, 2000, 2000])
  135. # 🔑 返回结构信息
  136. structure = {
  137. 'rows': rows,
  138. 'columns': columns,
  139. 'horizontal_lines': horizontal_lines,
  140. 'vertical_lines': vertical_lines,
  141. 'row_height': int(np.median([r['y_end'] - r['y_start'] for r in rows])) if rows else 0,
  142. 'col_widths': [c['x_end'] - c['x_start'] for c in columns],
  143. 'table_bbox': table_bbox,
  144. 'total_rows': actual_rows,
  145. 'total_cols': actual_cols
  146. }
  147. return table_bbox, structure
  148. @staticmethod
  149. def parse_ppstructure_result(ocr_result: Dict) -> Tuple[List[int], List[Dict]]:
  150. """
  151. 解析 PPStructure V3 的 OCR 结果
  152. Args:
  153. ocr_result: PPStructure V3 的完整 JSON 结果
  154. Returns:
  155. (table_bbox, text_boxes): 表格边界框和文本框列表
  156. """
  157. # 1. 从 parsing_res_list 中找到 table 区域
  158. table_bbox = None
  159. if 'parsing_res_list' in ocr_result:
  160. for block in ocr_result['parsing_res_list']:
  161. if block.get('block_label') == 'table':
  162. table_bbox = block.get('block_bbox')
  163. break
  164. if not table_bbox:
  165. raise ValueError("未找到表格区域 (block_label='table')")
  166. # 2. 从 overall_ocr_res 中提取文本框(使用 rec_boxes)
  167. text_boxes = []
  168. if 'overall_ocr_res' in ocr_result:
  169. rec_boxes = ocr_result['overall_ocr_res'].get('rec_boxes', [])
  170. rec_texts = ocr_result['overall_ocr_res'].get('rec_texts', [])
  171. # 过滤出表格区域内的文本框
  172. for i, bbox in enumerate(rec_boxes):
  173. if len(bbox) >= 4:
  174. # bbox 格式: [x1, y1, x2, y2]
  175. x1, y1, x2, y2 = bbox[:4]
  176. # 判断文本框是否在表格区域内
  177. if (x1 >= table_bbox[0] and y1 >= table_bbox[1] and
  178. x2 <= table_bbox[2] and y2 <= table_bbox[3]):
  179. text_boxes.append({
  180. 'bbox': [int(x1), int(y1), int(x2), int(y2)],
  181. 'text': rec_texts[i] if i < len(rec_texts) else ''
  182. })
  183. # 对text_boxes从上到下,从左到右排序
  184. text_boxes.sort(key=lambda x: (x['bbox'][1], x['bbox'][0]))
  185. return table_bbox, text_boxes
  186. def analyze_table_structure(self,
  187. y_tolerance: int = 5,
  188. x_tolerance: int = 10,
  189. min_row_height: int = 20) -> Dict:
  190. """
  191. 分析表格结构(行列分布)
  192. Args:
  193. y_tolerance: Y轴聚类容差(像素)
  194. x_tolerance: X轴聚类容差(像素)
  195. min_row_height: 最小行高(像素)
  196. Returns:
  197. 表格结构信息
  198. """
  199. if not self.ocr_data:
  200. return {}
  201. # 1. 提取所有bbox的Y坐标(用于行检测)
  202. y_coords = []
  203. for item in self.ocr_data:
  204. bbox = item.get('bbox', [])
  205. if len(bbox) >= 4:
  206. y1, y2 = bbox[1], bbox[3]
  207. y_coords.append((y1, y2, bbox))
  208. # 按Y坐标排序
  209. y_coords.sort(key=lambda x: x[0])
  210. # 2. 聚类检测行(基于Y坐标相近的bbox)
  211. self.rows = self._cluster_rows(y_coords, y_tolerance, min_row_height)
  212. # 3. 计算标准行高(中位数)
  213. row_heights = [row['y_end'] - row['y_start'] for row in self.rows]
  214. self.row_height = int(np.median(row_heights)) if row_heights else 30
  215. # 4. 提取所有bbox的X坐标(用于列检测)
  216. x_coords = []
  217. for item in self.ocr_data:
  218. bbox = item.get('bbox', [])
  219. if len(bbox) >= 4:
  220. x1, x2 = bbox[0], bbox[2]
  221. x_coords.append((x1, x2))
  222. # 5. 聚类检测列(基于X坐标相近的bbox)
  223. self.columns = self._cluster_columns(x_coords, x_tolerance)
  224. # 6. 计算各列宽度
  225. self.col_widths = [col['x_end'] - col['x_start'] for col in self.columns]
  226. # 7. 生成横线坐标列表
  227. horizontal_lines = []
  228. for row in self.rows:
  229. horizontal_lines.append(row['y_start'])
  230. if self.rows:
  231. horizontal_lines.append(self.rows[-1]['y_end'])
  232. # 8. 生成竖线坐标列表
  233. vertical_lines = []
  234. for col in self.columns:
  235. vertical_lines.append(col['x_start'])
  236. if self.columns:
  237. vertical_lines.append(self.columns[-1]['x_end'])
  238. return {
  239. 'rows': self.rows,
  240. 'columns': self.columns,
  241. 'horizontal_lines': horizontal_lines,
  242. 'vertical_lines': vertical_lines,
  243. 'row_height': self.row_height,
  244. 'col_widths': self.col_widths,
  245. 'table_bbox': self._get_table_bbox()
  246. }
  247. def _cluster_rows(self, y_coords: List[Tuple], tolerance: int, min_height: int) -> List[Dict]:
  248. """聚类检测行"""
  249. if not y_coords:
  250. return []
  251. rows = []
  252. current_row = {
  253. 'y_start': y_coords[0][0],
  254. 'y_end': y_coords[0][1],
  255. 'bboxes': [y_coords[0][2]]
  256. }
  257. for i in range(1, len(y_coords)):
  258. y1, y2, bbox = y_coords[i]
  259. if abs(y1 - current_row['y_start']) <= tolerance:
  260. current_row['y_start'] = min(current_row['y_start'], y1)
  261. current_row['y_end'] = max(current_row['y_end'], y2)
  262. current_row['bboxes'].append(bbox)
  263. else:
  264. if current_row['y_end'] - current_row['y_start'] >= min_height:
  265. rows.append(current_row)
  266. current_row = {
  267. 'y_start': y1,
  268. 'y_end': y2,
  269. 'bboxes': [bbox]
  270. }
  271. if current_row['y_end'] - current_row['y_start'] >= min_height:
  272. rows.append(current_row)
  273. return rows
  274. def _cluster_columns(self, x_coords: List[Tuple], tolerance: int) -> List[Dict]:
  275. """聚类检测列"""
  276. if not x_coords:
  277. return []
  278. all_x = []
  279. for x1, x2 in x_coords:
  280. all_x.append(x1)
  281. all_x.append(x2)
  282. all_x = sorted(set(all_x))
  283. columns = []
  284. current_x = all_x[0]
  285. for x in all_x[1:]:
  286. if x - current_x > tolerance:
  287. columns.append(current_x)
  288. current_x = x
  289. columns.append(current_x)
  290. column_regions = []
  291. for i in range(len(columns) - 1):
  292. column_regions.append({
  293. 'x_start': columns[i],
  294. 'x_end': columns[i + 1]
  295. })
  296. return column_regions
  297. def _get_table_bbox(self) -> List[int]:
  298. """获取表格整体边界框"""
  299. if not self.rows or not self.columns:
  300. return [0, 0, self.image.width, self.image.height]
  301. y_min = min(row['y_start'] for row in self.rows)
  302. y_max = max(row['y_end'] for row in self.rows)
  303. x_min = min(col['x_start'] for col in self.columns)
  304. x_max = max(col['x_end'] for col in self.columns)
  305. return [x_min, y_min, x_max, y_max]
  306. def generate_table_lines(self,
  307. line_color: Tuple[int, int, int] = (0, 0, 255),
  308. line_width: int = 2) -> Image.Image:
  309. """在原图上绘制表格线"""
  310. img_with_lines = self.image.copy()
  311. draw = ImageDraw.Draw(img_with_lines)
  312. x_start = self.columns[0]['x_start'] if self.columns else 0
  313. x_end = self.columns[-1]['x_end'] if self.columns else img_with_lines.width
  314. y_start = self.rows[0]['y_start'] if self.rows else 0
  315. y_end = self.rows[-1]['y_end'] if self.rows else img_with_lines.height
  316. # 绘制横线
  317. for row in self.rows:
  318. y = row['y_start']
  319. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  320. if self.rows:
  321. y = self.rows[-1]['y_end']
  322. draw.line([(x_start, y), (x_end, y)], fill=line_color, width=line_width)
  323. # 绘制竖线
  324. for col in self.columns:
  325. x = col['x_start']
  326. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  327. if self.columns:
  328. x = self.columns[-1]['x_end']
  329. draw.line([(x, y_start), (x, y_end)], fill=line_color, width=line_width)
  330. return img_with_lines
  331. def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
  332. """
  333. 计算横线位置(考虑行间距)
  334. Args:
  335. row_boundaries: {row_num: (y_min, y_max)}
  336. Returns:
  337. 横线 y 坐标列表
  338. """
  339. if not row_boundaries:
  340. return []
  341. sorted_rows = sorted(row_boundaries.items())
  342. # 🔑 分析相邻行之间的间隔
  343. gaps = []
  344. gap_info = [] # 保存详细信息用于调试
  345. for i in range(len(sorted_rows) - 1):
  346. row_num1, (y_min1, y_max1) = sorted_rows[i]
  347. row_num2, (y_min2, y_max2) = sorted_rows[i + 1]
  348. gap = y_min2 - y_max1 # 行间距(可能为负,表示重叠)
  349. gaps.append(gap)
  350. gap_info.append({
  351. 'row1': row_num1,
  352. 'row2': row_num2,
  353. 'gap': gap
  354. })
  355. print(f"📏 行间距详情:")
  356. for info in gap_info:
  357. status = "重叠" if info['gap'] < 0 else "正常"
  358. print(f" 行 {info['row1']} → {info['row2']}: {info['gap']:.1f}px ({status})")
  359. # 🔑 过滤掉负数 gap(重叠情况)和极小的 gap
  360. valid_gaps = [g for g in gaps if g > 2] # 至少 2px 间隔才算有效
  361. if valid_gaps:
  362. gap_median = np.median(valid_gaps)
  363. gap_std = np.std(valid_gaps)
  364. print(f"📏 行间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px")
  365. print(f" 有效间隔数: {len(valid_gaps)}/{len(gaps)}")
  366. # 🔑 生成横线坐标(在相邻行中间)
  367. horizontal_lines = []
  368. for i, (row_num, (y_min, y_max)) in enumerate(sorted_rows):
  369. if i == 0:
  370. # 第一行的上边界
  371. horizontal_lines.append(y_min)
  372. if i < len(sorted_rows) - 1:
  373. next_row_num, (next_y_min, next_y_max) = sorted_rows[i + 1]
  374. gap = next_y_min - y_max
  375. if gap > 0:
  376. # 有间隔:在间隔中间画线
  377. # separator_y = int((y_max + next_y_min) / 2)
  378. # 有间隔:更靠近下一行的位置
  379. separator_y = int(next_y_min) - int(gap / 4)
  380. horizontal_lines.append(separator_y)
  381. else:
  382. # 重叠或紧贴:在当前行的下边界画线
  383. horizontal_lines.append(y_max)
  384. else:
  385. # 最后一行的下边界
  386. horizontal_lines.append(y_max)
  387. return sorted(set(horizontal_lines))
  388. def _calculate_vertical_lines_with_spacing(col_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
  389. """
  390. 计算竖线位置(考虑列间距和重叠)
  391. Args:
  392. col_boundaries: {col_num: (x_min, x_max)}
  393. Returns:
  394. 竖线 x 坐标列表
  395. """
  396. if not col_boundaries:
  397. return []
  398. sorted_cols = sorted(col_boundaries.items())
  399. # 🔑 分析相邻列之间的间隔
  400. gaps = []
  401. gap_info = []
  402. for i in range(len(sorted_cols) - 1):
  403. col_num1, (x_min1, x_max1) = sorted_cols[i]
  404. col_num2, (x_min2, x_max2) = sorted_cols[i + 1]
  405. gap = x_min2 - x_max1 # 列间距(可能为负)
  406. gaps.append(gap)
  407. gap_info.append({
  408. 'col1': col_num1,
  409. 'col2': col_num2,
  410. 'gap': gap
  411. })
  412. print(f"📏 列间距详情:")
  413. for info in gap_info:
  414. status = "重叠" if info['gap'] < 0 else "正常"
  415. print(f" 列 {info['col1']} → {info['col2']}: {info['gap']:.1f}px ({status})")
  416. # 🔑 过滤掉负数 gap
  417. valid_gaps = [g for g in gaps if g > 2]
  418. if valid_gaps:
  419. gap_median = np.median(valid_gaps)
  420. gap_std = np.std(valid_gaps)
  421. print(f"📏 列间距统计: 中位数={gap_median:.1f}px, 标准差={gap_std:.1f}px")
  422. # 🔑 生成竖线坐标(在相邻列中间)
  423. vertical_lines = []
  424. for i, (col_num, (x_min, x_max)) in enumerate(sorted_cols):
  425. if i == 0:
  426. # 第一列的左边界
  427. vertical_lines.append(x_min)
  428. if i < len(sorted_cols) - 1:
  429. next_col_num, (next_x_min, next_x_max) = sorted_cols[i + 1]
  430. gap = next_x_min - x_max
  431. if gap > 0:
  432. # 有间隔:在间隔中间画线
  433. separator_x = int((x_max + next_x_min) / 2)
  434. vertical_lines.append(separator_x)
  435. else:
  436. # 重叠或紧贴:在当前列的右边界画线
  437. vertical_lines.append(x_max)
  438. else:
  439. # 最后一列的右边界
  440. vertical_lines.append(x_max)
  441. return sorted(set(vertical_lines))
  442. def _extract_table_data(mineru_result: Union[Dict, List]) -> Optional[Dict]:
  443. """提取 table 数据"""
  444. if isinstance(mineru_result, list):
  445. for item in mineru_result:
  446. if isinstance(item, dict) and item.get('type') == 'table':
  447. return item
  448. elif isinstance(mineru_result, dict):
  449. if mineru_result.get('type') == 'table':
  450. return mineru_result
  451. # 递归查找
  452. for value in mineru_result.values():
  453. if isinstance(value, dict) and value.get('type') == 'table':
  454. return value
  455. elif isinstance(value, list):
  456. result = _extract_table_data(value)
  457. if result:
  458. return result
  459. return None
  460. def _parse_table_body_structure(table_body: str) -> Tuple[int, int]:
  461. """从 table_body HTML 中解析准确的行列数"""
  462. try:
  463. soup = BeautifulSoup(table_body, 'html.parser')
  464. table = soup.find('table')
  465. if not table:
  466. raise ValueError("未找到 <table> 标签")
  467. rows = table.find_all('tr')
  468. if not rows:
  469. raise ValueError("未找到 <tr> 标签")
  470. num_rows = len(rows)
  471. first_row = rows[0]
  472. num_cols = len(first_row.find_all(['td', 'th']))
  473. return num_rows, num_cols
  474. except Exception as e:
  475. print(f"⚠️ 解析 table_body 失败: {e}")
  476. return 0, 0