|
|
@@ -10,6 +10,15 @@ import numpy as np
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
+# 导入倾斜矫正工具
|
|
|
+try:
|
|
|
+ from ocr_utils import BBoxExtractor
|
|
|
+ BBOX_EXTRACTOR_AVAILABLE = True
|
|
|
+except ImportError:
|
|
|
+ BBoxExtractor = None
|
|
|
+ BBOX_EXTRACTOR_AVAILABLE = False
|
|
|
+ logger.warning("BBoxExtractor not available, deskewing will be disabled")
|
|
|
+
|
|
|
# 确保 mineru 库可导入
|
|
|
mineru_path = str((__file__ and __file__) and __file__)
|
|
|
# 使用已有 mineru_adapter 中的路径追加逻辑
|
|
|
@@ -50,6 +59,9 @@ class MinerUWiredTableRecognizer:
|
|
|
self.cell_crop_margin: int = self.config.get("cell_crop_margin", 2)
|
|
|
# 是否使用自定义后处理(v2),默认启用
|
|
|
self.use_custom_postprocess: bool = self.config.get("use_custom_postprocess", True)
|
|
|
+ # 倾斜矫正配置
|
|
|
+ self.enable_deskew: bool = self.config.get("enable_deskew", True) and BBOX_EXTRACTOR_AVAILABLE
|
|
|
+ self.skew_threshold: float = self.config.get("skew_threshold", 0.1) # 小于此角度不矫正
|
|
|
self.table_model = UnetTableModel(ocr_engine)
|
|
|
self.ocr_engine = ocr_engine
|
|
|
|
|
|
@@ -145,191 +157,201 @@ class MinerUWiredTableRecognizer:
|
|
|
ys = poly[:, 1]
|
|
|
return [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())]
|
|
|
|
|
|
- # ========== 行列分组与网格计算 (修复版) ==========
|
|
|
+ # ========== 倾斜检测与矫正 ==========
|
|
|
|
|
|
- def _group_cells_into_rows(self, bboxes: List[List[float]]) -> List[List[int]]:
|
|
|
+ def _convert_ocr_boxes_to_paddle_format(self, ocr_boxes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
- 按垂直位置将单元格分组到行 (修复版)
|
|
|
+ 将OCR框转换为BBoxExtractor期望的格式(包含poly字段)
|
|
|
|
|
|
- 使用单元格的垂直中心点进行聚类分组
|
|
|
+ Args:
|
|
|
+ ocr_boxes: OCR结果 [{"bbox": [...], "text": "..."}, ...]
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 转换后的OCR框列表,包含poly字段
|
|
|
"""
|
|
|
- if not bboxes:
|
|
|
- return []
|
|
|
-
|
|
|
- # 计算每个单元格的垂直中心和高度
|
|
|
- cells_info = []
|
|
|
- for i, bbox in enumerate(bboxes):
|
|
|
- y_center = (bbox[1] + bbox[3]) / 2
|
|
|
- height = bbox[3] - bbox[1]
|
|
|
- cells_info.append({
|
|
|
- 'index': i,
|
|
|
- 'y_center': y_center,
|
|
|
- 'y_min': bbox[1],
|
|
|
- 'y_max': bbox[3],
|
|
|
- 'height': height,
|
|
|
- 'bbox': bbox
|
|
|
- })
|
|
|
-
|
|
|
- # 按y中心排序
|
|
|
- cells_info.sort(key=lambda c: c['y_center'])
|
|
|
-
|
|
|
- # 计算自适应行高阈值(使用高度的中位数)
|
|
|
- heights = [c['height'] for c in cells_info if c['height'] > 0]
|
|
|
- if not heights:
|
|
|
- return [[i for i in range(len(bboxes))]]
|
|
|
-
|
|
|
- median_height = sorted(heights)[len(heights) // 2]
|
|
|
- # 行分组阈值:同一行的单元格y中心差异不超过中位高度的40%
|
|
|
- row_thresh = median_height * 0.4
|
|
|
-
|
|
|
- logger.debug(f"行分组: median_height={median_height:.1f}, row_thresh={row_thresh:.1f}")
|
|
|
-
|
|
|
- # 基于y中心进行分组
|
|
|
- rows = []
|
|
|
- current_row = [cells_info[0]['index']]
|
|
|
- current_row_y_centers = [cells_info[0]['y_center']]
|
|
|
-
|
|
|
- for cell in cells_info[1:]:
|
|
|
- # 计算当前行的平均y中心
|
|
|
- avg_y_center = sum(current_row_y_centers) / len(current_row_y_centers)
|
|
|
+ paddle_boxes = []
|
|
|
+ for item in ocr_boxes:
|
|
|
+ bbox = item.get("bbox", [])
|
|
|
+ if not bbox or len(bbox) < 4:
|
|
|
+ continue
|
|
|
|
|
|
- # 如果当前单元格的y中心与行平均y中心的差距在阈值内,加入当前行
|
|
|
- if abs(cell['y_center'] - avg_y_center) <= row_thresh:
|
|
|
- current_row.append(cell['index'])
|
|
|
- current_row_y_centers.append(cell['y_center'])
|
|
|
+ # 转换为4点格式
|
|
|
+ if len(bbox) == 8:
|
|
|
+ # 8点格式,提取为4个点
|
|
|
+ poly = [
|
|
|
+ [float(bbox[0]), float(bbox[1])],
|
|
|
+ [float(bbox[2]), float(bbox[3])],
|
|
|
+ [float(bbox[4]), float(bbox[5])],
|
|
|
+ [float(bbox[6]), float(bbox[7])],
|
|
|
+ ]
|
|
|
+ elif len(bbox) == 4:
|
|
|
+ # 4点bbox格式,转换为4点多边形
|
|
|
+ x1, y1, x2, y2 = map(float, bbox)
|
|
|
+ poly = [
|
|
|
+ [x1, y1],
|
|
|
+ [x2, y1],
|
|
|
+ [x2, y2],
|
|
|
+ [x1, y2],
|
|
|
+ ]
|
|
|
else:
|
|
|
- # 开始新行
|
|
|
- rows.append(current_row)
|
|
|
- current_row = [cell['index']]
|
|
|
- current_row_y_centers = [cell['y_center']]
|
|
|
-
|
|
|
- # 添加最后一行
|
|
|
- if current_row:
|
|
|
- rows.append(current_row)
|
|
|
-
|
|
|
- # 每行内按x坐标排序
|
|
|
- for row in rows:
|
|
|
- row.sort(key=lambda i: bboxes[i][0])
|
|
|
-
|
|
|
- logger.info(f"行分组结果: {len(rows)} 行, 每行单元格数: {[len(r) for r in rows[:10]]}...")
|
|
|
+ continue
|
|
|
+
|
|
|
+ paddle_box = {
|
|
|
+ "bbox": [poly[0][0], poly[0][1], poly[2][0], poly[2][1]], # [x_min, y_min, x_max, y_max]
|
|
|
+ "poly": poly,
|
|
|
+ "text": item.get("text", ""),
|
|
|
+ "confidence": item.get("confidence", item.get("score", 1.0)),
|
|
|
+ }
|
|
|
+ paddle_boxes.append(paddle_box)
|
|
|
|
|
|
- return rows
|
|
|
+ return paddle_boxes
|
|
|
|
|
|
- def _find_grid_index(self, value: float, edges: List[float]) -> int:
|
|
|
+ def _detect_skew_angle(self, ocr_boxes: List[Dict[str, Any]]) -> float:
|
|
|
"""
|
|
|
- 找到值对应的网格索引
|
|
|
+ 检测表格图像的倾斜角度
|
|
|
|
|
|
- 边界将坐标空间划分为 N-1 个网格区间
|
|
|
- 返回 value 所在的网格区间索引
|
|
|
+ Args:
|
|
|
+ ocr_boxes: OCR结果列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 倾斜角度(度数,正值=逆时针,负值=顺时针)
|
|
|
"""
|
|
|
- if not edges:
|
|
|
- return 0
|
|
|
-
|
|
|
- if len(edges) == 1:
|
|
|
- return 0
|
|
|
+ if not self.enable_deskew or not BBOX_EXTRACTOR_AVAILABLE:
|
|
|
+ return 0.0
|
|
|
|
|
|
- # value 小于第一个边界
|
|
|
- if value <= edges[0]:
|
|
|
- return 0
|
|
|
+ if not ocr_boxes or len(ocr_boxes) < 5:
|
|
|
+ # OCR框太少,无法准确检测倾斜
|
|
|
+ return 0.0
|
|
|
|
|
|
- # value 大于最后一个边界
|
|
|
- if value >= edges[-1]:
|
|
|
- return len(edges) - 2
|
|
|
-
|
|
|
- # 找到 value 所在的区间 [edges[i], edges[i+1])
|
|
|
- for i in range(len(edges) - 1):
|
|
|
- if edges[i] <= value < edges[i + 1]:
|
|
|
- return i
|
|
|
-
|
|
|
- return len(edges) - 2
|
|
|
-
|
|
|
- # ========== HTML生成与文本填充 ==========
|
|
|
+ try:
|
|
|
+ # 转换为BBoxExtractor期望的格式
|
|
|
+ # paddle_boxes = self._convert_ocr_boxes_to_paddle_format(ocr_boxes)
|
|
|
+ paddle_boxes = ocr_boxes
|
|
|
+
|
|
|
+ if len(paddle_boxes) < 5:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ # 使用BBoxExtractor计算倾斜角度
|
|
|
+ if BBoxExtractor is None:
|
|
|
+ return 0.0
|
|
|
+ skew_angle = BBoxExtractor.calculate_skew_angle(paddle_boxes)
|
|
|
+
|
|
|
+ logger.debug(f"检测到倾斜角度: {skew_angle:.3f}°")
|
|
|
+ return skew_angle
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"倾斜角度检测失败: {e}")
|
|
|
+ return 0.0
|
|
|
|
|
|
- def _plot_html_with_bbox(
|
|
|
+ def _apply_deskew(
|
|
|
self,
|
|
|
- bboxes: List[List[float]],
|
|
|
- logic_points: List[List[int]],
|
|
|
- texts: List[str],
|
|
|
- row_edges: List[float],
|
|
|
- col_edges: List[float],
|
|
|
- ) -> str:
|
|
|
+ table_image: np.ndarray,
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
+ skew_angle: float
|
|
|
+ ) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
|
|
|
"""
|
|
|
- 生成带 data-bbox 属性的 HTML 表格。
|
|
|
-
|
|
|
- 直接在生成 <td> 时附加 data-bbox="[x1,y1,x2,y2]"(原图坐标)。
|
|
|
+ 应用倾斜矫正并同步更新OCR坐标
|
|
|
|
|
|
Args:
|
|
|
- bboxes: 单元格坐标(已还原到原图坐标)
|
|
|
- logic_points: 逻辑坐标 [[row_start, row_end, col_start, col_end], ...]
|
|
|
- texts: 单元格文本
|
|
|
- row_edges: 行边界
|
|
|
- col_edges: 列边界
|
|
|
+ table_image: 表格图像
|
|
|
+ ocr_boxes: OCR结果列表
|
|
|
+ skew_angle: 倾斜角度(度数,正值=逆时针,负值=顺时针)
|
|
|
|
|
|
Returns:
|
|
|
- HTML字符串
|
|
|
+ (矫正后的图像, 更新后的OCR框列表)
|
|
|
"""
|
|
|
- if not bboxes or len(row_edges) < 2 or len(col_edges) < 2:
|
|
|
- return ""
|
|
|
-
|
|
|
- n_rows = len(row_edges) - 1
|
|
|
- n_cols = len(col_edges) - 1
|
|
|
-
|
|
|
- # 构建网格,记录每个格子对应的单元格索引
|
|
|
- # -1 表示被合并单元格占用,None 表示空
|
|
|
- grid: List[List[Optional[int]]] = [[None for _ in range(n_cols)] for _ in range(n_rows)]
|
|
|
-
|
|
|
- for idx, lp in enumerate(logic_points):
|
|
|
- r0, r1, c0, c1 = lp
|
|
|
- if r0 < 0 or c0 < 0 or r1 >= n_rows or c1 >= n_cols:
|
|
|
- continue
|
|
|
- for rr in range(r0, r1 + 1):
|
|
|
- for cc in range(c0, c1 + 1):
|
|
|
- if rr == r0 and cc == c0:
|
|
|
- grid[rr][cc] = idx # 主格
|
|
|
- else:
|
|
|
- grid[rr][cc] = -1 # 占位(被合并)
|
|
|
+ if abs(skew_angle) < self.skew_threshold:
|
|
|
+ return table_image, ocr_boxes
|
|
|
|
|
|
- # 生成HTML
|
|
|
- html_parts = ["<table>", "<tbody>"]
|
|
|
+ if not BBOX_EXTRACTOR_AVAILABLE:
|
|
|
+ logger.warning("BBoxExtractor不可用,跳过倾斜矫正")
|
|
|
+ return table_image, ocr_boxes
|
|
|
|
|
|
- for r in range(n_rows):
|
|
|
- html_parts.append("<tr>")
|
|
|
- c = 0
|
|
|
- while c < n_cols:
|
|
|
- cell_idx = grid[r][c]
|
|
|
+ try:
|
|
|
+ h, w = table_image.shape[:2]
|
|
|
+ center = (w / 2, h / 2)
|
|
|
+
|
|
|
+ # 计算矫正角度(需要向相反方向旋转)
|
|
|
+ correction_angle = -skew_angle
|
|
|
+
|
|
|
+ # 构建旋转矩阵
|
|
|
+ rotation_matrix = cv2.getRotationMatrix2D(center, correction_angle, 1.0)
|
|
|
+
|
|
|
+ # 计算旋转后的图像尺寸(避免裁剪)
|
|
|
+ cos_val = abs(rotation_matrix[0, 0])
|
|
|
+ sin_val = abs(rotation_matrix[0, 1])
|
|
|
+ new_w = int((h * sin_val) + (w * cos_val))
|
|
|
+ new_h = int((h * cos_val) + (w * sin_val))
|
|
|
+
|
|
|
+ # 调整旋转矩阵的平移部分,使图像居中
|
|
|
+ rotation_matrix[0, 2] += (new_w / 2) - center[0]
|
|
|
+ rotation_matrix[1, 2] += (new_h / 2) - center[1]
|
|
|
+
|
|
|
+ # 应用旋转(填充背景为白色)
|
|
|
+ if len(table_image.shape) == 2:
|
|
|
+ # 灰度图
|
|
|
+ deskewed_image = cv2.warpAffine(
|
|
|
+ table_image, rotation_matrix, (new_w, new_h),
|
|
|
+ flags=cv2.INTER_LINEAR,
|
|
|
+ borderMode=cv2.BORDER_CONSTANT,
|
|
|
+ borderValue=255
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # 彩色图
|
|
|
+ deskewed_image = cv2.warpAffine(
|
|
|
+ table_image, rotation_matrix, (new_w, new_h),
|
|
|
+ flags=cv2.INTER_LINEAR,
|
|
|
+ borderMode=cv2.BORDER_CONSTANT,
|
|
|
+ borderValue=(255, 255, 255)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 更新OCR框坐标
|
|
|
+ # 先将OCR框转换为paddle格式
|
|
|
+ # paddle_boxes = self._convert_ocr_boxes_to_paddle_format(ocr_boxes)
|
|
|
+ paddle_boxes = ocr_boxes
|
|
|
+
|
|
|
+ # 使用BBoxExtractor更新坐标
|
|
|
+ if BBoxExtractor is None:
|
|
|
+ logger.warning("BBoxExtractor不可用,无法更新OCR坐标")
|
|
|
+ return table_image, ocr_boxes
|
|
|
+ updated_paddle_boxes = BBoxExtractor.correct_boxes_skew(
|
|
|
+ paddle_boxes, correction_angle, (new_w, new_h)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 转换回原始格式
|
|
|
+ updated_ocr_boxes = []
|
|
|
+ for i, paddle_box in enumerate(updated_paddle_boxes):
|
|
|
+ original_box = ocr_boxes[i] if i < len(ocr_boxes) else {}
|
|
|
|
|
|
- if cell_idx is None:
|
|
|
- # 空格子,输出空td
|
|
|
- html_parts.append("<td></td>")
|
|
|
- c += 1
|
|
|
- elif cell_idx == -1:
|
|
|
- # 被合并,跳过
|
|
|
- c += 1
|
|
|
+ # 从poly重新计算bbox(确保坐标正确)
|
|
|
+ poly = paddle_box.get("poly", [])
|
|
|
+ if poly and len(poly) >= 4:
|
|
|
+ xs = [p[0] for p in poly]
|
|
|
+ ys = [p[1] for p in poly]
|
|
|
+ bbox = [min(xs), min(ys), max(xs), max(ys)]
|
|
|
else:
|
|
|
- # 主格,输出带span的td
|
|
|
- lp = logic_points[cell_idx]
|
|
|
- rowspan = lp[1] - lp[0] + 1
|
|
|
- colspan = lp[3] - lp[2] + 1
|
|
|
- bbox = bboxes[cell_idx]
|
|
|
- text = html.escape(texts[cell_idx]) if cell_idx < len(texts) else ""
|
|
|
-
|
|
|
- bbox_str = f"[{int(bbox[0])},{int(bbox[1])},{int(bbox[2])},{int(bbox[3])}]"
|
|
|
-
|
|
|
- if rowspan > 1 or colspan > 1:
|
|
|
- html_parts.append(
|
|
|
- f'<td data-bbox="{bbox_str}" rowspan="{rowspan}" colspan="{colspan}">{text}</td>'
|
|
|
- )
|
|
|
- else:
|
|
|
- html_parts.append(f'<td data-bbox="{bbox_str}">{text}</td>')
|
|
|
-
|
|
|
- c += colspan
|
|
|
+ bbox = paddle_box.get("bbox", [])
|
|
|
+
|
|
|
+ updated_box = {
|
|
|
+ "bbox": bbox,
|
|
|
+ "text": paddle_box.get("text", original_box.get("text", "")),
|
|
|
+ "confidence": paddle_box.get("confidence", original_box.get("confidence", original_box.get("score", 1.0))),
|
|
|
+ }
|
|
|
+ # 保留其他字段
|
|
|
+ for key in original_box:
|
|
|
+ if key not in updated_box:
|
|
|
+ updated_box[key] = original_box[key]
|
|
|
+
|
|
|
+ updated_ocr_boxes.append(updated_box)
|
|
|
|
|
|
- html_parts.append("</tr>")
|
|
|
-
|
|
|
- html_parts.append("</tbody>")
|
|
|
- html_parts.append("</table>")
|
|
|
-
|
|
|
- return "".join(html_parts)
|
|
|
+ logger.info(f"✅ 倾斜矫正完成: {skew_angle:.3f}° → 0° (图像尺寸: {w}x{h} → {new_w}x{new_h})")
|
|
|
+
|
|
|
+ return deskewed_image, updated_ocr_boxes
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"倾斜矫正失败: {e}")
|
|
|
+ return table_image, ocr_boxes
|
|
|
+
|
|
|
|
|
|
def _fill_text_by_center_point(
|
|
|
self,
|
|
|
@@ -544,124 +566,6 @@ class MinerUWiredTableRecognizer:
|
|
|
return str(soup)
|
|
|
|
|
|
# ========== 基于表格线交点的单元格计算 ==========
|
|
|
- def _compute_cells_from_lines_4_1(
|
|
|
- self,
|
|
|
- hpred_up: np.ndarray,
|
|
|
- vpred_up: np.ndarray,
|
|
|
- upscale: float = 1.0,
|
|
|
- ) -> List[List[float]]:
|
|
|
- """
|
|
|
- 基于连通域分析从表格线 Mask 提取单元格
|
|
|
-
|
|
|
- 原理:横竖线叠加 -> 反色 -> 提取白色连通块 -> 也就是单元格
|
|
|
- """
|
|
|
- h, w = hpred_up.shape[:2]
|
|
|
-
|
|
|
- # 1. 预处理:二值化
|
|
|
- _, h_bin = cv2.threshold(hpred_up, 127, 255, cv2.THRESH_BINARY)
|
|
|
- _, v_bin = cv2.threshold(vpred_up, 127, 255, cv2.THRESH_BINARY)
|
|
|
-
|
|
|
- # 2. 形态学连接:轻微膨胀以闭合可能的断点
|
|
|
- # 横线横向膨胀,竖线竖向膨胀
|
|
|
- kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 1))
|
|
|
- kernel_v = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 5))
|
|
|
- h_bin = cv2.dilate(h_bin, kernel_h, iterations=1)
|
|
|
- v_bin = cv2.dilate(v_bin, kernel_v, iterations=1)
|
|
|
-
|
|
|
- # 3. 合成网格图 (白线黑底)
|
|
|
- grid_mask = cv2.bitwise_or(h_bin, v_bin)
|
|
|
-
|
|
|
- # 4. 反转图像 (黑线白底),此时单元格变成白色连通域
|
|
|
- inv_grid = cv2.bitwise_not(grid_mask)
|
|
|
-
|
|
|
- # 5. 提取连通域
|
|
|
- num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
|
|
|
-
|
|
|
- bboxes = []
|
|
|
- heights = []
|
|
|
- # 过滤掉背景(label=0)和过小的噪声
|
|
|
- min_area = 50 # 最小面积阈值
|
|
|
-
|
|
|
- for i in range(1, num_labels):
|
|
|
- area = stats[i, cv2.CC_STAT_AREA]
|
|
|
- if area < min_area:
|
|
|
- continue
|
|
|
-
|
|
|
- x = stats[i, cv2.CC_STAT_LEFT]
|
|
|
- y = stats[i, cv2.CC_STAT_TOP]
|
|
|
- w_cell = stats[i, cv2.CC_STAT_WIDTH]
|
|
|
- h_cell = stats[i, cv2.CC_STAT_HEIGHT]
|
|
|
-
|
|
|
- # 过滤掉长条形的非单元格区域(例如边缘的细长空白)
|
|
|
- if w_cell > w * 0.95 or h_cell > h * 0.95:
|
|
|
- continue
|
|
|
-
|
|
|
- # 转换到原图尺度
|
|
|
- orig_h = h_cell / upscale
|
|
|
- orig_w = w_cell / upscale
|
|
|
-
|
|
|
- # 过滤极小高度/宽度的单元格 (可能是边缘噪声或线条残留)
|
|
|
- # 阈值设为 5 像素,通常文本行不会小于这个高度
|
|
|
- if orig_h < 5 or orig_w < 5:
|
|
|
- continue
|
|
|
-
|
|
|
- # 还原到原图坐标
|
|
|
- # 注意:连通域提取的是内部空白,实际单元格边界应该包含线条的一半宽度
|
|
|
- # 这里简单处理,直接使用内部空白作为 bbox,OCR 匹配时通常足够
|
|
|
- bboxes.append([
|
|
|
- x / upscale,
|
|
|
- y / upscale,
|
|
|
- (x + w_cell) / upscale,
|
|
|
- (y + h_cell) / upscale
|
|
|
- ])
|
|
|
- heights.append(orig_h)
|
|
|
-
|
|
|
- # --- 动态过滤逻辑开始 ---
|
|
|
- # 计算中位数高度,代表正常行的典型高度
|
|
|
- median_h = np.median(heights) if heights else 0
|
|
|
- # 设定动态阈值:
|
|
|
- # 1. 高度小于中位数的 1/3 (显著偏矮)
|
|
|
- # 2. 且高度小于 5 像素 (确保不误删本来就很密集的正常小行)
|
|
|
- final_bboxes = []
|
|
|
- for bbox in bboxes:
|
|
|
- h_cell = bbox[3] - bbox[1]
|
|
|
- w_cell = bbox[2] - bbox[0]
|
|
|
-
|
|
|
- # 1. 绝对高度过滤:过滤极矮的噪点
|
|
|
- # 降低阈值到 6px,防止漏掉极小的字号
|
|
|
- if h_cell < 6.0:
|
|
|
- continue
|
|
|
-
|
|
|
- # 2. 相对高度过滤 (更保守的策略)
|
|
|
- # 仅当高度同时满足 "相对极小" AND "绝对较小" 时才过滤
|
|
|
- # 这样可以防止在 median_h 很大(如160px)时误删正常的小行(如25px)
|
|
|
- if median_h > 0:
|
|
|
- ratio = h_cell / median_h
|
|
|
-
|
|
|
- # 策略A: 极矮行过滤
|
|
|
- # 高度 < 10% median 且 绝对高度 < 10px
|
|
|
- # (你的case: 25/164 = 0.15 > 0.1, 且 25 > 10, 故保留)
|
|
|
- if ratio < 0.1 and h_cell < 10.0:
|
|
|
- continue
|
|
|
-
|
|
|
- # 策略B: 扁长缝隙过滤 (通常是双线造成的)
|
|
|
- # 高度 < 20% median 且 宽高比 > 5 且 绝对高度 < 15px
|
|
|
- if ratio < 0.2 and w_cell > h_cell * 5 and h_cell < 15.0:
|
|
|
- continue
|
|
|
-
|
|
|
- final_bboxes.append(bbox)
|
|
|
-
|
|
|
- # --- 动态过滤逻辑结束 ---
|
|
|
- bboxes = final_bboxes
|
|
|
-
|
|
|
- # 按阅读顺序排序 (先上后下,再左后右)
|
|
|
- # 允许一定的行误差
|
|
|
- bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
|
|
|
-
|
|
|
- logger.info(f"连通域分析提取到 {len(bboxes)} 个单元格")
|
|
|
-
|
|
|
- return bboxes
|
|
|
-
|
|
|
def _compute_cells_from_lines(
|
|
|
self,
|
|
|
hpred_up: np.ndarray,
|
|
|
@@ -904,253 +808,6 @@ class MinerUWiredTableRecognizer:
|
|
|
|
|
|
return new_cells
|
|
|
|
|
|
- def _recover_grid_structure_4_1(self, bboxes: List[List[float]]) -> List[Dict]:
|
|
|
- """
|
|
|
- 从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
|
|
|
- 改进版:使用边界投影聚类,更稳健
|
|
|
- """
|
|
|
- if not bboxes:
|
|
|
- return []
|
|
|
-
|
|
|
- # 1. 收集所有 y 坐标 (top, bottom) 并聚类得到行分割线
|
|
|
- y_coords = []
|
|
|
- for b in bboxes:
|
|
|
- y_coords.append(b[1])
|
|
|
- y_coords.append(b[3])
|
|
|
- y_coords.sort()
|
|
|
-
|
|
|
- row_dividers = []
|
|
|
- if y_coords:
|
|
|
- # 聚类阈值:行高的一小部分,例如 10px
|
|
|
- threshold = 10
|
|
|
- curr = [y_coords[0]]
|
|
|
- for y in y_coords[1:]:
|
|
|
- if y - curr[-1] < threshold:
|
|
|
- curr.append(y)
|
|
|
- else:
|
|
|
- row_dividers.append(sum(curr)/len(curr))
|
|
|
- curr = [y]
|
|
|
- row_dividers.append(sum(curr)/len(curr))
|
|
|
-
|
|
|
- # 2. 收集所有 x 坐标 (left, right) 并聚类得到列分割线
|
|
|
- x_coords = []
|
|
|
- for b in bboxes:
|
|
|
- x_coords.append(b[0])
|
|
|
- x_coords.append(b[2])
|
|
|
- x_coords.sort()
|
|
|
-
|
|
|
- col_dividers = []
|
|
|
- if x_coords:
|
|
|
- threshold = 10
|
|
|
- curr = [x_coords[0]]
|
|
|
- for x in x_coords[1:]:
|
|
|
- if x - curr[-1] < threshold:
|
|
|
- curr.append(x)
|
|
|
- else:
|
|
|
- col_dividers.append(sum(curr)/len(curr))
|
|
|
- curr = [x]
|
|
|
- col_dividers.append(sum(curr)/len(curr))
|
|
|
-
|
|
|
- # 3. 匹配单元格到网格
|
|
|
- structured_cells = []
|
|
|
- for bbox in bboxes:
|
|
|
- x1, y1, x2, y2 = bbox
|
|
|
-
|
|
|
- # 找最近的分割线索引
|
|
|
- # Row Start: 离 y1 最近的 divider
|
|
|
- r1 = min(range(len(row_dividers)), key=lambda i: abs(row_dividers[i] - y1))
|
|
|
- # Row End: 离 y2 最近的 divider
|
|
|
- r2 = min(range(len(row_dividers)), key=lambda i: abs(row_dividers[i] - y2))
|
|
|
-
|
|
|
- # Col Start: 离 x1 最近的 divider
|
|
|
- c1 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - x1))
|
|
|
- # Col End: 离 x2 最近的 divider
|
|
|
- c2 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - x2))
|
|
|
-
|
|
|
- # 修正:防止 span=0
|
|
|
- if r1 == r2: r2 = r1 + 1
|
|
|
- if c1 == c2: c2 = c1 + 1
|
|
|
-
|
|
|
- # 确保顺序
|
|
|
- if r1 > r2: r1, r2 = r2, r1
|
|
|
- if c1 > c2: c1, c2 = c2, c1
|
|
|
-
|
|
|
- structured_cells.append({
|
|
|
- "bbox": bbox,
|
|
|
- "row": r1,
|
|
|
- "col": c1,
|
|
|
- "rowspan": r2 - r1,
|
|
|
- "colspan": c2 - c1
|
|
|
- })
|
|
|
-
|
|
|
- # 按行列排序
|
|
|
- structured_cells.sort(key=lambda c: (c["row"], c["col"]))
|
|
|
-
|
|
|
- # 压缩网格,移除空行空列
|
|
|
- structured_cells = self._compress_grid(structured_cells)
|
|
|
-
|
|
|
- return structured_cells
|
|
|
-
|
|
|
- def _recover_grid_structure_4_2(self, bboxes: List[List[float]]) -> List[Dict]:
|
|
|
- """
|
|
|
- 从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
|
|
|
- 重构版:基于标准行骨架的匹配,解决密集行与跨行单元格混合的问题
|
|
|
- """
|
|
|
- if not bboxes:
|
|
|
- return []
|
|
|
-
|
|
|
- # --- 1. 识别行结构 (Row Structure) ---
|
|
|
-
|
|
|
- # 计算高度中位数,用于区分"标准行"和"跨行单元格"
|
|
|
- heights = [b[3] - b[1] for b in bboxes]
|
|
|
- median_h = np.median(heights) if heights else 0
|
|
|
-
|
|
|
- # 定义标准行单元格:高度在 [0.5, 1.5] 倍中位数之间
|
|
|
- # 这样可以排除跨行的大单元格,也可以排除极小的噪点
|
|
|
- standard_cells = []
|
|
|
- for i, bbox in enumerate(bboxes):
|
|
|
- h = bbox[3] - bbox[1]
|
|
|
- if median_h > 0 and 0.5 * median_h < h < 1.8 * median_h:
|
|
|
- standard_cells.append({"bbox": bbox, "index": i})
|
|
|
-
|
|
|
- # 兜底:如果找不到标准行(比如表格全是奇怪的单元格),则使用所有单元格
|
|
|
- if not standard_cells:
|
|
|
- standard_cells = [{"bbox": b, "index": i} for i, b in enumerate(bboxes)]
|
|
|
-
|
|
|
- # 对标准单元格按 Y 中心排序
|
|
|
- standard_cells.sort(key=lambda x: (x["bbox"][1] + x["bbox"][3]) / 2)
|
|
|
-
|
|
|
- # 贪心聚类生成"行骨架"
|
|
|
- # rows_defs 存储每一行的垂直范围 {'top': y1, 'bottom': y2, 'center': yc}
|
|
|
- rows_defs = []
|
|
|
-
|
|
|
- for item in standard_cells:
|
|
|
- box = item["bbox"]
|
|
|
- cy = (box[1] + box[3]) / 2
|
|
|
-
|
|
|
- # 尝试匹配已有的行
|
|
|
- matched = False
|
|
|
- for r_def in rows_defs:
|
|
|
- # 判断条件:中心点距离小于行高的一半 (假设行高近似 median_h)
|
|
|
- # 或者:垂直重叠率高
|
|
|
- r_h = r_def['bottom'] - r_def['top']
|
|
|
- ref_h = max(r_h, median_h) # 参考高度
|
|
|
-
|
|
|
- if abs(cy - r_def['center']) < ref_h * 0.6:
|
|
|
- # 匹配成功,更新行范围
|
|
|
- r_def['top'] = min(r_def['top'], box[1])
|
|
|
- r_def['bottom'] = max(r_def['bottom'], box[3])
|
|
|
- r_def['center'] = (r_def['top'] + r_def['bottom']) / 2
|
|
|
- matched = True
|
|
|
- break
|
|
|
-
|
|
|
- if not matched:
|
|
|
- rows_defs.append({
|
|
|
- 'top': box[1],
|
|
|
- 'bottom': box[3],
|
|
|
- 'center': cy
|
|
|
- })
|
|
|
-
|
|
|
- # 对行骨架按位置排序
|
|
|
- rows_defs.sort(key=lambda x: x['center'])
|
|
|
-
|
|
|
- # 合并靠得太近的行骨架 (防止过度切分)
|
|
|
- # 阈值:0.5 * median_h
|
|
|
- merged_rows = []
|
|
|
- if rows_defs:
|
|
|
- curr = rows_defs[0]
|
|
|
- for next_row in rows_defs[1:]:
|
|
|
- if next_row['center'] - curr['center'] < median_h * 0.5:
|
|
|
- # 合并
|
|
|
- curr['top'] = min(curr['top'], next_row['top'])
|
|
|
- curr['bottom'] = max(curr['bottom'], next_row['bottom'])
|
|
|
- curr['center'] = (curr['top'] + curr['bottom']) / 2
|
|
|
- else:
|
|
|
- merged_rows.append(curr)
|
|
|
- curr = next_row
|
|
|
- merged_rows.append(curr)
|
|
|
- rows_defs = merged_rows
|
|
|
-
|
|
|
- # --- 2. 识别列结构 (Col Structure) ---
|
|
|
- # 列分割线逻辑保持不变,通常列比较规整
|
|
|
- x_coords = []
|
|
|
- for b in bboxes:
|
|
|
- x_coords.append(b[0])
|
|
|
- x_coords.append(b[2])
|
|
|
- x_coords.sort()
|
|
|
-
|
|
|
- col_dividers = []
|
|
|
- if x_coords:
|
|
|
- thresh = 5 # 列间隙阈值
|
|
|
- curr_cluster = [x_coords[0]]
|
|
|
- for x in x_coords[1:]:
|
|
|
- if x - curr_cluster[-1] < thresh:
|
|
|
- curr_cluster.append(x)
|
|
|
- else:
|
|
|
- col_dividers.append(sum(curr_cluster)/len(curr_cluster))
|
|
|
- curr_cluster = [x]
|
|
|
- col_dividers.append(sum(curr_cluster)/len(curr_cluster))
|
|
|
-
|
|
|
- # --- 3. 匹配单元格到网格 ---
|
|
|
- structured_cells = []
|
|
|
- for bbox in bboxes:
|
|
|
- # --- 匹配行 (Row) ---
|
|
|
- b_top, b_bottom = bbox[1], bbox[3]
|
|
|
- b_h = b_bottom - b_top
|
|
|
-
|
|
|
- matched_row_indices = []
|
|
|
-
|
|
|
- for r_idx, r_def in enumerate(rows_defs):
|
|
|
- # 计算 Y 轴重叠
|
|
|
- inter_top = max(b_top, r_def['top'])
|
|
|
- inter_bottom = min(b_bottom, r_def['bottom'])
|
|
|
- inter_h = max(0, inter_bottom - inter_top)
|
|
|
-
|
|
|
- r_h = r_def['bottom'] - r_def['top']
|
|
|
-
|
|
|
- # 判定覆盖:
|
|
|
- # 1. 单元格覆盖了该行的大部分 (跨行情况) -> inter_h / r_h > 0.5
|
|
|
- # 2. 该行覆盖了单元格的大部分 (小单元格情况) -> inter_h / b_h > 0.5
|
|
|
- if r_h > 0 and (inter_h / r_h > 0.5 or inter_h / b_h > 0.5):
|
|
|
- matched_row_indices.append(r_idx)
|
|
|
-
|
|
|
- if not matched_row_indices:
|
|
|
- # 兜底:找中心点最近的行
|
|
|
- cy = (b_top + b_bottom) / 2
|
|
|
- closest_r = min(range(len(rows_defs)), key=lambda i: abs(rows_defs[i]['center'] - cy))
|
|
|
- matched_row_indices = [closest_r]
|
|
|
-
|
|
|
- row_start = min(matched_row_indices)
|
|
|
- row_end = max(matched_row_indices)
|
|
|
- rowspan = row_end - row_start + 1
|
|
|
-
|
|
|
- # --- 匹配列 (Col) ---
|
|
|
- # 找左右边界最近的 divider
|
|
|
- c1 = 0
|
|
|
- c2 = 0
|
|
|
- if len(col_dividers) >= 2:
|
|
|
- c1 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - bbox[0]))
|
|
|
- c2 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - bbox[2]))
|
|
|
- if c1 > c2: c1, c2 = c2, c1
|
|
|
-
|
|
|
- colspan = max(1, c2 - c1)
|
|
|
-
|
|
|
- structured_cells.append({
|
|
|
- "bbox": bbox,
|
|
|
- "row": row_start,
|
|
|
- "col": c1,
|
|
|
- "rowspan": rowspan,
|
|
|
- "colspan": colspan
|
|
|
- })
|
|
|
-
|
|
|
- # 按行列排序
|
|
|
- structured_cells.sort(key=lambda c: (c["row"], c["col"]))
|
|
|
-
|
|
|
- # 压缩网格,移除空行空列
|
|
|
- structured_cells = self._compress_grid(structured_cells)
|
|
|
-
|
|
|
- return structured_cells
|
|
|
-
|
|
|
def _recover_grid_structure(self, bboxes: List[List[float]]) -> List[Dict]:
|
|
|
"""
|
|
|
从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
|
|
|
@@ -1412,6 +1069,13 @@ class MinerUWiredTableRecognizer:
|
|
|
"""
|
|
|
V4版本:直接从表格线计算单元格,绕过 MinerU 的 cal_region_boxes
|
|
|
"""
|
|
|
+ # Step 0: 倾斜检测和矫正(在UNet预测之前)
|
|
|
+ if self.enable_deskew and ocr_boxes:
|
|
|
+ skew_angle = self._detect_skew_angle(ocr_boxes)
|
|
|
+ if abs(skew_angle) > self.skew_threshold:
|
|
|
+ logger.info(f"📐 检测到表格倾斜: {skew_angle:.3f}°,开始矫正...")
|
|
|
+ table_image, ocr_boxes = self._apply_deskew(table_image, ocr_boxes, skew_angle)
|
|
|
+
|
|
|
upscale = self.upscale_ratio if self.upscale_ratio and self.upscale_ratio > 0 else 1.0
|
|
|
h, w = table_image.shape[:2]
|
|
|
|