graph TB
A["输入图片 + OCR结果"] --> B["SLANet+ 表格结构识别"]
A --> C["OCR 结果预处理"]
B --> D["pred_structures<br/>HTML标签序列"]
B --> E["cell_bboxes<br/>单元格坐标"]
C --> F["dt_boxes<br/>OCR文本框坐标"]
C --> G["rec_res<br/>OCR识别文本"]
F --> H{"过滤OCR结果"}
G --> H
E --> H
H --> I["match_result<br/>坐标匹配"]
E --> I
I --> J["matched_index<br/>单元格-OCR框映射"]
J --> K["get_pred_html<br/>生成HTML"]
D --> K
G --> K
K --> L["最终HTML输出"]
main.py 第 41-69 行)def predict(self, img, ocr_result):
# 1. 提取 OCR 结果中的坐标和文本
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
# 2. SLANet+ 模型推理,获取表格结构
pred_structures, cell_bboxes, _ = self.table_structure.process(img)
# 3. 坐标缩放还原
cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
# 4. 核心:匹配并生成 HTML
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
两个关键输入:
pred_structures: SLANet+ 输出的 HTML 标签序列,如 ['<table>', '<tr>', '<td>', '</td>', ...]cell_bboxes: SLANet+ 输出的每个单元格的坐标 [x1, y1, x2, y2]matcher.py 第 188-198 行)def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
"""过滤掉表格区域外的 OCR 结果"""
y1 = cell_bboxes[:, 1::2].min() # 表格最小 y 坐标
new_dt_boxes = []
new_rec_res = []
for box, rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1: # OCR框在表格上方,跳过
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res
作用:过滤掉表格区域外(如表头上方)的 OCR 结果,避免干扰匹配。
matcher.py 第 31-59 行) ⭐ 核心算法def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
"""将 OCR 文本框匹配到对应的单元格"""
matched = {}
for i, gt_box in enumerate(dt_boxes): # 遍历每个 OCR 框
distances = []
for j, pred_box in enumerate(cell_bboxes): # 遍历每个单元格
# 计算两个度量:
# 1. L1 距离(坐标差的绝对值之和)
# 2. 1 - IoU(交并比的补)
distances.append(
(distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
)
# 按 (1-IoU, L1距离) 排序,优先选择 IoU 高的
sorted_distances = sorted(distances, key=lambda item: (item[1], item[0]))
# 必须满足最小 IoU 阈值
if sorted_distances[0][1] >= 1 - min_iou:
continue
# 记录匹配关系:单元格索引 → OCR框索引列表
best_cell_idx = distances.index(sorted_distances[0])
if best_cell_idx not in matched:
matched[best_cell_idx] = [i]
else:
matched[best_cell_idx].append(i) # 一个单元格可能对应多个 OCR 框
return matched
匹配策略:
matcher.py 第 61-116 行)def get_pred_html(self, pred_structures, matched_index, ocr_contents):
"""将 OCR 文本填充到表格结构中"""
end_html = []
td_index = 0 # 单元格计数器
for tag in pred_structures:
if "</td>" not in tag:
end_html.append(tag)
continue
# 处理 <td></td> 标签
if "<td></td>" == tag:
end_html.extend("<td>")
# 如果该单元格有匹配的 OCR 结果
if td_index in matched_index.keys():
# 合并多个 OCR 框的文本
for i, ocr_idx in enumerate(matched_index[td_index]):
content = ocr_contents[ocr_idx][0]
# 处理多行文本:添加空格分隔
if len(matched_index[td_index]) > 1:
if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
content += " "
end_html.extend(content)
if "<td></td>" == tag:
end_html.append("</td>")
else:
end_html.append(tag)
td_index += 1
return "".join(end_html), end_html
填充逻辑:
pred_structures 中的每个标签<td>...</td> 时,查找 matched_index 获取对应的 OCR 文本def compute_iou(rec1, rec2):
"""计算两个矩形的交并比"""
# 计算交集面积
left = max(rec1[0], rec2[0])
right = min(rec1[2], rec2[2])
top = max(rec1[1], rec2[1])
bottom = min(rec1[3], rec2[3])
if left >= right or top >= bottom:
return 0.0
intersect = (right - left) * (bottom - top)
union = S_rec1 + S_rec2 - intersect
return intersect / union
def distance(box_1, box_2):
"""计算两个矩形的 L1 距离"""
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
# 四个角点的曼哈顿距离之和
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
# 加上左上角和右下角的距离(加权)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4 - x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
| 步骤 | 输入 | 输出 | 说明 |
|---|---|---|---|
| 1. 结构识别 | 图片 | pred_structures, cell_bboxes |
SLANet+ 模型推理 |
| 2. OCR 过滤 | OCR 结果, 单元格坐标 | 过滤后的 OCR 结果 | 去除表格外的文本 |
| 3. 坐标匹配 | OCR 框, 单元格框 | matched_index |
基于 IoU + L1 距离 |
| 4. HTML 生成 | 结构标签, 匹配索引, OCR 文本 | 完整 HTML | 将文本填充到结构中 |
关键点: