|
|
@@ -15,15 +15,19 @@ from bs4 import BeautifulSoup
|
|
|
class TableLineGenerator:
|
|
|
"""表格线生成器"""
|
|
|
|
|
|
- def __init__(self, image: Union[str, Image.Image], ocr_data: Dict):
|
|
|
+ def __init__(self, image: Union[str, Image.Image, None], ocr_data: Dict):
|
|
|
"""
|
|
|
初始化表格线生成器
|
|
|
|
|
|
Args:
|
|
|
- image: 图片路径(str) 或 PIL.Image 对象
|
|
|
+ image: 图片路径(str) 或 PIL.Image 对象,或 None(仅分析结构时)
|
|
|
ocr_data: OCR识别结果(包含bbox)
|
|
|
"""
|
|
|
- if isinstance(image, str):
|
|
|
+ if image is None:
|
|
|
+ # 🆕 无图片模式:仅用于结构分析
|
|
|
+ self.image_path = None
|
|
|
+ self.image = None
|
|
|
+ elif isinstance(image, str):
|
|
|
self.image_path = image
|
|
|
self.image = Image.open(image)
|
|
|
elif isinstance(image, Image.Image):
|
|
|
@@ -31,7 +35,7 @@ class TableLineGenerator:
|
|
|
self.image = image
|
|
|
else:
|
|
|
raise TypeError(
|
|
|
- f"image 参数必须是 str (路径) 或 PIL.Image.Image 对象,"
|
|
|
+ f"image 参数必须是 str (路径)、PIL.Image.Image 对象或 None,"
|
|
|
f"实际类型: {type(image)}"
|
|
|
)
|
|
|
|
|
|
@@ -221,9 +225,6 @@ class TableLineGenerator:
|
|
|
"""
|
|
|
基于单元格的 row/col 索引分析(MinerU 专用)
|
|
|
|
|
|
- Args:
|
|
|
- use_table_body: 是否使用 table_body 确定准确的行列数
|
|
|
-
|
|
|
Returns:
|
|
|
表格结构信息
|
|
|
"""
|
|
|
@@ -265,20 +266,40 @@ class TableLineGenerator:
|
|
|
bboxes = cells_by_row[row_num]
|
|
|
y_min = min(bbox[1] for bbox in bboxes)
|
|
|
y_max = max(bbox[3] for bbox in bboxes)
|
|
|
- row_boundaries[row_num] = (y_min, y_max)
|
|
|
-
|
|
|
- # 🔑 计算横线
|
|
|
+ row_boundaries[row_num] = (y_min, y_max)
|
|
|
+
|
|
|
+ # 🔑 计算横线(现在使用的是过滤后的数据)
|
|
|
horizontal_lines = _calculate_horizontal_lines_with_spacing(row_boundaries)
|
|
|
|
|
|
- # 🔑 计算每列的 x 边界
|
|
|
+ # 🔑 列边界计算(同样需要过滤异常值)
|
|
|
col_boundaries = {}
|
|
|
for col_num in range(1, actual_cols + 1):
|
|
|
if col_num in cells_by_col:
|
|
|
bboxes = cells_by_col[col_num]
|
|
|
- x_min = min(bbox[0] for bbox in bboxes)
|
|
|
- x_max = max(bbox[2] for bbox in bboxes)
|
|
|
- col_boundaries[col_num] = (x_min, x_max)
|
|
|
-
|
|
|
+
|
|
|
+ # 🎯 过滤 x 方向的异常值(使用 IQR)
|
|
|
+ if len(bboxes) > 1:
|
|
|
+ x_centers = [(bbox[0] + bbox[2]) / 2 for bbox in bboxes]
|
|
|
+ x_center_q1 = np.percentile(x_centers, 25)
|
|
|
+ x_center_q3 = np.percentile(x_centers, 75)
|
|
|
+ x_center_iqr = x_center_q3 - x_center_q1
|
|
|
+ x_center_median = np.median(x_centers)
|
|
|
+
|
|
|
+ # 允许偏移 3 倍 IQR 或至少 100px
|
|
|
+ x_threshold = max(3 * x_center_iqr, 100)
|
|
|
+
|
|
|
+ valid_bboxes = [
|
|
|
+ bbox for bbox in bboxes
|
|
|
+ if abs((bbox[0] + bbox[2]) / 2 - x_center_median) <= x_threshold
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ valid_bboxes = bboxes
|
|
|
+
|
|
|
+ if valid_bboxes:
|
|
|
+ x_min = min(bbox[0] for bbox in valid_bboxes)
|
|
|
+ x_max = max(bbox[2] for bbox in valid_bboxes)
|
|
|
+ col_boundaries[col_num] = (x_min, x_max)
|
|
|
+
|
|
|
# 🔑 计算竖线
|
|
|
vertical_lines = _calculate_vertical_lines_with_spacing(col_boundaries)
|
|
|
|
|
|
@@ -317,7 +338,9 @@ class TableLineGenerator:
|
|
|
'table_bbox': self._get_table_bbox(),
|
|
|
'total_rows': actual_rows,
|
|
|
'total_cols': actual_cols,
|
|
|
- 'method': 'mineru'
|
|
|
+ 'mode': 'hybrid', # ✅ 添加 mode 字段
|
|
|
+ 'modified_h_lines': [], # ✅ 添加修改记录字段
|
|
|
+ 'modified_v_lines': [] # ✅ 添加修改记录字段
|
|
|
}
|
|
|
|
|
|
def _analyze_by_clustering(self, y_tolerance: int, x_tolerance: int, min_row_height: int) -> Dict:
|
|
|
@@ -356,7 +379,7 @@ class TableLineGenerator:
|
|
|
|
|
|
# 4. 提取所有bbox的X坐标(用于列检测)
|
|
|
x_coords = []
|
|
|
- for item in self.ocr_data:
|
|
|
+ for item in ocr_data:
|
|
|
bbox = item.get('bbox', [])
|
|
|
if len(bbox) >= 4:
|
|
|
x1, x2 = bbox[0], bbox[2]
|
|
|
@@ -390,7 +413,9 @@ class TableLineGenerator:
|
|
|
'row_height': self.row_height,
|
|
|
'col_widths': self.col_widths,
|
|
|
'table_bbox': self._get_table_bbox(),
|
|
|
- 'method': 'cluster'
|
|
|
+ 'mode': 'fixed', # ✅ 添加 mode 字段
|
|
|
+ 'modified_h_lines': [], # ✅ 添加修改记录字段
|
|
|
+ 'modified_v_lines': [] # ✅ 添加修改记录字段
|
|
|
}
|
|
|
|
|
|
@staticmethod
|
|
|
@@ -498,6 +523,12 @@ class TableLineGenerator:
|
|
|
line_color: Tuple[int, int, int] = (0, 0, 255),
|
|
|
line_width: int = 2) -> Image.Image:
|
|
|
"""在原图上绘制表格线"""
|
|
|
+ if self.image is None:
|
|
|
+ raise ValueError(
|
|
|
+ "无图片模式下不能调用 generate_table_lines(),"
|
|
|
+ "请在初始化时提供图片"
|
|
|
+ )
|
|
|
+
|
|
|
img_with_lines = self.image.copy()
|
|
|
draw = ImageDraw.Draw(img_with_lines)
|
|
|
|
|
|
@@ -526,6 +557,38 @@ class TableLineGenerator:
|
|
|
|
|
|
return img_with_lines
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def analyze_structure_only(
|
|
|
+ ocr_data: Dict,
|
|
|
+ y_tolerance: int = 5,
|
|
|
+ x_tolerance: int = 10,
|
|
|
+ min_row_height: int = 20,
|
|
|
+ method: str = "auto"
|
|
|
+ ) -> Dict:
|
|
|
+ """
|
|
|
+ 仅分析表格结构(无需图片)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ ocr_data: OCR识别结果
|
|
|
+ y_tolerance: Y轴聚类容差(像素)
|
|
|
+ x_tolerance: X轴聚类容差(像素)
|
|
|
+ min_row_height: 最小行高(像素)
|
|
|
+ method: 分析方法 ("auto" / "cluster" / "mineru")
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 表格结构信息
|
|
|
+ """
|
|
|
+ # 🔑 创建无图片模式的生成器
|
|
|
+ temp_generator = TableLineGenerator(None, ocr_data)
|
|
|
+
|
|
|
+ # 🔑 分析结构
|
|
|
+ return temp_generator.analyze_table_structure(
|
|
|
+ y_tolerance=y_tolerance,
|
|
|
+ x_tolerance=x_tolerance,
|
|
|
+ min_row_height=min_row_height,
|
|
|
+ method=method
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int, int]]) -> List[int]:
|
|
|
"""
|
|
|
@@ -593,7 +656,8 @@ def _calculate_horizontal_lines_with_spacing(row_boundaries: Dict[int, Tuple[int
|
|
|
horizontal_lines.append(separator_y)
|
|
|
else:
|
|
|
# 重叠或紧贴:在当前行的下边界画线
|
|
|
- horizontal_lines.append(y_max)
|
|
|
+ separator_y = int(next_y_min) - max(int(gap / 4), 2)
|
|
|
+ horizontal_lines.append(separator_y)
|
|
|
else:
|
|
|
# 最后一行的下边界
|
|
|
horizontal_lines.append(y_max)
|