Bladeren bron

feat: 添加Batch数据生成流程文档,详细描述数据加载和预处理过程

zhch158_admin 3 maanden geleden
bovenliggende
commit
221051dd07
1 gewijzigde bestanden met toevoegingen van 384 en 0 verwijderingen
  1. 384 0
      zhch/table_structure_recognition/eval程序流程说明.md

+ 384 - 0
zhch/table_structure_recognition/eval程序流程说明.md

@@ -0,0 +1,384 @@
+
+## Batch数据生成流程
+
+### 1. **入口点 - eval.py**
+
+在 `eval.py` 第37行:
+
+```python
+# build dataloader
+set_signal_handlers()
+valid_dataloader = build_dataloader(config, "Eval", device, logger)
+```
+
+这里调用 `build_dataloader` 函数构建数据加载器。
+
+### 2. **DataLoader构建**
+
+`build_dataloader` 函数会:
+- 根据配置文件创建对应的Dataset
+- 设置数据预处理pipeline
+- 创建DataLoader用于批量加载数据
+
+### 3. **数据集类型**
+
+对于表格识别任务,通常使用的数据集类是 **`TableDataSet`** 或类似的表格专用数据集类。
+
+### 4. **数据预处理Pipeline**
+
+根据您显示的batch数据结构,预处理pipeline包括:
+
+```python
+# 典型的表格识别预处理流程
+transforms = [
+    DecodeImage(),           # 图像解码
+    TableLabelEncode(),      # 表格标签编码
+    TableBoxEncode(),        # 边界框编码
+    ResizeTableImage(),      # 图像尺寸调整
+    NormalizeImage(),        # 图像归一化
+    PaddingTableImage(),     # 图像填充
+    ToTensor(),             # 转换为张量
+    Collect()               # 收集数据
+]
+```
+
+### 5. **具体数据生成位置**
+
+#### TableLabelEncode
+负责生成 `batch[1]` (结构序列标签):
+```python
+# 将HTML结构转换为token ID序列
+# 例如: "<tr><td></td></tr>" -> [5, 7, 8, 9, ...]
+```
+
+#### TableBoxEncode  
+负责生成 `batch[2]` (边界框坐标):
+```python
+# 将单元格边界框转换为归一化坐标
+# 每个单元格8个坐标值 [x1,y1,x2,y2,x3,y3,x4,y4]
+```
+
+#### 其他编码器
+- `batch[3]`: 单元格掩码 (有效性标识)
+- `batch[4]`: 序列长度 (每个样本的实际token数量)
+- `batch[5]`: 图像元信息 (原始尺寸、缩放比例等)
+
+### 6. **数据流转过程**
+
+```mermaid
+graph TB
+    A[原始标注文件] --> B[Dataset.__getitem__]
+    B --> C[图像读取]
+    B --> D[标签解析]
+    
+    C --> E[图像预处理Pipeline]
+    D --> F[标签编码Pipeline]
+    
+    E --> G["batch[0]: 图像张量"]
+    F --> H["batch[1]: 结构序列"]
+    F --> I["batch[2]: 边界框"]
+    F --> J["batch[3]: 掩码"]
+    F --> K["batch[4]: 序列长度"]
+    E --> L["batch[5]: 图像信息"]
+    
+    G --> M[DataLoader.collate_fn]
+    H --> M
+    I --> M
+    J --> M
+    K --> M
+    L --> M
+    
+    M --> N[Batch数据]
+    N --> O[传入program.eval]
+```
+
+### 7. **关键配置文件位置**
+
+数据预处理的具体配置通常在:
+```yaml
+# 配置文件中的数据部分
+DataLoader:
+  Eval:
+    dataset:
+      name: TableDataset
+      data_dir: ./dataset/
+      label_file_list: ["./dataset/val.txt"]
+      transforms:
+        - DecodeImage:
+            img_mode: RGB
+            channel_first: False
+        - TableLabelEncode:
+            max_text_length: 500
+        - TableBoxEncode: null
+        - ResizeTableImage:
+            max_len: 488
+        - NormalizeImage:
+            scale: 1./255.
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+        - PaddingTableImage: null
+        - ToTensor: null
+        - Collect:
+            keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
+```
+
+### 8. **数据加载时机**
+
+在 `program.py` 第667行的循环中:
+
+```python
+for idx, batch in enumerate(valid_dataloader):
+    # 这里的batch就是预处理完成的6元组数据
+    # batch[0]: 图像张量 [48, 3, 488, 488]
+    # batch[1]: 结构序列 [48, 502] 
+    # batch[2]: 边界框 [48, 502, 8]
+    # batch[3]: 掩码 [48, 502, 1]
+    # batch[4]: 序列长度 [48]
+    # batch[5]: 图像信息 [48, 6]
+```
+
+## 总结
+
+Batch数据是在 **DataLoader的数据预处理pipeline** 中生成的,具体包括:
+
+1. **Dataset类** 负责读取原始数据和标注
+2. **Transform pipeline** 负责数据预处理和编码
+3. **DataLoader** 负责批量组装和张量化
+4. **Collate函数** 负责批数据的最终整理
+
+这个过程将原始的表格图像和HTML标注转换为模型训练/评估所需的标准化张量格式。
+
+正在收集工作区信息正在筛选到最相关的信息根据您的工作空间代码分析,DataLoader并不直接调用模型SLANet_plus。实际的调用链是这样的:
+
+
+## DataLoader与模型的调用关系
+
+### 1. **DataLoader的作用**
+DataLoader只负责**数据加载和预处理**,不直接调用模型:
+
+```python
+# 在 eval.py 第37行
+valid_dataloader = build_dataloader(config, "Eval", device, logger)
+```
+
+DataLoader通过以下流程工作:
+```python
+# 从 __init__.py 可以看到
+def build_dataloader(config, mode, device, logger, seed=None):
+    # 创建数据集
+    dataset = eval(module_name)(config, mode, logger, seed)  # PubTabDataSet
+    
+    # 创建DataLoader
+    data_loader = DataLoader(
+        dataset=dataset,
+        batch_sampler=batch_sampler,
+        places=device,
+        num_workers=num_workers,
+        return_list=True,
+        use_shared_memory=use_shared_memory,
+        collate_fn=collate_fn,
+    )
+    return data_loader
+```
+
+### 2. **模型的实际调用位置**
+
+模型SLANet_plus是在 `eval.py` 中被构建和调用的:
+
+```python
+# eval.py 第83行 - 构建模型
+model = build_model(config["Architecture"])
+
+# eval.py 第163行 - 调用模型进行评估
+metric = program.eval(
+    model,                    # ← SLANet_plus模型在这里被传入
+    valid_dataloader,         # ← DataLoader提供数据
+    post_process_class,
+    eval_class,
+    model_type,
+    extra_input,
+    scaler,
+    amp_level,
+    amp_custom_black_list,
+)
+```
+
+### 3. **模型配置来源**
+
+SLANet_plus模型是根据您的配置文件 `SLANet_plus_paddleocr.yml` 第40行构建的:
+
+```yaml
+Architecture:
+  model_type: table
+  algorithm: SLANet      # ← 指定算法类型
+  Backbone:
+    name: PPLCNet        # ← 指定骨干网络
+    scale: 1.0
+    pretrained: True
+    use_ssld: True
+  Neck:
+    name: CSPPAN         # ← 指定颈部网络
+    out_channels: 96
+  Head:
+    name: SLAHead        # ← 指定头部网络
+    hidden_size: 256
+    max_text_length: 500
+    loc_reg_num: 8
+```
+
+### 4. **完整的调用流程**
+
+```mermaid
+graph TB
+    A["SLANet_plus_paddleocr.yml"] --> B["eval.py main()"]
+    
+    B --> C["build_dataloader()"]
+    C --> D["PubTabDataSet"]
+    D --> E["DataLoader实例"]
+    
+    B --> F["build_model()"]
+    F --> G["SLANet_plus模型实例"]
+    
+    B --> H["program.eval()"]
+    E --> H
+    G --> H
+    
+    H --> I["模型推理循环"]
+    I --> J["for batch in dataloader:"]
+    J --> K["model(images, data=batch)"]
+    
+    style A fill:#e1f5fe
+    style E fill:#fff3e0
+    style G fill:#e8f5e8
+    style K fill:#ffebee
+```
+
+### 5. **实际的模型调用代码**
+
+在 `program.py` 的eval函数中:
+
+```python
+def eval(model, valid_dataloader, post_process_class, eval_class, model_type=None, ...):
+    with paddle.no_grad():
+        for idx, batch in enumerate(valid_dataloader):  # ← DataLoader提供batch数据
+            images = batch[0]
+            
+            # 模型推理 - 这里是真正调用SLANet_plus的地方
+            if model_type == "table" or extra_input:
+                preds = model(images, data=batch[1:])  # ← SLANet_plus在这里被调用
+            
+            # 后处理和评估
+            post_result = post_process_class(preds, batch)
+            eval_class(post_result, batch)
+```
+
+## 总结
+
+**DataLoader不直接调用模型**,而是:
+
+1. **DataLoader**: 负责加载和预处理表格数据,生成batch
+2. **model (SLANet_plus)**: 接收DataLoader提供的batch数据进行推理
+3. **program.eval()**: 协调DataLoader和模型之间的交互
+
+这种设计分离了数据处理和模型推理的职责,使代码更加模块化和可维护。
+
+
+## Batch数据结构解析
+
+### batch[0] - 输入图像张量
+```python
+Tensor(shape=[48, 3, 488, 488], dtype=float32)
+```
+- **含义**: 预处理后的表格图像
+- **维度**: [batch_size, channels, height, width]
+- **数值**: 48张图像,RGB 3通道,尺寸488×488像素
+- **数据类型**: 已归一化的float32图像数据
+
+### batch[1] - 结构序列标签
+```python
+Tensor(shape=[48, 502], dtype=int64)
+```
+- **含义**: 表格HTML结构序列的token ID
+- **维度**: [batch_size, max_sequence_length]
+- **数值**: 每个样本最大502个token的结构序列
+- **内容**: HTML标签对应的数字ID,如 `<tr>`, `<td>`, `<td colspan="2">` 等
+
+### batch[2] - 边界框坐标
+```python
+Tensor(shape=[48, 502, 8], dtype=float32)
+```
+- **含义**: 表格单元格的边界框坐标
+- **维度**: [batch_size, max_sequence_length, 8]
+- **数值**: 每个token对应的8个坐标值(四边形的4个顶点坐标)
+- **格式**: [x1, y1, x2, y2, x3, y3, x4, y4] 表示四边形顶点
+
+### batch[3] - 单元格掩码
+```python
+Tensor(shape=[48, 502, 1], dtype=float32)
+```
+- **含义**: 标识哪些位置是有效的单元格
+- **维度**: [batch_size, max_sequence_length, 1]
+- **数值**: 1表示有效单元格,0表示填充位置
+
+### batch[4] - 序列长度
+```python
+Tensor(shape=[48], dtype=int64)
+```
+- **含义**: 每个样本的实际结构序列长度
+- **维度**: [batch_size]
+- **数值**: 如[146, 47, 148, 65, ...]表示每个样本的真实token数量
+
+### batch[5] - 图像信息
+```python
+Tensor(shape=[48, 6], dtype=float64)
+```
+- **含义**: 图像预处理相关的元信息
+- **维度**: [batch_size, 6]
+- **数值含义**:
+  - `[0]`: 原始图像宽度
+  - `[1]`: 原始图像高度  
+  - `[2]`: 宽度缩放比例
+  - `[3]`: 高度缩放比例
+  - `[4]`: 目标宽度 (488)
+  - `[5]`: 目标高度 (488)
+
+## 表格识别训练的数据流程
+
+```mermaid
+graph LR
+    A[原始表格图像] --> B[图像预处理]
+    B --> C["batch[0]: 图像张量"]
+    
+    D[HTML结构标注] --> E[Token化]
+    E --> F["batch[1]: 结构序列"]
+    
+    G[单元格标注] --> H[坐标归一化]
+    H --> I["batch[2]: 边界框"]
+    
+    J[有效性标注] --> K["batch[3]: 掩码"]
+    
+    L[序列长度统计] --> M["batch[4]: 长度"]
+    
+    N[图像元信息] --> O["batch[5]: 尺寸信息"]
+```
+
+## 在SLANet模型中的使用
+
+根据您之前显示的代码,这些数据会被这样使用:
+
+```python
+# 从program.py第683行可以看到
+if model_type == "table" or extra_input:
+    preds = model(images, data=batch[1:])  # images=batch[0], data=batch[1:]
+```
+
+- **images**: batch[0] 作为图像输入
+- **data**: batch[1:] 包含结构序列、边界框、掩码等作为额外监督信息
+
+这种设计使得SLANet模型能够:
+1. **视觉理解**: 通过图像学习表格的视觉特征
+2. **结构学习**: 通过HTML序列学习表格的逻辑结构
+3. **空间定位**: 通过边界框学习单元格的精确位置
+4. **联合优化**: 同时优化结构识别和位置检测两个任务
+
+这就是为什么SLANet在表格识别任务中能够达到高精度的原因 - 它结合了视觉、结构和空间信息的多模态学习。