|
@@ -0,0 +1,210 @@
|
|
|
|
|
+正在收集工作区信息正在筛选到最相关的信息根据代码分析,表格处理的代码主要在以下几个文件中:
|
|
|
|
|
+
|
|
|
|
|
+## 表格处理流程代码位置
|
|
|
|
|
+
|
|
|
|
|
+### 1. **表格识别入口** - `batch_analyze.py`
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+# 1. 裁剪表格区域图像
|
|
|
|
|
+def get_crop_table_img(scale):
|
|
|
|
|
+ crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
|
|
|
+ crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
|
|
|
+ bbox = (int(crop_xmin / scale), int(crop_ymin / scale),
|
|
|
|
|
+ int(crop_xmax / scale), int(crop_ymax / scale))
|
|
|
|
|
+ return get_crop_np_img(bbox, np_img, scale=scale)
|
|
|
|
|
+
|
|
|
|
|
+wireless_table_img = get_crop_table_img(scale = 1)
|
|
|
|
|
+wired_table_img = get_crop_table_img(scale = 10/3)
|
|
|
|
|
+
|
|
|
|
|
+# 2. 表格分类(有线/无线表格)
|
|
|
|
|
+table_cls_model = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=AtomicModel.TableCls,
|
|
|
|
|
+)
|
|
|
|
|
+table_cls_model.batch_predict(table_res_list_all_page, batch_size=16)
|
|
|
|
|
+
|
|
|
|
|
+# 3. OCR检测 - 获取文本框
|
|
|
|
|
+det_ocr_engine = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
|
|
+ enable_merge_det_boxes=False,
|
|
|
|
|
+)
|
|
|
|
|
+for table_res_dict in table_res_list_all_page:
|
|
|
|
|
+ bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0] # 只检测不识别
|
|
|
|
|
+ # 构造dt_box列表
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+### 2. **OCR识别** - `batch_analyze.py` (line 157-178)
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+# 4. OCR识别 - 识别文本内容
|
|
|
|
|
+for _lang, rec_img_list in rec_img_lang_group.items():
|
|
|
|
|
+ ocr_engine = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
|
|
+ lang=_lang,
|
|
|
|
|
+ enable_merge_det_boxes=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ cropped_img_list = [item["cropped_img"] for item in rec_img_list]
|
|
|
|
|
+ ocr_res_list = ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True)[0]
|
|
|
|
|
+
|
|
|
|
|
+ # 回填OCR结果
|
|
|
|
|
+ for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
|
|
|
|
|
+ table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append(
|
|
|
|
|
+ [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
|
|
|
|
|
+ )
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+### 3. **无线表格识别** - `batch_analyze.py` (line 192-195)
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+# 5. 无线表格模型预测(所有表格先用无线表格模型)
|
|
|
|
|
+wireless_table_model = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=AtomicModel.WirelessTable,
|
|
|
|
|
+)
|
|
|
|
|
+wireless_table_model.batch_predict(table_res_list_all_page)
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+对应的模型代码在 `RapidTableModel`:
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
|
|
|
|
|
+ """对传入的字典列表进行批量预测,无返回值"""
|
|
|
|
|
+
|
|
|
|
|
+ for index in range(0, len(not_none_table_res_list), batch_size):
|
|
|
|
|
+ batch_imgs = [cv2.cvtColor(np.asarray(table[i]["table_img"]),
|
|
|
|
|
+ cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ for i in range(index, min(index + batch_size, len(...)))]
|
|
|
|
|
+ batch_ocrs = [table[i]["ocr_result"]
|
|
|
|
|
+ for i in range(index, min(index + batch_size, len(...)))]
|
|
|
|
|
+
|
|
|
|
|
+ # 调用RapidTable模型
|
|
|
|
|
+ batch_results = self.table_model.batch_predict(batch_imgs, batch_ocrs, batch_size)
|
|
|
|
|
+
|
|
|
|
|
+ # 更新结果
|
|
|
|
|
+ for i, result in enumerate(batch_results):
|
|
|
|
|
+ not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+### 4. **有线表格识别** - `batch_analyze.py` (line 204-225)
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+# 6. 有线表格模型预测(针对分类为有线的表格)
|
|
|
|
|
+if wired_table_res_list:
|
|
|
|
|
+ for table_res_dict in tqdm(wired_table_res_list, desc="Table-wired Predict"):
|
|
|
|
|
+ wired_table_model = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=AtomicModel.WiredTable,
|
|
|
|
|
+ lang=table_res_dict["lang"],
|
|
|
|
|
+ )
|
|
|
|
|
+ # 传入无线表格的HTML作为参考
|
|
|
|
|
+ table_res_dict["table_res"]["html"] = wired_table_model.predict(
|
|
|
|
|
+ table_res_dict["wired_table_img"],
|
|
|
|
|
+ table_res_dict["ocr_result"],
|
|
|
|
|
+ table_res_dict["table_res"].get("html", None) # 无线表格HTML
|
|
|
|
|
+ )
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+对应的模型代码在 `UnetTableModel`:
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+def predict(self, input_img, ocr_result, wireless_html_code):
|
|
|
|
|
+ # 7. 生成有线表格HTML
|
|
|
|
|
+ wired_table_results = self.wired_table_model(np_img, ocr_result)
|
|
|
|
|
+ wired_html_code = wired_table_results.pred_html
|
|
|
|
|
+
|
|
|
|
|
+ # 8. 比较有线/无线表格结果,选择更好的
|
|
|
|
|
+ wired_len = count_table_cells_physical(wired_html_code)
|
|
|
|
|
+ wireless_len = count_table_cells_physical(wireless_html_code)
|
|
|
|
|
+
|
|
|
|
|
+ # 计算非空单元格数量
|
|
|
|
|
+ wireless_non_blank_count = wireless_len - wireless_blank_count
|
|
|
|
|
+ wired_non_blank_count = wired_len - wired_blank_count
|
|
|
|
|
+
|
|
|
|
|
+ # 选择更优结果
|
|
|
|
|
+ if wireless_non_blank_count > wired_non_blank_count:
|
|
|
|
|
+ return wireless_html_code
|
|
|
|
|
+ else:
|
|
|
|
|
+ return wired_html_code
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+### 5. **表格HTML生成** - `WiredTableRecognition`
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+def __call__(self, img, ocr_result=None, **kwargs):
|
|
|
|
|
+ # 1. 表格结构检测
|
|
|
|
|
+ polygons, rotated_polygons = self.table_structure(img, **kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 表格恢复(生成逻辑点和单元格映射)
|
|
|
|
|
+ table_res, logi_points = self.table_recover(
|
|
|
|
|
+ rotated_polygons, row_threshold, col_threshold
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 3. OCR结果匹配到单元格
|
|
|
|
|
+ cell_box_det_map = match_ocr_cell(ocr_result, rotated_polygons, logi_points)
|
|
|
|
|
+
|
|
|
|
|
+ # 4. 补充空单元格的OCR识别
|
|
|
|
|
+ cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
|
|
|
|
|
+
|
|
|
|
|
+ # 5. 排序和合并OCR结果
|
|
|
|
|
+ t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
|
|
|
|
|
+
|
|
|
|
|
+ # 6. 生成HTML
|
|
|
|
|
+ pred_html = plot_html_table(logi_points, cell_box_det_map)
|
|
|
|
|
+
|
|
|
|
|
+ return WiredTableOutput(pred_html, polygons, logi_points, elapse)
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+HTML生成的具体实现在 `plot_html_table`:
|
|
|
|
|
+
|
|
|
|
|
+```python
|
|
|
|
|
+def plot_html_table(logi_points, cell_box_map) -> str:
|
|
|
|
|
+ table_html = "<html><body><table>"
|
|
|
|
|
+
|
|
|
|
|
+ # 遍历每行
|
|
|
|
|
+ for row in range(max_row):
|
|
|
|
|
+ temp = "<tr>"
|
|
|
|
|
+ for col in range(max_col):
|
|
|
|
|
+ if grid[row][col]:
|
|
|
|
|
+ i, row_start, row_end, col_start, col_end = grid[row][col]
|
|
|
|
|
+ if row == row_start and col == col_start:
|
|
|
|
|
+ # 获取单元格文本
|
|
|
|
|
+ text = "".join(cell_box_map.get(i))
|
|
|
|
|
+
|
|
|
|
|
+ # 计算跨行跨列
|
|
|
|
|
+ row_span = row_end - row_start + 1
|
|
|
|
|
+ col_span = col_end - col_start + 1
|
|
|
|
|
+
|
|
|
|
|
+ # 生成HTML单元格
|
|
|
|
|
+ cell_content = f"<td rowspan={row_span} colspan={col_span}>{text}</td>"
|
|
|
|
|
+ temp += cell_content
|
|
|
|
|
+ table_html = table_html + temp + "</tr>"
|
|
|
|
|
+
|
|
|
|
|
+ table_html += "</table></body></html>"
|
|
|
|
|
+ return table_html
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+## 完整流程总结
|
|
|
|
|
+
|
|
|
|
|
+```mermaid
|
|
|
|
|
+graph LR
|
|
|
|
|
+ A[裁剪表格图像] --> B[表格分类<br/>有线/无线]
|
|
|
|
|
+ B --> C[OCR检测文本框]
|
|
|
|
|
+ C --> D[OCR识别文本]
|
|
|
|
|
+ D --> E[无线表格识别<br/>RapidTable]
|
|
|
|
|
+ E --> F{是否有线表格?}
|
|
|
|
|
+ F -->|是| G[有线表格识别<br/>UnetTable]
|
|
|
|
|
+ F -->|否| H[直接使用无线结果]
|
|
|
|
|
+ G --> I[比较有线/无线结果]
|
|
|
|
|
+ I --> J[选择更优HTML]
|
|
|
|
|
+ H --> J
|
|
|
|
|
+ J --> K[清理HTML格式]
|
|
|
|
|
+ K --> L[返回table HTML]
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+关键代码文件:
|
|
|
|
|
+- 表格检测和分类: `batch_analyze.py`
|
|
|
|
|
+- 无线表格模型: `slanet_plus/main.py`
|
|
|
|
|
+- 有线表格模型: `unet_table/main.py`
|
|
|
|
|
+- HTML生成: `utils_table_recover.py`
|