正在收集工作区信息正在筛选到最相关的信息根据代码分析,表格处理的代码主要在以下几个文件中:
batch_analyze.py# 1. 裁剪表格区域图像
def get_crop_table_img(scale):
crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
bbox = (int(crop_xmin / scale), int(crop_ymin / scale),
int(crop_xmax / scale), int(crop_ymax / scale))
return get_crop_np_img(bbox, np_img, scale=scale)
wireless_table_img = get_crop_table_img(scale = 1)
wired_table_img = get_crop_table_img(scale = 10/3)
# 2. 表格分类(有线/无线表格)
table_cls_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.TableCls,
)
table_cls_model.batch_predict(table_res_list_all_page, batch_size=16)
# 3. OCR检测 - 获取文本框
det_ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
enable_merge_det_boxes=False,
)
for table_res_dict in table_res_list_all_page:
bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0] # 只检测不识别
# 构造dt_box列表
batch_analyze.py (line 157-178)# 4. OCR识别 - 识别文本内容
for _lang, rec_img_list in rec_img_lang_group.items():
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=_lang,
enable_merge_det_boxes=False,
)
cropped_img_list = [item["cropped_img"] for item in rec_img_list]
ocr_res_list = ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True)[0]
# 回填OCR结果
for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append(
[img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
)
batch_analyze.py (line 192-195)# 5. 无线表格模型预测(所有表格先用无线表格模型)
wireless_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WirelessTable,
)
wireless_table_model.batch_predict(table_res_list_all_page)
对应的模型代码在 RapidTableModel:
def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
"""对传入的字典列表进行批量预测,无返回值"""
for index in range(0, len(not_none_table_res_list), batch_size):
batch_imgs = [cv2.cvtColor(np.asarray(table[i]["table_img"]),
cv2.COLOR_RGB2BGR)
for i in range(index, min(index + batch_size, len(...)))]
batch_ocrs = [table[i]["ocr_result"]
for i in range(index, min(index + batch_size, len(...)))]
# 调用RapidTable模型
batch_results = self.table_model.batch_predict(batch_imgs, batch_ocrs, batch_size)
# 更新结果
for i, result in enumerate(batch_results):
not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html
batch_analyze.py (line 204-225)# 6. 有线表格模型预测(针对分类为有线的表格)
if wired_table_res_list:
for table_res_dict in tqdm(wired_table_res_list, desc="Table-wired Predict"):
wired_table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.WiredTable,
lang=table_res_dict["lang"],
)
# 传入无线表格的HTML作为参考
table_res_dict["table_res"]["html"] = wired_table_model.predict(
table_res_dict["wired_table_img"],
table_res_dict["ocr_result"],
table_res_dict["table_res"].get("html", None) # 无线表格HTML
)
对应的模型代码在 UnetTableModel:
def predict(self, input_img, ocr_result, wireless_html_code):
# 7. 生成有线表格HTML
wired_table_results = self.wired_table_model(np_img, ocr_result)
wired_html_code = wired_table_results.pred_html
# 8. 比较有线/无线表格结果,选择更好的
wired_len = count_table_cells_physical(wired_html_code)
wireless_len = count_table_cells_physical(wireless_html_code)
# 计算非空单元格数量
wireless_non_blank_count = wireless_len - wireless_blank_count
wired_non_blank_count = wired_len - wired_blank_count
# 选择更优结果
if wireless_non_blank_count > wired_non_blank_count:
return wireless_html_code
else:
return wired_html_code
WiredTableRecognitiondef __call__(self, img, ocr_result=None, **kwargs):
# 1. 表格结构检测
polygons, rotated_polygons = self.table_structure(img, **kwargs)
# 2. 表格恢复(生成逻辑点和单元格映射)
table_res, logi_points = self.table_recover(
rotated_polygons, row_threshold, col_threshold
)
# 3. OCR结果匹配到单元格
cell_box_det_map = match_ocr_cell(ocr_result, rotated_polygons, logi_points)
# 4. 补充空单元格的OCR识别
cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
# 5. 排序和合并OCR结果
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
# 6. 生成HTML
pred_html = plot_html_table(logi_points, cell_box_det_map)
return WiredTableOutput(pred_html, polygons, logi_points, elapse)
HTML生成的具体实现在 plot_html_table:
def plot_html_table(logi_points, cell_box_map) -> str:
table_html = "<html><body><table>"
# 遍历每行
for row in range(max_row):
temp = "<tr>"
for col in range(max_col):
if grid[row][col]:
i, row_start, row_end, col_start, col_end = grid[row][col]
if row == row_start and col == col_start:
# 获取单元格文本
text = "".join(cell_box_map.get(i))
# 计算跨行跨列
row_span = row_end - row_start + 1
col_span = col_end - col_start + 1
# 生成HTML单元格
cell_content = f"<td rowspan={row_span} colspan={col_span}>{text}</td>"
temp += cell_content
table_html = table_html + temp + "</tr>"
table_html += "</table></body></html>"
return table_html
graph LR
A[裁剪表格图像] --> B[表格分类<br/>有线/无线]
B --> C[OCR检测文本框]
C --> D[OCR识别文本]
D --> E[无线表格识别<br/>RapidTable]
E --> F{是否有线表格?}
F -->|是| G[有线表格识别<br/>UnetTable]
F -->|否| H[直接使用无线结果]
G --> I[比较有线/无线结果]
I --> J[选择更优HTML]
H --> J
J --> K[清理HTML格式]
K --> L[返回table HTML]
关键代码文件:
batch_analyze.pyslanet_plus/main.pyunet_table/main.pyutils_table_recover.py