|
|
@@ -1,9 +1,14 @@
|
|
|
import sys
|
|
|
-from typing import Any, Dict, List, Tuple, cast
|
|
|
+import html
|
|
|
+import copy
|
|
|
+from typing import Any, Dict, List, Tuple, Optional, cast
|
|
|
+import ast
|
|
|
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
|
|
|
+from loguru import logger
|
|
|
+
|
|
|
# 确保 mineru 库可导入
|
|
|
mineru_path = str((__file__ and __file__) and __file__)
|
|
|
# 使用已有 mineru_adapter 中的路径追加逻辑
|
|
|
@@ -16,7 +21,12 @@ from mineru.model.table.rec.unet_table.main import UnetTableModel
|
|
|
|
|
|
|
|
|
class MinerUWiredTableRecognizer:
|
|
|
- """有线表格识别封装:裁剪+放大→UNet→坐标回写+按中心点匹配OCR文本"""
|
|
|
+ """有线表格识别封装:裁剪+放大→UNet→坐标回写+按中心点匹配OCR文本
|
|
|
+
|
|
|
+ 支持两种后处理模式:
|
|
|
+ - recognize_legacy(): 原始流程,使用MinerU的plot_html_table
|
|
|
+ - recognize_v4(): 改进流程,使用自定义HTML生成和文本填充(支持data-bbox属性)
|
|
|
+ """
|
|
|
|
|
|
def __init__(self, config: Dict[str, Any], ocr_engine: Any):
|
|
|
self.config = config or {}
|
|
|
@@ -26,9 +36,28 @@ class MinerUWiredTableRecognizer:
|
|
|
self.col_threshold: int = self.config.get("col_threshold", 15)
|
|
|
self.ocr_conf_threshold: float = self.config.get("ocr_conf_threshold", 0.5)
|
|
|
self.cell_crop_margin: int = self.config.get("cell_crop_margin", 2)
|
|
|
+ # 是否使用自定义后处理(v2),默认启用
|
|
|
+ self.use_custom_postprocess: bool = self.config.get("use_custom_postprocess", True)
|
|
|
self.table_model = UnetTableModel(ocr_engine)
|
|
|
self.ocr_engine = ocr_engine
|
|
|
|
|
|
+ # ========== 坐标格式转换工具 ==========
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _normalize_bbox(box: List[float]) -> List[float]:
|
|
|
+ """将8点或4点坐标统一转换为 [x_min, y_min, x_max, y_max] 格式"""
|
|
|
+ if not box:
|
|
|
+ return []
|
|
|
+ if len(box) == 8:
|
|
|
+ xs = [box[0], box[2], box[4], box[6]]
|
|
|
+ ys = [box[1], box[3], box[5], box[7]]
|
|
|
+ return [min(xs), min(ys), max(xs), max(ys)]
|
|
|
+ elif len(box) == 4:
|
|
|
+ # 已经是4点格式,确保是 [x_min, y_min, x_max, y_max]
|
|
|
+ x1, y1, x2, y2 = box
|
|
|
+ return [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)]
|
|
|
+ return []
|
|
|
+
|
|
|
@staticmethod
|
|
|
def _to_unet_ocr_format(ocr_boxes: List[Dict[str, Any]]) -> List[List[Any]]:
|
|
|
"""将OCR结果转成 UNet 期望格式 [[poly4,text,score], ...],坐标用浮点。"""
|
|
|
@@ -57,44 +86,263 @@ class MinerUWiredTableRecognizer:
|
|
|
|
|
|
@staticmethod
|
|
|
def _poly_to_bbox(poly: np.ndarray) -> List[float]:
|
|
|
+ """将4点多边形转换为 [x_min, y_min, x_max, y_max]"""
|
|
|
xs = poly[:, 0]
|
|
|
ys = poly[:, 1]
|
|
|
return [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())]
|
|
|
|
|
|
- def _match_text_by_center(
|
|
|
+ # ========== 行列分组与网格计算 (修复版) ==========
|
|
|
+
|
|
|
+ def _group_cells_into_rows(self, bboxes: List[List[float]]) -> List[List[int]]:
|
|
|
+ """
|
|
|
+ 按垂直位置将单元格分组到行 (修复版)
|
|
|
+
|
|
|
+ 使用单元格的垂直中心点进行聚类分组
|
|
|
+ """
|
|
|
+ if not bboxes:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 计算每个单元格的垂直中心和高度
|
|
|
+ cells_info = []
|
|
|
+ for i, bbox in enumerate(bboxes):
|
|
|
+ y_center = (bbox[1] + bbox[3]) / 2
|
|
|
+ height = bbox[3] - bbox[1]
|
|
|
+ cells_info.append({
|
|
|
+ 'index': i,
|
|
|
+ 'y_center': y_center,
|
|
|
+ 'y_min': bbox[1],
|
|
|
+ 'y_max': bbox[3],
|
|
|
+ 'height': height,
|
|
|
+ 'bbox': bbox
|
|
|
+ })
|
|
|
+
|
|
|
+ # 按y中心排序
|
|
|
+ cells_info.sort(key=lambda c: c['y_center'])
|
|
|
+
|
|
|
+ # 计算自适应行高阈值(使用高度的中位数)
|
|
|
+ heights = [c['height'] for c in cells_info if c['height'] > 0]
|
|
|
+ if not heights:
|
|
|
+ return [[i for i in range(len(bboxes))]]
|
|
|
+
|
|
|
+ median_height = sorted(heights)[len(heights) // 2]
|
|
|
+ # 行分组阈值:同一行的单元格y中心差异不超过中位高度的40%
|
|
|
+ row_thresh = median_height * 0.4
|
|
|
+
|
|
|
+ logger.debug(f"行分组: median_height={median_height:.1f}, row_thresh={row_thresh:.1f}")
|
|
|
+
|
|
|
+ # 基于y中心进行分组
|
|
|
+ rows = []
|
|
|
+ current_row = [cells_info[0]['index']]
|
|
|
+ current_row_y_centers = [cells_info[0]['y_center']]
|
|
|
+
|
|
|
+ for cell in cells_info[1:]:
|
|
|
+ # 计算当前行的平均y中心
|
|
|
+ avg_y_center = sum(current_row_y_centers) / len(current_row_y_centers)
|
|
|
+
|
|
|
+ # 如果当前单元格的y中心与行平均y中心的差距在阈值内,加入当前行
|
|
|
+ if abs(cell['y_center'] - avg_y_center) <= row_thresh:
|
|
|
+ current_row.append(cell['index'])
|
|
|
+ current_row_y_centers.append(cell['y_center'])
|
|
|
+ else:
|
|
|
+ # 开始新行
|
|
|
+ rows.append(current_row)
|
|
|
+ current_row = [cell['index']]
|
|
|
+ current_row_y_centers = [cell['y_center']]
|
|
|
+
|
|
|
+ # 添加最后一行
|
|
|
+ if current_row:
|
|
|
+ rows.append(current_row)
|
|
|
+
|
|
|
+ # 每行内按x坐标排序
|
|
|
+ for row in rows:
|
|
|
+ row.sort(key=lambda i: bboxes[i][0])
|
|
|
+
|
|
|
+ logger.info(f"行分组结果: {len(rows)} 行, 每行单元格数: {[len(r) for r in rows[:10]]}...")
|
|
|
+
|
|
|
+ return rows
|
|
|
+
|
|
|
+ def _find_grid_index(self, value: float, edges: List[float]) -> int:
|
|
|
+ """
|
|
|
+ 找到值对应的网格索引
|
|
|
+
|
|
|
+ 边界将坐标空间划分为 N-1 个网格区间
|
|
|
+ 返回 value 所在的网格区间索引
|
|
|
+ """
|
|
|
+ if not edges:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ if len(edges) == 1:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # value 小于第一个边界
|
|
|
+ if value <= edges[0]:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # value 大于最后一个边界
|
|
|
+ if value >= edges[-1]:
|
|
|
+ return len(edges) - 2
|
|
|
+
|
|
|
+ # 找到 value 所在的区间 [edges[i], edges[i+1])
|
|
|
+ for i in range(len(edges) - 1):
|
|
|
+ if edges[i] <= value < edges[i + 1]:
|
|
|
+ return i
|
|
|
+
|
|
|
+ return len(edges) - 2
|
|
|
+
|
|
|
+ # ========== HTML生成与文本填充 ==========
|
|
|
+
|
|
|
+ def _plot_html_with_bbox(
|
|
|
self,
|
|
|
- cells_bbox: List[List[float]],
|
|
|
+ bboxes: List[List[float]],
|
|
|
+ logic_points: List[List[int]],
|
|
|
+ texts: List[str],
|
|
|
+ row_edges: List[float],
|
|
|
+ col_edges: List[float],
|
|
|
+ ) -> str:
|
|
|
+ """
|
|
|
+ 生成带 data-bbox 属性的 HTML 表格。
|
|
|
+
|
|
|
+ 直接在生成 <td> 时附加 data-bbox="[x1,y1,x2,y2]"(原图坐标)。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ bboxes: 单元格坐标(已还原到原图坐标)
|
|
|
+ logic_points: 逻辑坐标 [[row_start, row_end, col_start, col_end], ...]
|
|
|
+ texts: 单元格文本
|
|
|
+ row_edges: 行边界
|
|
|
+ col_edges: 列边界
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ HTML字符串
|
|
|
+ """
|
|
|
+ if not bboxes or len(row_edges) < 2 or len(col_edges) < 2:
|
|
|
+ return ""
|
|
|
+
|
|
|
+ n_rows = len(row_edges) - 1
|
|
|
+ n_cols = len(col_edges) - 1
|
|
|
+
|
|
|
+ # 构建网格,记录每个格子对应的单元格索引
|
|
|
+ # -1 表示被合并单元格占用,None 表示空
|
|
|
+ grid: List[List[Optional[int]]] = [[None for _ in range(n_cols)] for _ in range(n_rows)]
|
|
|
+
|
|
|
+ for idx, lp in enumerate(logic_points):
|
|
|
+ r0, r1, c0, c1 = lp
|
|
|
+ if r0 < 0 or c0 < 0 or r1 >= n_rows or c1 >= n_cols:
|
|
|
+ continue
|
|
|
+ for rr in range(r0, r1 + 1):
|
|
|
+ for cc in range(c0, c1 + 1):
|
|
|
+ if rr == r0 and cc == c0:
|
|
|
+ grid[rr][cc] = idx # 主格
|
|
|
+ else:
|
|
|
+ grid[rr][cc] = -1 # 占位(被合并)
|
|
|
+
|
|
|
+ # 生成HTML
|
|
|
+ html_parts = ["<table>", "<tbody>"]
|
|
|
+
|
|
|
+ for r in range(n_rows):
|
|
|
+ html_parts.append("<tr>")
|
|
|
+ c = 0
|
|
|
+ while c < n_cols:
|
|
|
+ cell_idx = grid[r][c]
|
|
|
+
|
|
|
+ if cell_idx is None:
|
|
|
+ # 空格子,输出空td
|
|
|
+ html_parts.append("<td></td>")
|
|
|
+ c += 1
|
|
|
+ elif cell_idx == -1:
|
|
|
+ # 被合并,跳过
|
|
|
+ c += 1
|
|
|
+ else:
|
|
|
+ # 主格,输出带span的td
|
|
|
+ lp = logic_points[cell_idx]
|
|
|
+ rowspan = lp[1] - lp[0] + 1
|
|
|
+ colspan = lp[3] - lp[2] + 1
|
|
|
+ bbox = bboxes[cell_idx]
|
|
|
+ text = html.escape(texts[cell_idx]) if cell_idx < len(texts) else ""
|
|
|
+
|
|
|
+ bbox_str = f"[{int(bbox[0])},{int(bbox[1])},{int(bbox[2])},{int(bbox[3])}]"
|
|
|
+
|
|
|
+ if rowspan > 1 or colspan > 1:
|
|
|
+ html_parts.append(
|
|
|
+ f'<td data-bbox="{bbox_str}" rowspan="{rowspan}" colspan="{colspan}">{text}</td>'
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ html_parts.append(f'<td data-bbox="{bbox_str}">{text}</td>')
|
|
|
+
|
|
|
+ c += colspan
|
|
|
+
|
|
|
+ html_parts.append("</tr>")
|
|
|
+
|
|
|
+ html_parts.append("</tbody>")
|
|
|
+ html_parts.append("</table>")
|
|
|
+
|
|
|
+ return "".join(html_parts)
|
|
|
+
|
|
|
+ def _fill_text_by_center_point(
|
|
|
+ self,
|
|
|
+ bboxes: List[List[float]],
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
) -> List[str]:
|
|
|
- """使用中心点落格分配文本,行内按 y 排序后拼接。"""
|
|
|
- texts_per_cell: List[str] = []
|
|
|
- centers = []
|
|
|
+ """
|
|
|
+ 使用中心点落格策略填充文本。
|
|
|
+
|
|
|
+ 参考 fill_html_with_ocr_by_bbox:
|
|
|
+ - OCR文本中心点落入单元格bbox内则匹配
|
|
|
+ - 多行文本按y坐标排序拼接
|
|
|
+
|
|
|
+ Args:
|
|
|
+ bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
|
|
|
+ ocr_boxes: OCR结果 [{"bbox": [...], "text": "..."}, ...]
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 每个单元格的文本列表
|
|
|
+ """
|
|
|
+ texts: List[str] = ["" for _ in bboxes]
|
|
|
+
|
|
|
+ if not ocr_boxes:
|
|
|
+ return texts
|
|
|
+
|
|
|
+ # 预处理OCR结果:计算中心点
|
|
|
+ ocr_items: List[Dict[str, Any]] = []
|
|
|
for item in ocr_boxes:
|
|
|
- poly = item.get("bbox", [])
|
|
|
- if not poly:
|
|
|
+ box = self._normalize_bbox(item.get("bbox", []))
|
|
|
+ if not box:
|
|
|
continue
|
|
|
- if len(poly) == 8:
|
|
|
- xs = [poly[i] for i in range(0, 8, 2)]
|
|
|
- ys = [poly[i] for i in range(1, 8, 2)]
|
|
|
- cx = (min(xs) + max(xs)) / 2
|
|
|
- cy = (min(ys) + max(ys)) / 2
|
|
|
- elif len(poly) == 4:
|
|
|
- x1, y1, x2, y2 = poly
|
|
|
- cx = (x1 + x2) / 2
|
|
|
- cy = (y1 + y2) / 2
|
|
|
- else:
|
|
|
- continue
|
|
|
- centers.append((cx, cy, item.get("text", ""), item.get("confidence", 0.0)))
|
|
|
-
|
|
|
- for bbox in cells_bbox:
|
|
|
+ cx = (box[0] + box[2]) / 2
|
|
|
+ cy = (box[1] + box[3]) / 2
|
|
|
+ ocr_items.append({
|
|
|
+ "center_x": cx,
|
|
|
+ "center_y": cy,
|
|
|
+ "y1": box[1],
|
|
|
+ "text": item.get("text", ""),
|
|
|
+ "confidence": item.get("confidence", 0.0),
|
|
|
+ })
|
|
|
+
|
|
|
+ # 为每个单元格匹配OCR文本
|
|
|
+ for idx, bbox in enumerate(bboxes):
|
|
|
x1, y1, x2, y2 = bbox
|
|
|
- collected = [(t, cy) for cx, cy, t, conf in centers if x1 <= cx <= x2 and y1 <= cy <= y2]
|
|
|
- collected.sort(key=lambda x: x[1])
|
|
|
- cell_text = " ".join([t for t, _ in collected]) if collected else ""
|
|
|
- texts_per_cell.append(cell_text)
|
|
|
- return texts_per_cell
|
|
|
-
|
|
|
- def recognize(
|
|
|
+ matched: List[Tuple[str, float]] = []
|
|
|
+
|
|
|
+ for ocr in ocr_items:
|
|
|
+ if x1 <= ocr["center_x"] <= x2 and y1 <= ocr["center_y"] <= y2:
|
|
|
+ matched.append((ocr["text"], ocr["y1"]))
|
|
|
+
|
|
|
+ if matched:
|
|
|
+ # 按y坐标排序,确保多行文本顺序正确
|
|
|
+ matched.sort(key=lambda x: x[1])
|
|
|
+ texts[idx] = " ".join([t for t, _ in matched])
|
|
|
+
|
|
|
+ return texts
|
|
|
+
|
|
|
+ def _match_text_by_center(
|
|
|
+ self,
|
|
|
+ cells_bbox: List[List[float]],
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
+ ) -> List[str]:
|
|
|
+ """使用中心点落格分配文本,行内按 y 排序后拼接。(旧版兼容)"""
|
|
|
+ return self._fill_text_by_center_point(cells_bbox, ocr_boxes)
|
|
|
+
|
|
|
+
|
|
|
+ def recognize_legacy(
|
|
|
self,
|
|
|
table_image: np.ndarray,
|
|
|
ocr_boxes: List[Dict[str, Any]],
|
|
|
@@ -272,3 +520,637 @@ class MinerUWiredTableRecognizer:
|
|
|
col_idx += colspan
|
|
|
|
|
|
return str(soup)
|
|
|
+
|
|
|
+ # ========== 基于表格线交点的单元格计算 ==========
|
|
|
+ def _compute_cells_from_lines(
|
|
|
+ self,
|
|
|
+ hpred_up: np.ndarray,
|
|
|
+ vpred_up: np.ndarray,
|
|
|
+ upscale: float = 1.0,
|
|
|
+ debug_output_dir: Optional[str] = None
|
|
|
+ ) -> List[List[float]]:
|
|
|
+ """
|
|
|
+ 基于连通域分析从表格线 Mask 提取单元格
|
|
|
+
|
|
|
+ 原理:横竖线叠加 -> 反色 -> 提取白色连通块 -> 也就是单元格
|
|
|
+ """
|
|
|
+ h, w = hpred_up.shape[:2]
|
|
|
+
|
|
|
+ # 1. 预处理:二值化
|
|
|
+ _, h_bin = cv2.threshold(hpred_up, 127, 255, cv2.THRESH_BINARY)
|
|
|
+ _, v_bin = cv2.threshold(vpred_up, 127, 255, cv2.THRESH_BINARY)
|
|
|
+
|
|
|
+ # 2. 形态学连接:轻微膨胀以闭合可能的断点
|
|
|
+ # 横线横向膨胀,竖线竖向膨胀
|
|
|
+ kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 1))
|
|
|
+ kernel_v = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 5))
|
|
|
+ h_bin = cv2.dilate(h_bin, kernel_h, iterations=1)
|
|
|
+ v_bin = cv2.dilate(v_bin, kernel_v, iterations=1)
|
|
|
+
|
|
|
+ # 3. 合成网格图 (白线黑底)
|
|
|
+ grid_mask = cv2.bitwise_or(h_bin, v_bin)
|
|
|
+
|
|
|
+ # 4. 反转图像 (黑线白底),此时单元格变成白色连通域
|
|
|
+ inv_grid = cv2.bitwise_not(grid_mask)
|
|
|
+
|
|
|
+ # 5. 提取连通域
|
|
|
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(inv_grid, connectivity=8)
|
|
|
+
|
|
|
+ bboxes = []
|
|
|
+ # 过滤掉背景(label=0)和过小的噪声
|
|
|
+ min_area = 50 # 最小面积阈值
|
|
|
+
|
|
|
+ for i in range(1, num_labels):
|
|
|
+ area = stats[i, cv2.CC_STAT_AREA]
|
|
|
+ if area < min_area:
|
|
|
+ continue
|
|
|
+
|
|
|
+ x = stats[i, cv2.CC_STAT_LEFT]
|
|
|
+ y = stats[i, cv2.CC_STAT_TOP]
|
|
|
+ w_cell = stats[i, cv2.CC_STAT_WIDTH]
|
|
|
+ h_cell = stats[i, cv2.CC_STAT_HEIGHT]
|
|
|
+
|
|
|
+ # 过滤掉长条形的非单元格区域(例如边缘的细长空白)
|
|
|
+ if w_cell > w * 0.95 or h_cell > h * 0.95:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 还原到原图坐标
|
|
|
+ # 注意:连通域提取的是内部空白,实际单元格边界应该包含线条的一半宽度
|
|
|
+ # 这里简单处理,直接使用内部空白作为 bbox,OCR 匹配时通常足够
|
|
|
+ bboxes.append([
|
|
|
+ x / upscale,
|
|
|
+ y / upscale,
|
|
|
+ (x + w_cell) / upscale,
|
|
|
+ (y + h_cell) / upscale
|
|
|
+ ])
|
|
|
+
|
|
|
+ # 按阅读顺序排序 (先上后下,再左后右)
|
|
|
+ # 允许一定的行误差
|
|
|
+ bboxes.sort(key=lambda b: (int(b[1] / 10), b[0]))
|
|
|
+
|
|
|
+ logger.info(f"连通域分析提取到 {len(bboxes)} 个单元格")
|
|
|
+
|
|
|
+ # 调试可视化
|
|
|
+ if debug_output_dir:
|
|
|
+ vis = np.zeros((h, w, 3), dtype=np.uint8)
|
|
|
+ vis[grid_mask > 0] = [0, 0, 255] # 红色线条
|
|
|
+
|
|
|
+ # 绘制提取出的框
|
|
|
+ for i, box in enumerate(bboxes):
|
|
|
+ x1, y1, x2, y2 = [int(c * upscale) for c in box]
|
|
|
+ cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
+
|
|
|
+ cv2.imwrite(f"{debug_output_dir}/connected_components.png", vis)
|
|
|
+
|
|
|
+ return bboxes
|
|
|
+
|
|
|
+ def _visualize_detected_lines(
|
|
|
+ self,
|
|
|
+ hpred: np.ndarray,
|
|
|
+ vpred: np.ndarray,
|
|
|
+ h_lines_y: List[int],
|
|
|
+ v_lines_x: List[int],
|
|
|
+ output_path: str
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 可视化检测到的横竖线位置
|
|
|
+ """
|
|
|
+ h, w = hpred.shape[:2]
|
|
|
+
|
|
|
+ # 创建彩色图像
|
|
|
+ vis_img = np.zeros((h, w, 3), dtype=np.uint8)
|
|
|
+
|
|
|
+ # 显示原始mask(淡色背景)
|
|
|
+ vis_img[hpred > 128] = [100, 100, 255] # 淡红色横线
|
|
|
+ vis_img[vpred > 128] = [255, 100, 100] # 淡蓝色竖线
|
|
|
+
|
|
|
+ # 绘制检测到的横线位置(亮绿色)
|
|
|
+ for y in h_lines_y:
|
|
|
+ if 0 <= y < h:
|
|
|
+ cv2.line(vis_img, (0, y), (w, y), (0, 255, 0), 2)
|
|
|
+
|
|
|
+ # 绘制检测到的竖线位置(亮黄色)
|
|
|
+ for x in v_lines_x:
|
|
|
+ if 0 <= x < w:
|
|
|
+ cv2.line(vis_img, (x, 0), (x, h), (0, 255, 255), 2)
|
|
|
+
|
|
|
+ # 添加标注
|
|
|
+ cv2.putText(
|
|
|
+ vis_img, f"H-lines: {len(h_lines_y)}, V-lines: {len(v_lines_x)}",
|
|
|
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2
|
|
|
+ )
|
|
|
+
|
|
|
+ cv2.imwrite(output_path, vis_img)
|
|
|
+ logger.info(f"检测线可视化: {output_path}")
|
|
|
+
|
|
|
+
|
|
|
+ def _recover_grid_structure(self, bboxes: List[List[float]]) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 从散乱的单元格 bbox 恢复表格的行列结构 (row, col, rowspan, colspan)
|
|
|
+ 改进版:使用边界投影聚类,更稳健
|
|
|
+ """
|
|
|
+ if not bboxes:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 1. 收集所有 y 坐标 (top, bottom) 并聚类得到行分割线
|
|
|
+ y_coords = []
|
|
|
+ for b in bboxes:
|
|
|
+ y_coords.append(b[1])
|
|
|
+ y_coords.append(b[3])
|
|
|
+ y_coords.sort()
|
|
|
+
|
|
|
+ row_dividers = []
|
|
|
+ if y_coords:
|
|
|
+ # 聚类阈值:行高的一小部分,例如 10px
|
|
|
+ threshold = 10
|
|
|
+ curr = [y_coords[0]]
|
|
|
+ for y in y_coords[1:]:
|
|
|
+ if y - curr[-1] < threshold:
|
|
|
+ curr.append(y)
|
|
|
+ else:
|
|
|
+ row_dividers.append(sum(curr)/len(curr))
|
|
|
+ curr = [y]
|
|
|
+ row_dividers.append(sum(curr)/len(curr))
|
|
|
+
|
|
|
+ # 2. 收集所有 x 坐标 (left, right) 并聚类得到列分割线
|
|
|
+ x_coords = []
|
|
|
+ for b in bboxes:
|
|
|
+ x_coords.append(b[0])
|
|
|
+ x_coords.append(b[2])
|
|
|
+ x_coords.sort()
|
|
|
+
|
|
|
+ col_dividers = []
|
|
|
+ if x_coords:
|
|
|
+ threshold = 10
|
|
|
+ curr = [x_coords[0]]
|
|
|
+ for x in x_coords[1:]:
|
|
|
+ if x - curr[-1] < threshold:
|
|
|
+ curr.append(x)
|
|
|
+ else:
|
|
|
+ col_dividers.append(sum(curr)/len(curr))
|
|
|
+ curr = [x]
|
|
|
+ col_dividers.append(sum(curr)/len(curr))
|
|
|
+
|
|
|
+ # 3. 匹配单元格到网格
|
|
|
+ structured_cells = []
|
|
|
+ for bbox in bboxes:
|
|
|
+ x1, y1, x2, y2 = bbox
|
|
|
+
|
|
|
+ # 找最近的分割线索引
|
|
|
+ # Row Start: 离 y1 最近的 divider
|
|
|
+ r1 = min(range(len(row_dividers)), key=lambda i: abs(row_dividers[i] - y1))
|
|
|
+ # Row End: 离 y2 最近的 divider
|
|
|
+ r2 = min(range(len(row_dividers)), key=lambda i: abs(row_dividers[i] - y2))
|
|
|
+
|
|
|
+ # Col Start: 离 x1 最近的 divider
|
|
|
+ c1 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - x1))
|
|
|
+ # Col End: 离 x2 最近的 divider
|
|
|
+ c2 = min(range(len(col_dividers)), key=lambda i: abs(col_dividers[i] - x2))
|
|
|
+
|
|
|
+ # 修正:防止 span=0
|
|
|
+ if r1 == r2: r2 = r1 + 1
|
|
|
+ if c1 == c2: c2 = c1 + 1
|
|
|
+
|
|
|
+ # 确保顺序
|
|
|
+ if r1 > r2: r1, r2 = r2, r1
|
|
|
+ if c1 > c2: c1, c2 = c2, c1
|
|
|
+
|
|
|
+ structured_cells.append({
|
|
|
+ "bbox": bbox,
|
|
|
+ "row": r1,
|
|
|
+ "col": c1,
|
|
|
+ "rowspan": r2 - r1,
|
|
|
+ "colspan": c2 - c1
|
|
|
+ })
|
|
|
+
|
|
|
+ # 按行列排序
|
|
|
+ structured_cells.sort(key=lambda c: (c["row"], c["col"]))
|
|
|
+
|
|
|
+ return structured_cells
|
|
|
+
|
|
|
+ def _build_html_from_merged_cells(self, merged_cells: List[Dict]) -> str:
|
|
|
+ """
|
|
|
+ 基于矩阵填充法生成 HTML,防止错位
|
|
|
+ """
|
|
|
+ if not merged_cells:
|
|
|
+ return "<table><tbody></tbody></table>"
|
|
|
+
|
|
|
+ # 1. 计算网格尺寸
|
|
|
+ max_row = 0
|
|
|
+ max_col = 0
|
|
|
+ for cell in merged_cells:
|
|
|
+ max_row = max(max_row, cell["row"] + cell.get("rowspan", 1))
|
|
|
+ max_col = max(max_col, cell["col"] + cell.get("colspan", 1))
|
|
|
+
|
|
|
+ # 2. 构建占用矩阵 (True 表示该位置已被占据)
|
|
|
+ occupied = [[False for _ in range(max_col)] for _ in range(max_row)]
|
|
|
+
|
|
|
+ # 3. 将单元格放入查找表,方便按坐标检索
|
|
|
+ cell_map = {}
|
|
|
+ for cell in merged_cells:
|
|
|
+ key = (cell["row"], cell["col"])
|
|
|
+ cell_map[key] = cell
|
|
|
+
|
|
|
+ html_parts = ["<table><tbody>"]
|
|
|
+
|
|
|
+ # 4. 逐行逐列扫描
|
|
|
+ for r in range(max_row):
|
|
|
+ html_parts.append("<tr>")
|
|
|
+ for c in range(max_col):
|
|
|
+ # 如果该位置已被之前的 rowspan/colspan 占据,跳过
|
|
|
+ if occupied[r][c]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 检查是否有单元格起始于此
|
|
|
+ cell = cell_map.get((r, c))
|
|
|
+
|
|
|
+ if cell:
|
|
|
+ # 有单元格:输出 td 并标记占用区域
|
|
|
+ bbox = cell["bbox"]
|
|
|
+ colspan = cell.get("colspan", 1)
|
|
|
+ rowspan = cell.get("rowspan", 1)
|
|
|
+ text = html.escape(cell.get("text", ""))
|
|
|
+ bbox_str = f"[{int(bbox[0])},{int(bbox[1])},{int(bbox[2])},{int(bbox[3])}]"
|
|
|
+
|
|
|
+ attrs = [f'data-bbox="{bbox_str}"']
|
|
|
+ if colspan > 1:
|
|
|
+ attrs.append(f'colspan="{colspan}"')
|
|
|
+ if rowspan > 1:
|
|
|
+ attrs.append(f'rowspan="{rowspan}"')
|
|
|
+
|
|
|
+ html_parts.append(f'<td {" ".join(attrs)}>{text}</td>')
|
|
|
+
|
|
|
+ # 标记占用
|
|
|
+ for i in range(rowspan):
|
|
|
+ for j in range(colspan):
|
|
|
+ if r + i < max_row and c + j < max_col:
|
|
|
+ occupied[r + i][c + j] = True
|
|
|
+ else:
|
|
|
+ # 无单元格(空洞):输出空 td 占位,防止后续单元格左移
|
|
|
+ # 这种情况通常是网格对齐产生的微小缝隙,或者是漏检
|
|
|
+ html_parts.append("<td></td>")
|
|
|
+ occupied[r][c] = True
|
|
|
+
|
|
|
+ html_parts.append("</tr>")
|
|
|
+
|
|
|
+ html_parts.append("</tbody></table>")
|
|
|
+ return "".join(html_parts)
|
|
|
+
|
|
|
+ def _visualize_grid_structure(
|
|
|
+ self,
|
|
|
+ table_image: np.ndarray,
|
|
|
+ cells: List[Dict],
|
|
|
+ output_path: str
|
|
|
+ ):
|
|
|
+ """可视化表格逻辑结构 (row, col, span)"""
|
|
|
+ vis = table_image.copy()
|
|
|
+ if len(vis.shape) == 2:
|
|
|
+ vis = cv2.cvtColor(vis, cv2.COLOR_GRAY2BGR)
|
|
|
+
|
|
|
+ for cell in cells:
|
|
|
+ x1, y1, x2, y2 = [int(c) for c in cell["bbox"]]
|
|
|
+
|
|
|
+ # 绘制边框
|
|
|
+ cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
+
|
|
|
+ # 绘制逻辑坐标
|
|
|
+ info = f"R{cell['row']}C{cell['col']}"
|
|
|
+ if cell.get('rowspan', 1) > 1: info += f" rs{cell['rowspan']}"
|
|
|
+ if cell.get('colspan', 1) > 1: info += f" cs{cell['colspan']}"
|
|
|
+
|
|
|
+ # 居中显示
|
|
|
+ font_scale = 0.5
|
|
|
+ thickness = 1
|
|
|
+ (tw, th), _ = cv2.getTextSize(info, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
|
|
+ tx = x1 + (x2 - x1 - tw) // 2
|
|
|
+ ty = y1 + (y2 - y1 + th) // 2
|
|
|
+
|
|
|
+ # 描边以增加可读性
|
|
|
+ cv2.putText(vis, info, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness + 2)
|
|
|
+ cv2.putText(vis, info, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 255, 255), thickness)
|
|
|
+
|
|
|
+ cv2.imwrite(output_path, vis)
|
|
|
+ logger.info(f"表格结构可视化: {output_path}")
|
|
|
+
|
|
|
+ def recognize_v4(
|
|
|
+ self,
|
|
|
+ table_image: np.ndarray,
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
+ debug_output_dir: Optional[str] = None,
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ V4版本:直接从表格线计算单元格,绕过 MinerU 的 cal_region_boxes
|
|
|
+ """
|
|
|
+ upscale = self.upscale_ratio if self.upscale_ratio and self.upscale_ratio > 0 else 1.0
|
|
|
+ h, w = table_image.shape[:2]
|
|
|
+
|
|
|
+ # Step 1: 获取 UNet 预测的横竖线 mask(供后续合并检测使用)
|
|
|
+ if upscale != 1.0:
|
|
|
+ img_up = cv2.resize(table_image, (int(w * upscale), int(h * upscale)))
|
|
|
+ else:
|
|
|
+ img_up = table_image
|
|
|
+
|
|
|
+ wired_rec = self.table_model.wired_table_model
|
|
|
+ img = wired_rec.load_img(img_up)
|
|
|
+ img_info = wired_rec.table_structure.preprocess(img)
|
|
|
+ pred = wired_rec.table_structure.infer(img_info)
|
|
|
+
|
|
|
+ hpred = np.where(pred == 1, 255, 0).astype(np.uint8)
|
|
|
+ vpred = np.where(pred == 2, 255, 0).astype(np.uint8)
|
|
|
+
|
|
|
+ h_up, w_up = img_up.shape[:2]
|
|
|
+ hpred_up = cv2.resize(hpred, (w_up, h_up), interpolation=cv2.INTER_NEAREST)
|
|
|
+ vpred_up = cv2.resize(vpred, (w_up, h_up), interpolation=cv2.INTER_NEAREST)
|
|
|
+
|
|
|
+ # Step 1.5: 可视化表格线(调试用)- 需要缩放回原图
|
|
|
+ if debug_output_dir:
|
|
|
+ hpred_orig = cv2.resize(hpred_up, (w, h), interpolation=cv2.INTER_NEAREST)
|
|
|
+ vpred_orig = cv2.resize(vpred_up, (w, h), interpolation=cv2.INTER_NEAREST)
|
|
|
+ self._visualize_table_lines(
|
|
|
+ table_image,
|
|
|
+ hpred_orig,
|
|
|
+ vpred_orig,
|
|
|
+ output_path=f"{debug_output_dir}/unet_table_lines.png"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 2: 使用连通域法提取单元格 (替换了原来的投影法)
|
|
|
+ bboxes = self._compute_cells_from_lines(hpred_up, vpred_up, upscale, debug_output_dir)
|
|
|
+
|
|
|
+ if not bboxes:
|
|
|
+ raise RuntimeError("未能提取出单元格")
|
|
|
+
|
|
|
+ # Step 3: 重建网格结构 (计算 row, col, rowspan, colspan)
|
|
|
+ # 这一步替代了原来的 _merge_cells_without_separator
|
|
|
+ merged_cells = self._recover_grid_structure(bboxes)
|
|
|
+
|
|
|
+ # Step 3.5: 可视化逻辑结构 (新增)
|
|
|
+ if debug_output_dir:
|
|
|
+ self._visualize_grid_structure(
|
|
|
+ table_image, merged_cells,
|
|
|
+ output_path=f"{debug_output_dir}/grid_structure.png"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 4: 统一计算文本填充
|
|
|
+ bboxes_merged = [cell["bbox"] for cell in merged_cells]
|
|
|
+ texts = self._fill_text_by_center_point(bboxes_merged, ocr_boxes or [])
|
|
|
+
|
|
|
+ # Step 4.5: 对空单元格尝试二次 OCR (新增)
|
|
|
+ # 针对漏检问题(特别是竖排小字),进行切片放大识别
|
|
|
+ if hasattr(self, 'ocr_engine') and self.ocr_engine and any(not t for t in texts):
|
|
|
+ crop_list = []
|
|
|
+ crop_indices = []
|
|
|
+ h_img, w_img = table_image.shape[:2]
|
|
|
+ margin = self.cell_crop_margin
|
|
|
+
|
|
|
+ for i, text in enumerate(texts):
|
|
|
+ if text.strip():
|
|
|
+ continue
|
|
|
+
|
|
|
+ bbox = bboxes_merged[i]
|
|
|
+ x1, y1, x2, y2 = map(int, bbox)
|
|
|
+
|
|
|
+ # 边界保护 + 少量外扩
|
|
|
+ x1 = max(0, x1 - margin)
|
|
|
+ y1 = max(0, y1 - margin)
|
|
|
+ x2 = min(w_img, x2 + margin)
|
|
|
+ y2 = min(h_img, y2 + margin)
|
|
|
+
|
|
|
+ if x2 <= x1 or y2 <= y1:
|
|
|
+ continue
|
|
|
+
|
|
|
+ cell_img = table_image[y1:y2, x1:x2]
|
|
|
+ if cell_img.size == 0:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # --- 关键改进:放大与旋转 ---
|
|
|
+ cell_h, cell_w = cell_img.shape[:2]
|
|
|
+
|
|
|
+ # 1. 放大图像:对于表格中的小字,放大能显著提高识别率
|
|
|
+ # 建议放大 2 倍,如果原图特别小可以更大
|
|
|
+ scale = 2.0
|
|
|
+ if cell_h < 64 or cell_w < 64: # 只有较小的图才放大,避免大图过大
|
|
|
+ cell_img = cv2.resize(cell_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
|
|
+
|
|
|
+ # 2. 处理竖排文本:如果高宽比很大(>2),很可能是竖排表头(如"优先股")
|
|
|
+ # 通用 OCR 模型通常只支持横排,旋转 90 度变成横排
|
|
|
+ if cell_h > cell_w * 2:
|
|
|
+ cell_img = cv2.rotate(cell_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
|
|
+ # -------------------------
|
|
|
+
|
|
|
+ crop_list.append(cell_img)
|
|
|
+ crop_indices.append(i)
|
|
|
+
|
|
|
+ if crop_list:
|
|
|
+ try:
|
|
|
+ # 批量识别,det=False 表示直接识别内容(假设裁剪图就是文本行)
|
|
|
+ ocr_res = self.ocr_engine.ocr(crop_list, det=False)
|
|
|
+
|
|
|
+ # 解析结果 (兼容 PaddleOCR 返回格式)
|
|
|
+ # ocr_res 结构通常为 [(text, score), (text, score), ...] 对应每张图
|
|
|
+ # 但有时可能包裹在列表中,需做兼容处理
|
|
|
+ results = ocr_res
|
|
|
+ if isinstance(ocr_res, list) and len(ocr_res) == 1 and isinstance(ocr_res[0], list) and len(ocr_res[0]) == len(crop_list):
|
|
|
+ # 兼容 legacy 代码中遇到的 [[(t,s), (t,s)...]] 情况
|
|
|
+ results = ocr_res[0]
|
|
|
+
|
|
|
+ if len(results) == len(crop_list):
|
|
|
+ for idx, res in enumerate(results):
|
|
|
+ # res 可能是 (text, score) 或 [(text, score)] 或 None
|
|
|
+ if not res: continue
|
|
|
+
|
|
|
+ text = ""
|
|
|
+ score = 0.0
|
|
|
+
|
|
|
+ if isinstance(res, tuple):
|
|
|
+ text, score = res
|
|
|
+ elif isinstance(res, list) and len(res) > 0:
|
|
|
+ text, score = res[0]
|
|
|
+
|
|
|
+ if score >= self.ocr_conf_threshold and text:
|
|
|
+ texts[crop_indices[idx]] = text
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"二次OCR失败: {e}")
|
|
|
+
|
|
|
+ # 将文本填入 merged_cells
|
|
|
+ for i, cell in enumerate(merged_cells):
|
|
|
+ cell["text"] = texts[i] if i < len(texts) else ""
|
|
|
+
|
|
|
+ # Step 5: 生成带文本和 colspan/rowspan 的 HTML
|
|
|
+ html_filled = self._build_html_from_merged_cells(merged_cells)
|
|
|
+
|
|
|
+ # Step 6: 可视化文本填充(调试用)
|
|
|
+ if debug_output_dir:
|
|
|
+ self._visualize_with_text(
|
|
|
+ table_image, bboxes_merged, texts,
|
|
|
+ output_path=f"{debug_output_dir}/text_filled_v4.png"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 7: 组装 cells 输出
|
|
|
+ cells = []
|
|
|
+ for idx, cell in enumerate(merged_cells):
|
|
|
+ cells.append({
|
|
|
+ "bbox": cell["bbox"],
|
|
|
+ "row": cell.get("row", 0),
|
|
|
+ "col": cell.get("col", 0),
|
|
|
+ "rowspan": cell.get("rowspan", 1),
|
|
|
+ "colspan": cell.get("colspan", 1),
|
|
|
+ "text": cell["text"],
|
|
|
+ "matched_text": cell["text"],
|
|
|
+ "score": 100.0,
|
|
|
+ })
|
|
|
+
|
|
|
+ return {
|
|
|
+ "html": html_filled,
|
|
|
+ "cells": cells,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ # ========== 调试可视化 ==========
|
|
|
+ def _visualize_table_lines(
|
|
|
+ self,
|
|
|
+ table_image: np.ndarray,
|
|
|
+ hpred: np.ndarray,
|
|
|
+ vpred: np.ndarray,
|
|
|
+ output_path: str
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """
|
|
|
+ 可视化 UNet 检测到的表格线
|
|
|
+
|
|
|
+ Args:
|
|
|
+ table_image: 原始图片
|
|
|
+ hpred: 横线mask(已缩放到原图大小)
|
|
|
+ vpred: 竖线mask(已缩放到原图大小)
|
|
|
+ output_path: 输出路径
|
|
|
+ """
|
|
|
+ vis_img = table_image.copy()
|
|
|
+ if len(vis_img.shape) == 2:
|
|
|
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2BGR)
|
|
|
+
|
|
|
+ # 横线用红色,竖线用蓝色
|
|
|
+ vis_img[hpred > 128] = [0, 0, 255] # 红色横线
|
|
|
+ vis_img[vpred > 128] = [255, 0, 0] # 蓝色竖线
|
|
|
+
|
|
|
+ cv2.imwrite(output_path, vis_img)
|
|
|
+ logger.info(f"表格线可视化: {output_path}")
|
|
|
+
|
|
|
+ return vis_img
|
|
|
+
|
|
|
+ def _visualize_table_structure(
|
|
|
+ self,
|
|
|
+ image: np.ndarray,
|
|
|
+ bboxes: List[List[float]],
|
|
|
+ output_path: Optional[str] = None,
|
|
|
+ title: str = "Table Structure"
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """
|
|
|
+ 可视化表格结构检测结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: 原始图片
|
|
|
+ bboxes: 单元格坐标 [[x1,y1,x2,y2], ...]
|
|
|
+ output_path: 保存路径(可选)
|
|
|
+ title: 标题
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 标注后的图片
|
|
|
+ """
|
|
|
+ import random
|
|
|
+
|
|
|
+ vis_img = image.copy()
|
|
|
+ if len(vis_img.shape) == 2:
|
|
|
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2BGR)
|
|
|
+
|
|
|
+ # 为每个单元格分配随机颜色
|
|
|
+ colors = []
|
|
|
+ for _ in range(len(bboxes)):
|
|
|
+ colors.append((
|
|
|
+ random.randint(50, 255),
|
|
|
+ random.randint(50, 255),
|
|
|
+ random.randint(50, 255)
|
|
|
+ ))
|
|
|
+
|
|
|
+ # 绘制单元格
|
|
|
+ for idx, bbox in enumerate(bboxes):
|
|
|
+ x1, y1, x2, y2 = map(int, bbox)
|
|
|
+ color = colors[idx]
|
|
|
+
|
|
|
+ # 绘制矩形边框
|
|
|
+ cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
|
|
+
|
|
|
+ # 绘制单元格索引
|
|
|
+ cv2.putText(
|
|
|
+ vis_img, str(idx),
|
|
|
+ (x1 + 2, y1 + 15),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加标题
|
|
|
+ cv2.putText(
|
|
|
+ vis_img, f"{title} ({len(bboxes)} cells)",
|
|
|
+ (10, 25),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2
|
|
|
+ )
|
|
|
+
|
|
|
+ if output_path:
|
|
|
+ cv2.imwrite(output_path, vis_img)
|
|
|
+ logger.info(f"表格结构可视化已保存: {output_path}")
|
|
|
+
|
|
|
+ return vis_img
|
|
|
+
|
|
|
+ def _visualize_with_text(
|
|
|
+ self,
|
|
|
+ image: np.ndarray,
|
|
|
+ bboxes: List[List[float]],
|
|
|
+ texts: List[str],
|
|
|
+ output_path: Optional[str] = None
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """
|
|
|
+ 可视化单元格及其文本内容
|
|
|
+ """
|
|
|
+ vis_img = image.copy()
|
|
|
+ if len(vis_img.shape) == 2:
|
|
|
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2BGR)
|
|
|
+
|
|
|
+ for idx, (bbox, text) in enumerate(zip(bboxes, texts)):
|
|
|
+ x1, y1, x2, y2 = map(int, bbox)
|
|
|
+
|
|
|
+ # 有文本用绿色,无文本用红色
|
|
|
+ color = (0, 255, 0) if text else (0, 0, 255)
|
|
|
+ cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
|
|
+
|
|
|
+ # 显示文本预览(最多10个字符)
|
|
|
+ preview = text[:10] + "..." if len(text) > 10 else text
|
|
|
+ if preview:
|
|
|
+ cv2.putText(
|
|
|
+ vis_img, preview,
|
|
|
+ (x1 + 2, y1 + 15),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 0, 0), 1
|
|
|
+ )
|
|
|
+
|
|
|
+ if output_path:
|
|
|
+ cv2.imwrite(output_path, vis_img)
|
|
|
+ logger.info(f"文本填充可视化已保存: {output_path}")
|
|
|
+
|
|
|
+ return vis_img
|
|
|
+
|
|
|
+ def recognize(
|
|
|
+ self,
|
|
|
+ table_image: np.ndarray,
|
|
|
+ ocr_boxes: List[Dict[str, Any]],
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 统一入口:根据配置选择 recognize_legacy() 或 recognize_v2()。
|
|
|
+
|
|
|
+ 配置项 use_custom_postprocess:
|
|
|
+ - True: 使用 recognize_v2()(自定义后处理)
|
|
|
+ - False: 使用 recognize()(原始流程)
|
|
|
+ """
|
|
|
+ if self.use_custom_postprocess:
|
|
|
+ try:
|
|
|
+ return self.recognize_v4(table_image, ocr_boxes, debug_output_dir="./output")
|
|
|
+ except Exception:
|
|
|
+ # 回退
|
|
|
+ return self.recognize_legacy(table_image, ocr_boxes)
|
|
|
+ else:
|
|
|
+ return self.recognize_legacy(table_image, ocr_boxes)
|