Bladeren bron

feat: 新增文字PDF表格处理文档,详细描述表格处理流程及关键代码实现

zhch158_admin 1 maand geleden
bovenliggende
commit
f16ca60453
1 gewijzigde bestanden met toevoegingen van 210 en 0 verwijderingen
  1. 210 0
      zhch/文字PDF-表格处理.md

+ 210 - 0
zhch/文字PDF-表格处理.md

@@ -0,0 +1,210 @@
+正在收集工作区信息正在筛选到最相关的信息根据代码分析,表格处理的代码主要在以下几个文件中:
+
+## 表格处理流程代码位置
+
+### 1. **表格识别入口** - `batch_analyze.py`
+
+```python
+# 1. 裁剪表格区域图像
+def get_crop_table_img(scale):
+    crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
+    crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
+    bbox = (int(crop_xmin / scale), int(crop_ymin / scale), 
+            int(crop_xmax / scale), int(crop_ymax / scale))
+    return get_crop_np_img(bbox, np_img, scale=scale)
+
+wireless_table_img = get_crop_table_img(scale = 1)
+wired_table_img = get_crop_table_img(scale = 10/3)
+
+# 2. 表格分类(有线/无线表格)
+table_cls_model = atom_model_manager.get_atom_model(
+    atom_model_name=AtomicModel.TableCls,
+)
+table_cls_model.batch_predict(table_res_list_all_page, batch_size=16)
+
+# 3. OCR检测 - 获取文本框
+det_ocr_engine = atom_model_manager.get_atom_model(
+    atom_model_name=AtomicModel.OCR,
+    det_db_box_thresh=0.5,
+    det_db_unclip_ratio=1.6,
+    enable_merge_det_boxes=False,
+)
+for table_res_dict in table_res_list_all_page:
+    bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
+    ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0]  # 只检测不识别
+    # 构造dt_box列表
+```
+
+### 2. **OCR识别** - `batch_analyze.py` (line 157-178)
+
+```python
+# 4. OCR识别 - 识别文本内容
+for _lang, rec_img_list in rec_img_lang_group.items():
+    ocr_engine = atom_model_manager.get_atom_model(
+        atom_model_name=AtomicModel.OCR,
+        det_db_box_thresh=0.5,
+        det_db_unclip_ratio=1.6,
+        lang=_lang,
+        enable_merge_det_boxes=False,
+    )
+    cropped_img_list = [item["cropped_img"] for item in rec_img_list]
+    ocr_res_list = ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True)[0]
+    
+    # 回填OCR结果
+    for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
+        table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append(
+            [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
+        )
+```
+
+### 3. **无线表格识别** - `batch_analyze.py` (line 192-195)
+
+```python
+# 5. 无线表格模型预测(所有表格先用无线表格模型)
+wireless_table_model = atom_model_manager.get_atom_model(
+    atom_model_name=AtomicModel.WirelessTable,
+)
+wireless_table_model.batch_predict(table_res_list_all_page)
+```
+
+对应的模型代码在 `RapidTableModel`:
+
+```python
+def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
+    """对传入的字典列表进行批量预测,无返回值"""
+    
+    for index in range(0, len(not_none_table_res_list), batch_size):
+        batch_imgs = [cv2.cvtColor(np.asarray(table[i]["table_img"]), 
+                                   cv2.COLOR_RGB2BGR) 
+                     for i in range(index, min(index + batch_size, len(...)))]
+        batch_ocrs = [table[i]["ocr_result"] 
+                     for i in range(index, min(index + batch_size, len(...)))]
+        
+        # 调用RapidTable模型
+        batch_results = self.table_model.batch_predict(batch_imgs, batch_ocrs, batch_size)
+        
+        # 更新结果
+        for i, result in enumerate(batch_results):
+            not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html
+```
+
+### 4. **有线表格识别** - `batch_analyze.py` (line 204-225)
+
+```python
+# 6. 有线表格模型预测(针对分类为有线的表格)
+if wired_table_res_list:
+    for table_res_dict in tqdm(wired_table_res_list, desc="Table-wired Predict"):
+        wired_table_model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.WiredTable,
+            lang=table_res_dict["lang"],
+        )
+        # 传入无线表格的HTML作为参考
+        table_res_dict["table_res"]["html"] = wired_table_model.predict(
+            table_res_dict["wired_table_img"],
+            table_res_dict["ocr_result"],
+            table_res_dict["table_res"].get("html", None)  # 无线表格HTML
+        )
+```
+
+对应的模型代码在 `UnetTableModel`:
+
+```python
+def predict(self, input_img, ocr_result, wireless_html_code):
+    # 7. 生成有线表格HTML
+    wired_table_results = self.wired_table_model(np_img, ocr_result)
+    wired_html_code = wired_table_results.pred_html
+    
+    # 8. 比较有线/无线表格结果,选择更好的
+    wired_len = count_table_cells_physical(wired_html_code)
+    wireless_len = count_table_cells_physical(wireless_html_code)
+    
+    # 计算非空单元格数量
+    wireless_non_blank_count = wireless_len - wireless_blank_count
+    wired_non_blank_count = wired_len - wired_blank_count
+    
+    # 选择更优结果
+    if wireless_non_blank_count > wired_non_blank_count:
+        return wireless_html_code
+    else:
+        return wired_html_code
+```
+
+### 5. **表格HTML生成** - `WiredTableRecognition`
+
+```python
+def __call__(self, img, ocr_result=None, **kwargs):
+    # 1. 表格结构检测
+    polygons, rotated_polygons = self.table_structure(img, **kwargs)
+    
+    # 2. 表格恢复(生成逻辑点和单元格映射)
+    table_res, logi_points = self.table_recover(
+        rotated_polygons, row_threshold, col_threshold
+    )
+    
+    # 3. OCR结果匹配到单元格
+    cell_box_det_map = match_ocr_cell(ocr_result, rotated_polygons, logi_points)
+    
+    # 4. 补充空单元格的OCR识别
+    cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
+    
+    # 5. 排序和合并OCR结果
+    t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
+    
+    # 6. 生成HTML
+    pred_html = plot_html_table(logi_points, cell_box_det_map)
+    
+    return WiredTableOutput(pred_html, polygons, logi_points, elapse)
+```
+
+HTML生成的具体实现在 `plot_html_table`:
+
+```python
+def plot_html_table(logi_points, cell_box_map) -> str:
+    table_html = "<html><body><table>"
+    
+    # 遍历每行
+    for row in range(max_row):
+        temp = "<tr>"
+        for col in range(max_col):
+            if grid[row][col]:
+                i, row_start, row_end, col_start, col_end = grid[row][col]
+                if row == row_start and col == col_start:
+                    # 获取单元格文本
+                    text = "".join(cell_box_map.get(i))
+                    
+                    # 计算跨行跨列
+                    row_span = row_end - row_start + 1
+                    col_span = col_end - col_start + 1
+                    
+                    # 生成HTML单元格
+                    cell_content = f"<td rowspan={row_span} colspan={col_span}>{text}</td>"
+                    temp += cell_content
+        table_html = table_html + temp + "</tr>"
+    
+    table_html += "</table></body></html>"
+    return table_html
+```
+
+## 完整流程总结
+
+```mermaid
+graph LR
+    A[裁剪表格图像] --> B[表格分类<br/>有线/无线]
+    B --> C[OCR检测文本框]
+    C --> D[OCR识别文本]
+    D --> E[无线表格识别<br/>RapidTable]
+    E --> F{是否有线表格?}
+    F -->|是| G[有线表格识别<br/>UnetTable]
+    F -->|否| H[直接使用无线结果]
+    G --> I[比较有线/无线结果]
+    I --> J[选择更优HTML]
+    H --> J
+    J --> K[清理HTML格式]
+    K --> L[返回table HTML]
+```
+
+关键代码文件:
+- 表格检测和分类: `batch_analyze.py`
+- 无线表格模型: `slanet_plus/main.py`
+- 有线表格模型: `unet_table/main.py`
+- HTML生成: `utils_table_recover.py`