正在收集工作区信息正在筛选到最相关的信息根据代码分析,表格处理的代码主要在以下几个文件中: ## 表格处理流程代码位置 ### 1. **表格识别入口** - `batch_analyze.py` ```python # 1. 裁剪表格区域图像 def get_crop_table_img(scale): crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1]) crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5]) bbox = (int(crop_xmin / scale), int(crop_ymin / scale), int(crop_xmax / scale), int(crop_ymax / scale)) return get_crop_np_img(bbox, np_img, scale=scale) wireless_table_img = get_crop_table_img(scale = 1) wired_table_img = get_crop_table_img(scale = 10/3) # 2. 表格分类(有线/无线表格) table_cls_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.TableCls, ) table_cls_model.batch_predict(table_res_list_all_page, batch_size=16) # 3. OCR检测 - 获取文本框 det_ocr_engine = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.OCR, det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, enable_merge_det_boxes=False, ) for table_res_dict in table_res_list_all_page: bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR) ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0] # 只检测不识别 # 构造dt_box列表 ``` ### 2. **OCR识别** - `batch_analyze.py` (line 157-178) ```python # 4. OCR识别 - 识别文本内容 for _lang, rec_img_list in rec_img_lang_group.items(): ocr_engine = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.OCR, det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=_lang, enable_merge_det_boxes=False, ) cropped_img_list = [item["cropped_img"] for item in rec_img_list] ocr_res_list = ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True)[0] # 回填OCR结果 for img_dict, ocr_res in zip(rec_img_list, ocr_res_list): table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append( [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]] ) ``` ### 3. **无线表格识别** - `batch_analyze.py` (line 192-195) ```python # 5. 无线表格模型预测(所有表格先用无线表格模型) wireless_table_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.WirelessTable, ) wireless_table_model.batch_predict(table_res_list_all_page) ``` 对应的模型代码在 `RapidTableModel`: ```python def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None: """对传入的字典列表进行批量预测,无返回值""" for index in range(0, len(not_none_table_res_list), batch_size): batch_imgs = [cv2.cvtColor(np.asarray(table[i]["table_img"]), cv2.COLOR_RGB2BGR) for i in range(index, min(index + batch_size, len(...)))] batch_ocrs = [table[i]["ocr_result"] for i in range(index, min(index + batch_size, len(...)))] # 调用RapidTable模型 batch_results = self.table_model.batch_predict(batch_imgs, batch_ocrs, batch_size) # 更新结果 for i, result in enumerate(batch_results): not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html ``` ### 4. **有线表格识别** - `batch_analyze.py` (line 204-225) ```python # 6. 有线表格模型预测(针对分类为有线的表格) if wired_table_res_list: for table_res_dict in tqdm(wired_table_res_list, desc="Table-wired Predict"): wired_table_model = atom_model_manager.get_atom_model( atom_model_name=AtomicModel.WiredTable, lang=table_res_dict["lang"], ) # 传入无线表格的HTML作为参考 table_res_dict["table_res"]["html"] = wired_table_model.predict( table_res_dict["wired_table_img"], table_res_dict["ocr_result"], table_res_dict["table_res"].get("html", None) # 无线表格HTML ) ``` 对应的模型代码在 `UnetTableModel`: ```python def predict(self, input_img, ocr_result, wireless_html_code): # 7. 生成有线表格HTML wired_table_results = self.wired_table_model(np_img, ocr_result) wired_html_code = wired_table_results.pred_html # 8. 比较有线/无线表格结果,选择更好的 wired_len = count_table_cells_physical(wired_html_code) wireless_len = count_table_cells_physical(wireless_html_code) # 计算非空单元格数量 wireless_non_blank_count = wireless_len - wireless_blank_count wired_non_blank_count = wired_len - wired_blank_count # 选择更优结果 if wireless_non_blank_count > wired_non_blank_count: return wireless_html_code else: return wired_html_code ``` ### 5. **表格HTML生成** - `WiredTableRecognition` ```python def __call__(self, img, ocr_result=None, **kwargs): # 1. 表格结构检测 polygons, rotated_polygons = self.table_structure(img, **kwargs) # 2. 表格恢复(生成逻辑点和单元格映射) table_res, logi_points = self.table_recover( rotated_polygons, row_threshold, col_threshold ) # 3. OCR结果匹配到单元格 cell_box_det_map = match_ocr_cell(ocr_result, rotated_polygons, logi_points) # 4. 补充空单元格的OCR识别 cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map) # 5. 排序和合并OCR结果 t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list) # 6. 生成HTML pred_html = plot_html_table(logi_points, cell_box_det_map) return WiredTableOutput(pred_html, polygons, logi_points, elapse) ``` HTML生成的具体实现在 `plot_html_table`: ```python def plot_html_table(logi_points, cell_box_map) -> str: table_html = "" # 遍历每行 for row in range(max_row): temp = "" for col in range(max_col): if grid[row][col]: i, row_start, row_end, col_start, col_end = grid[row][col] if row == row_start and col == col_start: # 获取单元格文本 text = "".join(cell_box_map.get(i)) # 计算跨行跨列 row_span = row_end - row_start + 1 col_span = col_end - col_start + 1 # 生成HTML单元格 cell_content = f"" temp += cell_content table_html = table_html + temp + "" table_html += "
{text}
" return table_html ``` ## 完整流程总结 ```mermaid graph LR A[裁剪表格图像] --> B[表格分类
有线/无线] B --> C[OCR检测文本框] C --> D[OCR识别文本] D --> E[无线表格识别
RapidTable] E --> F{是否有线表格?} F -->|是| G[有线表格识别
UnetTable] F -->|否| H[直接使用无线结果] G --> I[比较有线/无线结果] I --> J[选择更优HTML] H --> J J --> K[清理HTML格式] K --> L[返回table HTML] ``` 关键代码文件: - 表格检测和分类: `batch_analyze.py` - 无线表格模型: `slanet_plus/main.py` - 有线表格模型: `unet_table/main.py` - HTML生成: `utils_table_recover.py`