column_detector.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """
  2. 列边界检测器
  3. 基于 X 坐标聚类分析,自动检测表格列边界
  4. """
  5. import numpy as np
  6. from typing import List, Dict, Tuple, Optional
  7. from dataclasses import dataclass
  8. @dataclass
  9. class ColumnRegion:
  10. """列区域"""
  11. x_left: int
  12. x_right: int
  13. column_index: int = 0
  14. @property
  15. def width(self) -> int:
  16. return self.x_right - self.x_left
  17. @property
  18. def x_center(self) -> float:
  19. return (self.x_left + self.x_right) / 2
  20. def to_tuple(self) -> Tuple[int, int]:
  21. return (self.x_left, self.x_right)
  22. def to_dict(self) -> Dict:
  23. return {
  24. 'x_left': self.x_left,
  25. 'x_right': self.x_right,
  26. 'column_index': self.column_index,
  27. 'width': self.width
  28. }
  29. class ColumnBoundaryDetector:
  30. """
  31. 列边界检测器
  32. 核心思路:
  33. 1. 收集所有 OCR box 的左边界 X 坐标
  34. 2. 通过聚类找到稳定的列起始位置
  35. 3. 根据每列的 box 右边界确定列宽
  36. 适用场景:
  37. - 银行流水(列宽固定)
  38. - 规范的无线表格
  39. """
  40. def __init__(self,
  41. x_tolerance: int = 20, # X 坐标聚类容差
  42. min_boxes_per_column: int = 3, # 每列最少 box 数量
  43. min_column_width: int = 30, # 最小列宽
  44. column_gap: int = 5): # 列间最小间隙
  45. self.x_tolerance = x_tolerance
  46. self.min_boxes_per_column = min_boxes_per_column
  47. self.min_column_width = min_column_width
  48. self.column_gap = column_gap
  49. def detect(self,
  50. ocr_boxes: List[Dict],
  51. table_region: Optional[Tuple[int, int, int, int]] = None,
  52. page_width: Optional[int] = None) -> List[ColumnRegion]:
  53. """
  54. 检测列边界
  55. Args:
  56. ocr_boxes: OCR 识别结果 [{text, bbox: [x1,y1,x2,y2]}]
  57. table_region: 表格区域 (可选)
  58. page_width: 页面宽度 (可选)
  59. Returns:
  60. 列区域列表 (按 x 坐标升序)
  61. """
  62. if not ocr_boxes:
  63. return []
  64. # 1. 筛选表格区域内的 boxes
  65. if table_region:
  66. table_boxes = [
  67. b for b in ocr_boxes
  68. if self._box_in_region(b['bbox'], table_region)
  69. ]
  70. else:
  71. table_boxes = ocr_boxes
  72. if not table_boxes:
  73. return []
  74. # 2. 收集所有 box 的左边界 X 坐标
  75. x_lefts = [b['bbox'][0] for b in table_boxes]
  76. # 3. 聚类找到列起始位置
  77. column_centers = self._cluster_x_positions(x_lefts)
  78. if not column_centers:
  79. return []
  80. # 4. 为每个聚类计算列边界
  81. columns = []
  82. for i, center in enumerate(sorted(column_centers)):
  83. # 找到属于这个聚类的 boxes
  84. cluster_boxes = [
  85. b for b in table_boxes
  86. if abs(b['bbox'][0] - center) < self.x_tolerance * 2
  87. ]
  88. if len(cluster_boxes) >= self.min_boxes_per_column:
  89. # 左边界:聚类中心
  90. col_left = int(min(b['bbox'][0] for b in cluster_boxes))
  91. # 右边界:这些 box 的最大 x_right
  92. col_right = int(max(b['bbox'][2] for b in cluster_boxes))
  93. # 确保最小列宽
  94. if col_right - col_left < self.min_column_width:
  95. col_right = col_left + self.min_column_width
  96. columns.append(ColumnRegion(
  97. x_left=col_left,
  98. x_right=col_right,
  99. column_index=i
  100. ))
  101. # 5. 调整相邻列的边界 (防止重叠)
  102. columns = self._adjust_column_boundaries(columns)
  103. # 6. 重新设置列索引
  104. for i, col in enumerate(columns):
  105. col.column_index = i
  106. return columns
  107. def _cluster_x_positions(self, x_values: List[float]) -> List[float]:
  108. """
  109. 对 X 坐标进行聚类
  110. 使用简单的贪心聚类算法:
  111. 1. 排序
  112. 2. 相邻值差距小于 tolerance 则归入同一聚类
  113. 3. 返回每个聚类的中心
  114. """
  115. if not x_values:
  116. return []
  117. x_sorted = sorted(x_values)
  118. clusters = []
  119. current_cluster = [x_sorted[0]]
  120. for x in x_sorted[1:]:
  121. if x - current_cluster[-1] <= self.x_tolerance:
  122. current_cluster.append(x)
  123. else:
  124. # 只保留包含足够多点的聚类
  125. if len(current_cluster) >= self.min_boxes_per_column:
  126. clusters.append(sum(current_cluster) / len(current_cluster))
  127. current_cluster = [x]
  128. # 处理最后一个聚类
  129. if len(current_cluster) >= self.min_boxes_per_column:
  130. clusters.append(sum(current_cluster) / len(current_cluster))
  131. return clusters
  132. def _adjust_column_boundaries(self, columns: List[ColumnRegion]) -> List[ColumnRegion]:
  133. """
  134. 调整相邻列的边界,防止重叠
  135. """
  136. if len(columns) <= 1:
  137. return columns
  138. # 按 x_left 排序
  139. columns = sorted(columns, key=lambda c: c.x_left)
  140. for i in range(len(columns) - 1):
  141. curr_col = columns[i]
  142. next_col = columns[i + 1]
  143. # 如果当前列的右边界超过了下一列的左边界
  144. if curr_col.x_right > next_col.x_left - self.column_gap:
  145. # 取中点作为分界
  146. mid = (curr_col.x_right + next_col.x_left) // 2
  147. columns[i] = ColumnRegion(
  148. x_left=curr_col.x_left,
  149. x_right=mid - self.column_gap // 2,
  150. column_index=curr_col.column_index
  151. )
  152. columns[i + 1] = ColumnRegion(
  153. x_left=mid + self.column_gap // 2,
  154. x_right=next_col.x_right,
  155. column_index=next_col.column_index
  156. )
  157. return columns
  158. def _box_in_region(self, bbox: List[int], region: Tuple[int, int, int, int]) -> bool:
  159. """判断 box 是否在区域内"""
  160. x1, y1, x2, y2 = region
  161. bx1, by1, bx2, by2 = bbox
  162. cx, cy = (bx1 + bx2) / 2, (by1 + by2) / 2
  163. return x1 <= cx <= x2 and y1 <= cy <= y2
  164. def detect_from_header(self,
  165. header_boxes: List[Dict],
  166. page_width: int) -> List[ColumnRegion]:
  167. """
  168. 从表头检测列边界
  169. 适用于表头列名完整且间距清晰的情况
  170. """
  171. if not header_boxes:
  172. return []
  173. # 按 x 坐标排序
  174. sorted_boxes = sorted(header_boxes, key=lambda b: b['bbox'][0])
  175. columns = []
  176. for i, box in enumerate(sorted_boxes):
  177. x_left = box['bbox'][0]
  178. x_right = box['bbox'][2]
  179. # 扩展右边界到下一个 box 的左边界
  180. if i + 1 < len(sorted_boxes):
  181. next_x_left = sorted_boxes[i + 1]['bbox'][0]
  182. x_right = next_x_left - self.column_gap
  183. else:
  184. # 最后一列扩展到页面边缘
  185. x_right = page_width - 20
  186. columns.append(ColumnRegion(
  187. x_left=int(x_left),
  188. x_right=int(x_right),
  189. column_index=i
  190. ))
  191. return columns