mineru-无线表格结构与OCR结果合并算法.md 6.8 KB

SLANet(Structure Location Alignment Network)+ 无线表格结构与 OCR 结果合并流程

整体流程图

graph TB
    A["输入图片 + OCR结果"] --> B["SLANet+ 表格结构识别"]
    A --> C["OCR 结果预处理"]
    
    B --> D["pred_structures<br/>HTML标签序列"]
    B --> E["cell_bboxes<br/>单元格坐标"]
    
    C --> F["dt_boxes<br/>OCR文本框坐标"]
    C --> G["rec_res<br/>OCR识别文本"]
    
    F --> H{"过滤OCR结果"}
    G --> H
    E --> H
    
    H --> I["match_result<br/>坐标匹配"]
    E --> I
    
    I --> J["matched_index<br/>单元格-OCR框映射"]
    
    J --> K["get_pred_html<br/>生成HTML"]
    D --> K
    G --> K
    
    K --> L["最终HTML输出"]

核心步骤详解

1. 输入准备 (main.py 第 41-69 行)

def predict(self, img, ocr_result):
    # 1. 提取 OCR 结果中的坐标和文本
    dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
    
    # 2. SLANet+ 模型推理,获取表格结构
    pred_structures, cell_bboxes, _ = self.table_structure.process(img)
    
    # 3. 坐标缩放还原
    cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
    
    # 4. 核心:匹配并生成 HTML
    pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)

两个关键输入

  • pred_structures: SLANet+ 输出的 HTML 标签序列,如 ['<table>', '<tr>', '<td>', '</td>', ...]
  • cell_bboxes: SLANet+ 输出的每个单元格的坐标 [x1, y1, x2, y2]

2. OCR 结果过滤 (matcher.py 第 188-198 行)

def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
    """过滤掉表格区域外的 OCR 结果"""
    y1 = cell_bboxes[:, 1::2].min()  # 表格最小 y 坐标
    new_dt_boxes = []
    new_rec_res = []
    
    for box, rec in zip(dt_boxes, rec_res):
        if np.max(box[1::2]) < y1:  # OCR框在表格上方,跳过
            continue
        new_dt_boxes.append(box)
        new_rec_res.append(rec)
    return new_dt_boxes, new_rec_res

作用:过滤掉表格区域外(如表头上方)的 OCR 结果,避免干扰匹配。


3. 坐标匹配 (matcher.py 第 31-59 行) ⭐ 核心算法

def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
    """将 OCR 文本框匹配到对应的单元格"""
    matched = {}
    
    for i, gt_box in enumerate(dt_boxes):  # 遍历每个 OCR 框
        distances = []
        
        for j, pred_box in enumerate(cell_bboxes):  # 遍历每个单元格
            # 计算两个度量:
            # 1. L1 距离(坐标差的绝对值之和)
            # 2. 1 - IoU(交并比的补)
            distances.append(
                (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
            )
        
        # 按 (1-IoU, L1距离) 排序,优先选择 IoU 高的
        sorted_distances = sorted(distances, key=lambda item: (item[1], item[0]))
        
        # 必须满足最小 IoU 阈值
        if sorted_distances[0][1] >= 1 - min_iou:
            continue
        
        # 记录匹配关系:单元格索引 → OCR框索引列表
        best_cell_idx = distances.index(sorted_distances[0])
        if best_cell_idx not in matched:
            matched[best_cell_idx] = [i]
        else:
            matched[best_cell_idx].append(i)  # 一个单元格可能对应多个 OCR 框
    
    return matched

匹配策略

  1. 对每个 OCR 文本框,计算它与所有单元格的 IoUL1 距离
  2. 优先选择 IoU 最大 的单元格
  3. IoU 相同时,选择 L1 距离最小
  4. 一个单元格可以匹配多个 OCR 框(多行文本)

4. 生成 HTML (matcher.py 第 61-116 行)

def get_pred_html(self, pred_structures, matched_index, ocr_contents):
    """将 OCR 文本填充到表格结构中"""
    end_html = []
    td_index = 0  # 单元格计数器
    
    for tag in pred_structures:
        if "</td>" not in tag:
            end_html.append(tag)
            continue
        
        # 处理 <td></td> 标签
        if "<td></td>" == tag:
            end_html.extend("<td>")
        
        # 如果该单元格有匹配的 OCR 结果
        if td_index in matched_index.keys():
            # 合并多个 OCR 框的文本
            for i, ocr_idx in enumerate(matched_index[td_index]):
                content = ocr_contents[ocr_idx][0]
                
                # 处理多行文本:添加空格分隔
                if len(matched_index[td_index]) > 1:
                    if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
                        content += " "
                
                end_html.extend(content)
        
        if "<td></td>" == tag:
            end_html.append("</td>")
        else:
            end_html.append(tag)
        
        td_index += 1
    
    return "".join(end_html), end_html

填充逻辑

  1. 遍历 pred_structures 中的每个标签
  2. 遇到 <td>...</td> 时,查找 matched_index 获取对应的 OCR 文本
  3. 如果一个单元格匹配了多个 OCR 框,将文本用空格连接
  4. 最终拼接成完整的 HTML 字符串

匹配算法的两个核心度量

IoU(交并比)

def compute_iou(rec1, rec2):
    """计算两个矩形的交并比"""
    # 计算交集面积
    left = max(rec1[0], rec2[0])
    right = min(rec1[2], rec2[2])
    top = max(rec1[1], rec2[1])
    bottom = min(rec1[3], rec2[3])
    
    if left >= right or top >= bottom:
        return 0.0
    
    intersect = (right - left) * (bottom - top)
    union = S_rec1 + S_rec2 - intersect
    return intersect / union

L1 距离

def distance(box_1, box_2):
    """计算两个矩形的 L1 距离"""
    x1, y1, x2, y2 = box_1
    x3, y3, x4, y4 = box_2
    
    # 四个角点的曼哈顿距离之和
    dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
    
    # 加上左上角和右下角的距离(加权)
    dis_2 = abs(x3 - x1) + abs(y3 - y1)
    dis_3 = abs(x4 - x2) + abs(y4 - y2)
    
    return dis + min(dis_2, dis_3)

总结

步骤 输入 输出 说明
1. 结构识别 图片 pred_structures, cell_bboxes SLANet+ 模型推理
2. OCR 过滤 OCR 结果, 单元格坐标 过滤后的 OCR 结果 去除表格外的文本
3. 坐标匹配 OCR 框, 单元格框 matched_index 基于 IoU + L1 距离
4. HTML 生成 结构标签, 匹配索引, OCR 文本 完整 HTML 将文本填充到结构中

关键点

  • SLANet+ 输出的是表格结构(HTML 标签序列)和单元格坐标
  • OCR 输出的是文本框坐标识别文本
  • 通过 IoU + L1 距离 将 OCR 文本框匹配到对应的单元格
  • 一个单元格可以匹配多个 OCR 框(处理多行文本)