10 次代碼提交 9375927c14 ... 808f38d864

作者 SHA1 備註 提交日期
  zhch158_admin 808f38d864 feat: 更新 DiT 和 Detectron2 模型对比分析,增强文档内容 1 周之前
  zhch158_admin 4ec22c91f7 feat: 新增 Docling RT-DETR 与 DiT 模型性能对比分析文档 1 周之前
  zhch158_admin bdc29cb5a4 feat: 新增 DiT Layout Detector 测试脚本 1 周之前
  zhch158_admin 23326cb1b6 feat: 增强布局处理工具类,新增类别合并限制和误检过滤功能 1 周之前
  zhch158_admin a8e2a5d3e2 feat: 新增文档版面检测模型发展路径整理文档 1 周之前
  zhch158_admin 156c3d90ae feat: 更新 README 文档,新增非结构化文档统一架构概述 1 周之前
  zhch158_admin 66103ab214 feat: 新增 DiT Layout Detector 适配器及其核心功能 1 周之前
  zhch158_admin 1cdd879991 feat: 添加 DiT 适配器的可选导入和布局检测支持 1 周之前
  zhch158_admin 20d936e629 feat: 新增 DiT 支持模块及其核心功能 1 周之前
  zhch158_admin 39d16d50a7 fix: 更新 MinerU 文档中的交点计算结果 1 周之前

+ 68 - 0
README.md

@@ -28,6 +28,74 @@
 - **守护进程**:各种 OCR 工具的守护进程脚本
 - **公共工具包**:统一的工具函数库
 
+## 🏗️ 非结构化文档统一架构
+
+### 架构概览
+
+平台采用"技术层 + 应用层"的分层架构设计,技术层负责OCR识别、校验和迭代优化,应用层负责业务规则校验和报告生成。
+
+
+### 架构说明
+
+#### 技术层(左侧)
+
+1. **OCR工具评估与选择**
+   - 基于 OmniDocBench 公开数据集评估各工具性能
+   - 重点关注扫描文档和表格识别指标
+   - 根据场景特点选择最适合的工具
+
+2. **工具封装与增强**
+   - 采用 Client/Server 架构,支持多线程和批量处理
+   - 通过守护进程管理提升稳定性
+   - 针对流水分析、财报分析等场景构建样本数据
+
+3. **OCR识别工具层(平级关系)**
+   - **MinerU VLM**:多模态文档理解
+   - **PaddleOCR-VL**:基于视觉语言模型的文档解析
+   - **DotsOCR VLM**:专业 VLM OCR 引擎
+   - **PP-StructureV3**:全面的文档结构分析
+   - **统一文档解析器(Universal Doc Parser)**:通过 YAML 配置文件灵活组合多种模型/算法,适配不同业务场景
+     - 图片方向识别(Orientation)
+     - Layout 检测(Layout Detection)
+     - OCR 识别(OCR Recognition)
+     - PDF 文字提取(PDF Text Extraction)
+     - 表格识别(Table Recognition)
+
+4. **技术校验**
+   - **交叉验证**:多工具结果智能对比,快速发现差异
+   - **人工交互校验**:可视化界面支持人工审核
+   - **差异分析**:精确到单元格级别的差异检测
+
+5. **标注与改进**
+   - 当识别效果不佳时,使用标注工具将无线表格标注为有线表格
+   - 使用特定配置文件(如 `bank_statement_wired_unet.yaml`)重新识别
+   - 通过有线表格识别模式显著提升准确率
+
+6. **数据迭代闭环**
+   - 标注图片 → 训练样本数据
+   - 模型训练优化 → 模型更新
+   - 更新后的模型反馈到 OCR 工具,形成持续优化的闭环
+
+#### 应用层(右侧)
+
+1. **业务校验**
+   - **流水连续性检查**:验证银行流水的时间连续性
+   - **关键要素缺失检测**:检查必要字段是否完整
+   - **财报勾稽关系验证**:验证财务报表的勾稽关系
+   - **数据修改审核**:记录并审核数据修改历史
+
+2. **报告生成**
+   - 生成结构化分析报告
+   - 支持交互式问答
+
+### 关键特性
+
+- ✅ **多工具并行识别**:支持多种 OCR 工具同时处理,通过交叉验证提升准确率
+- ✅ **配置化统一框架**:Universal Doc Parser 通过 YAML 配置灵活组合模型,适配不同场景
+- ✅ **智能标注改进**:无线表格标注为有线表格,显著提升识别准确率
+- ✅ **数据迭代优化**:标注数据形成训练样本,持续优化模型性能
+- ✅ **业务规则校验**:基于业务规则确保识别结果的正确性
+
 ## 📁 项目结构
 
 ```

+ 1 - 1
docs/mineru/表格识别模块说明.md

@@ -327,7 +327,7 @@ UNet输出:
 横线掩码:[━━━━━━━━━━](3条清晰横线)
 竖线掩码:[┃ ┃ ┃ ┃](4条清晰竖线)
 
-→ 求交点 → 得到精确的9个单元格框
+→ 求交点 → 得到精确的6个单元格框
 ✅ 100% 对应真实网格
 ````
 

+ 480 - 0
docs/ocr_tools/universal_doc_parser/layout模型发展路径整理.md

@@ -0,0 +1,480 @@
+# Layout Detection 模型发展路径整理
+
+本文档基于技术讨论和实践经验,系统梳理了文档版面检测(Layout Detection)模型的发展历程,使用 mermaid 图表进行多视角对比和分析。
+
+## 📊 一、发展路线图(时间轴视角)
+
+```mermaid
+timeline
+    title Layout Detection 模型发展时间轴
+    
+    2015 : UNet 诞生
+         : 像素级语义分割
+         : 医学图像分割应用
+    
+    2018 : UNet++ / Attention UNet
+         : 密集跳跃连接
+         : 注意力机制引入
+    
+    2018 : YOLOv3 进入文档领域
+         : 目标检测方法
+         : 工业界采用
+    
+    2020 : UNet3+ 发布
+         : 全尺度特征融合
+         : YOLOv5 广泛应用
+    
+    2021 : LayoutLMv3 发布
+         : 多模态 Transformer
+         : 文档理解 SOTA
+    
+    2021 : DiT / BEiT 发布
+         : 纯视觉 Transformer
+         : 自监督预训练
+    
+    2022 : YOLOv8 发布
+         : 工业界主流
+         : 快速部署
+    
+    2023 : Detectron2 + DiT
+         : 组合架构
+         : 文档布局 SOTA
+```
+
+## 🔄 二、技术演进路径(架构视角)
+
+```mermaid
+flowchart TB
+    subgraph Era1 [第一代: 像素级分割时代 2015-2018]
+        direction TB
+        UNet[UNet<br/>Encoder-Decoder<br/>Skip Connections]
+        FCN[FCN<br/>全卷积网络]
+        SegNet[SegNet<br/>编码-解码结构]
+        
+        UNet --> UNetPlus[UNet++<br/>密集跳跃连接]
+        UNet --> AttUNet[Attention UNet<br/>注意力机制]
+        UNetPlus --> UNet3Plus[UNet3+<br/>全尺度融合]
+    end
+    
+    subgraph Era2 [第二代: 目标检测时代 2018-2022]
+        direction TB
+        YOLOv3[YOLOv3<br/>单阶段检测]
+        YOLOv3 --> YOLOv5[YOLOv5<br/>工业标准]
+        YOLOv5 --> YOLOv8[YOLOv8<br/>当前主流]
+        
+        FasterRCNN[Faster R-CNN<br/>两阶段检测]
+        FasterRCNN --> Detectron2[Detectron2<br/>检测框架]
+    end
+    
+    subgraph Era3 [第三代: Transformer 时代 2021-至今]
+        direction TB
+        ViT[ViT<br/>Vision Transformer]
+        ViT --> BEiT[BEiT<br/>掩码图像建模]
+        ViT --> DiT[DiT<br/>文档图像 Transformer]
+        
+        LayoutLM[LayoutLMv1<br/>文本+布局]
+        LayoutLM --> LayoutLMv2[LayoutLMv2<br/>多模态融合]
+        LayoutLMv2 --> LayoutLMv3[LayoutLMv3<br/>当前 SOTA]
+        
+        DiT --> DiTDetectron[DiT + Detectron2<br/>组合架构]
+    end
+    
+    Era1 --> Era2
+    Era2 --> Era3
+    
+    style Era1 fill:#e1f5ff
+    style Era2 fill:#fff3e0
+    style Era3 fill:#f3e5f5
+```
+
+## 🎯 三、核心模型对比(功能视角)
+
+### 3.1 UNet 系列对比
+
+```mermaid
+graph TB
+    subgraph UNetSeries [UNet 系列模型]
+        direction TB
+        
+        UNet[UNet 2015<br/>输出: 像素级分割Mask<br/>算法: Encoder-Decoder + Skip<br/>场景: 表格线检测/文档分割]
+        
+        UNetPlus[UNet++ 2018<br/>输出: 像素级分割Mask<br/>算法: Dense Skip Connections<br/>场景: 复杂边界分割]
+        
+        UNet3Plus[UNet3+ 2020<br/>输出: 像素级分割Mask<br/>算法: Full-scale Skip<br/>场景: 大尺寸图像分割]
+        
+        TransUNet[TransUNet 2021<br/>输出: 像素级分割Mask<br/>算法: ViT Encoder + UNet Decoder<br/>场景: 全局+局部特征融合]
+        
+        UNet --> UNetPlus
+        UNet --> TransUNet
+        UNetPlus --> UNet3Plus
+    end
+    
+    style UNet fill:#e8f5e9
+    style UNetPlus fill:#e8f5e9
+    style UNet3Plus fill:#e8f5e9
+    style TransUNet fill:#e8f5e9
+```
+
+### 3.2 目标检测系列对比
+
+```mermaid
+graph LR
+    subgraph DetectionSeries [目标检测系列]
+        direction TB
+        
+        YOLOv8[YOLOv8<br/>输出: Bounding Boxes<br/>算法: Anchor-free 单阶段<br/>场景: 快速区域检测]
+        
+        Detectron2[Detectron2 + Mask R-CNN<br/>输出: Bounding Boxes + Masks<br/>算法: 两阶段检测 + 实例分割<br/>场景: 精确区域检测+分割]
+        
+        RTDETR[RT-DETR<br/>输出: Bounding Boxes<br/>算法: DETR 实时版本<br/>场景: 实时检测]
+    end
+    
+    style YOLOv8 fill:#fff3e0
+    style Detectron2 fill:#fff3e0
+    style RTDETR fill:#fff3e0
+```
+
+### 3.3 Transformer 系列对比
+
+```mermaid
+graph TB
+    subgraph TransformerSeries [Transformer 系列]
+        direction TB
+        
+        LayoutLMv3[LayoutLMv3<br/>输出: 区域框 + 文本理解<br/>算法: 文本+图像+位置多模态<br/>场景: 复杂文档理解]
+        
+        DiT[DiT<br/>输出: 区域框 + 视觉特征<br/>算法: 纯视觉 Transformer<br/>场景: 高分辨率文档]
+        
+        BEiT[BEiT<br/>输出: 视觉特征表示<br/>算法: 掩码图像建模<br/>场景: 预训练骨干网络]
+        
+        DiTDetectron[DiT + Detectron2<br/>输出: 区域框 + Masks<br/>算法: DiT Backbone + 检测头<br/>场景: 最强布局检测]
+        
+        LayoutLMv3 --> DiTDetectron
+        DiT --> DiTDetectron
+        BEiT -.-> DiT
+    end
+    
+    style LayoutLMv3 fill:#f3e5f5
+    style DiT fill:#f3e5f5
+    style BEiT fill:#f3e5f5
+    style DiTDetectron fill:#ffcdd2
+```
+
+## 📋 四、详细对比表(多维度视角)
+
+### 4.1 核心算法对比
+
+```mermaid
+graph TB
+    subgraph Algorithm[核心算法对比]
+        direction LR
+        
+        subgraph PixelLevel[像素级方法]
+            UNetAlg[UNet<br/>Encoder-Decoder<br/>Skip Connections]
+            UNetAlgOut[输出: 每个像素的类别]
+        end
+        
+        subgraph BoxLevel[框级方法]
+            YOLOAlg[YOLO<br/>Anchor-free Detection<br/>单阶段检测]
+            YOLOAlgOut[输出: Bounding Box + 类别]
+        end
+        
+        subgraph TransformerLevel[Transformer方法]
+            DiTAlg[DiT/LayoutLMv3<br/>Multi-head Attention<br/>全局建模]
+            DiTAlgOut[输出: 区域框 + 语义理解]
+        end
+        
+        PixelLevel --> BoxLevel
+        BoxLevel --> TransformerLevel
+    end
+```
+
+### 4.2 输出内容对比
+
+```mermaid
+graph LR
+    subgraph OutputType[输出内容类型]
+        direction TB
+        
+        PixelMask[像素级Mask<br/>每个像素的类别标签<br/>适合: 表格线检测]
+        
+        BBox[Bounding Box<br/>矩形框 + 类别<br/>适合: 区域检测]
+        
+        BBoxMask[Bounding Box + Mask<br/>框 + 像素级分割<br/>适合: 精确区域]
+        
+        BBoxSemantic[Bounding Box + 语义<br/>框 + 文本理解<br/>适合: 文档理解]
+    end
+    
+    PixelMask --> BBox
+    BBox --> BBoxMask
+    BBox --> BBoxSemantic
+    
+    style PixelMask fill:#e8f5e9
+    style BBox fill:#fff3e0
+    style BBoxMask fill:#e1f5ff
+    style BBoxSemantic fill:#f3e5f5
+```
+
+### 4.3 适合场景对比
+
+```mermaid
+mindmap
+  root((Layout Detection<br/>应用场景))
+    像素级任务
+      表格线检测
+      文档区域分割
+      图像增强
+      OCR前处理
+    区域检测任务
+      快速布局检测
+      工业文档处理
+      批量文档分析
+    理解任务
+      复杂文档解析
+      多页文档理解
+      跨区域关系
+      语义推理
+```
+
+## 🔍 五、详细模型特性对比表
+
+### 5.1 UNet 系列详细对比
+
+| 模型 | 核心算法 | 输出内容 | 适合场景 | 优势 | 劣势 |
+|------|---------|---------|---------|------|------|
+| **UNet (2015)** | Encoder-Decoder + Skip Connections | 像素级分割 Mask | 表格线检测、文档分割、OCR前处理 | 结构简单、细节好、小数据可训练 | 语义差距大、多尺度融合不足 |
+| **UNet++ (2018)** | Dense Skip Connections + Deep Supervision | 像素级分割 Mask | 复杂边界分割、医学图像 | 边界更精确、收敛更快 | 计算量增加 |
+| **UNet3+ (2020)** | Full-scale Skip Connections | 像素级分割 Mask | 大尺寸图像、遥感图像 | 多尺度信息充分融合 | 内存占用大 |
+| **TransUNet (2021)** | ViT Encoder + UNet Decoder | 像素级分割 Mask | 复杂结构图像、文档结构化 | 全局建模能力强 | 计算复杂度高 |
+
+### 5.2 目标检测系列详细对比
+
+| 模型 | 核心算法 | 输出内容 | 适合场景 | 优势 | 劣势 |
+|------|---------|---------|---------|------|------|
+| **YOLOv8** | Anchor-free 单阶段检测 | Bounding Boxes + 类别 | 快速布局检测、工业文档、批量处理 | 速度快、部署简单、工业稳定 | 不理解文本语义、跨区域关系弱 |
+| **Detectron2 + Mask R-CNN** | 两阶段检测 + 实例分割 | Bounding Boxes + Masks | 精确区域检测、复杂布局 | 精度高、支持实例分割 | 推理速度较慢 |
+| **RT-DETR** | DETR 实时版本 | Bounding Boxes + 类别 | 实时检测场景 | 端到端、无需NMS | 训练难度较高 |
+
+### 5.3 Transformer 系列详细对比
+
+| 模型 | 核心算法 | 输出内容 | 适合场景 | 优势 | 劣势 |
+|------|---------|---------|---------|------|------|
+| **LayoutLMv3** | 文本+图像+位置多模态 Transformer | 区域框 + 文本理解 + 阅读顺序 | 合同、票据、科研论文、多页文档 | 理解文档语义、结构能力强 | 模型大、推理慢、训练成本高 |
+| **DiT** | 纯视觉 Transformer(文档域预训练) | 区域框 + 多尺度视觉特征 | 高分辨率文档、复杂视觉布局 | 视觉理解强、PubLayNet SOTA | 不理解文本语义、需大数据训练、对训练数据依赖性强 |
+| **BEiT** | 掩码图像建模(自监督预训练) | 视觉特征表示(作为Backbone) | 预训练骨干网络、迁移学习 | 自监督学习、通用性强 | 主要用于预训练,不直接用于检测 |
+| **DiT + Detectron2** | DiT Backbone + Detectron2 Detection Head | 区域框 + Masks + 语义特征 | 学术论文布局(PubLayNet SOTA) | 结合Transformer全局理解 + 检测精度 | 计算资源需求高、对训练数据匹配度要求高 ⚠️ |
+| **RT-DETR (Docling)** | Hybrid CNN-Transformer 端到端检测 | 区域框 + 类别(17类) | 商业文档、财务报表、多样化文档 | 类别体系完善、商业文档适配好、无需NMS | 在学术论文上可能不如DiT |
+
+## 🔗 六、模型关系与组合(架构组合视角)
+
+```mermaid
+graph TB
+    subgraph Backbone[Backbone 骨干网络]
+        BEiT[BEiT<br/>视觉特征提取]
+        DiT[DiT<br/>文档视觉特征]
+        ViT[ViT<br/>通用视觉特征]
+    end
+    
+    subgraph Detection[Detection Head 检测头]
+        YOLOHead[YOLO Head<br/>单阶段检测]
+        Detectron2Head[Detectron2 Head<br/>两阶段检测+分割]
+        DETRHead[DETR Head<br/>端到端检测]
+    end
+    
+    subgraph Segmentation[Segmentation Head 分割头]
+        UNetDecoder[UNet Decoder<br/>像素级分割]
+        MaskHead[Mask R-CNN Head<br/>实例分割]
+    end
+    
+    BEiT --> YOLOHead
+    DiT --> Detectron2Head
+    ViT --> UNetDecoder
+    DiT --> MaskHead
+    
+    Detectron2Head --> Final1[DiT + Detectron2<br/>最强组合]
+    UNetDecoder --> Final2[TransUNet<br/>Transformer + UNet]
+    
+    style Final1 fill:#ffcdd2
+    style Final2 fill:#c8e6c9
+```
+
+## 🎯 七、应用场景选择指南(决策视角)
+
+```mermaid
+flowchart TD
+    Start[需要Layout Detection] --> Q1{需要像素级精度?}
+    
+    Q1 -->|是: 表格线检测| UNetChoice[选择 UNet 系列<br/>UNet/UNet++/UNet3+<br/>输出: 像素级Mask]
+    
+    Q1 -->|否: 区域检测即可| Q2{需要文本理解?}
+    
+    Q2 -->|否: 只需快速检测| YOLOChoice[选择 YOLOv8<br/>输出: Bounding Boxes<br/>速度快、易部署]
+    
+    Q2 -->|是: 需要语义理解| Q3{需要跨区域关系?}
+    
+    Q3 -->|是: 复杂文档| LayoutLMv3Choice[选择 LayoutLMv3<br/>输出: 区域框 + 文本理解<br/>多模态理解]
+    
+    Q3 -->|否: 高分辨率文档| DiTChoice[选择 DiT + Detectron2<br/>输出: 区域框 + Masks<br/>最强精度]
+    
+    UNetChoice --> UseCase1[表格线检测<br/>文档区域分割<br/>OCR前处理]
+    YOLOChoice --> UseCase2[工业文档批量处理<br/>快速布局检测]
+    LayoutLMv3Choice --> UseCase3[合同票据解析<br/>科研论文结构化<br/>多页文档理解]
+    DiTChoice --> UseCase4[高分辨率扫描件<br/>复杂视觉布局<br/>最强精度需求]
+    
+    style UNetChoice fill:#e8f5e9
+    style YOLOChoice fill:#fff3e0
+    style LayoutLMv3Choice fill:#f3e5f5
+    style DiTChoice fill:#ffcdd2
+```
+
+## 📊 八、性能对比(评估视角)
+
+```mermaid
+graph LR
+    subgraph Performance[性能对比维度]
+        direction TB
+        
+        Speed[速度<br/>FPS]
+        Accuracy[准确率<br/>mAP]
+        Memory[内存占用<br/>GB]
+        Training[训练成本<br/>数据量+时间]
+    end
+    
+    subgraph ModelRank[模型排名]
+        YOLOv8Speed[YOLOv8: ⭐⭐⭐⭐⭐<br/>最快]
+        DiTAccuracy[DiT+Detectron2: ⭐⭐⭐⭐⭐<br/>最准]
+        UNetMemory[UNet: ⭐⭐⭐⭐<br/>最轻]
+        YOLOv8Training[YOLOv8: ⭐⭐⭐⭐<br/>易训练]
+    end
+    
+    Speed --> YOLOv8Speed
+    Accuracy --> DiTAccuracy
+    Memory --> UNetMemory
+    Training --> YOLOv8Training
+```
+
+## 🔬 九、技术细节对比(算法视角)
+
+### 9.1 UNet 结构详解
+
+```mermaid
+graph TB
+    subgraph UNetStructure[UNet 结构流程]
+        Input[Input Image<br/>H×W×3] --> Conv1[Conv Block 1<br/>特征提取]
+        Conv1 --> Pool1[MaxPool 2×2<br/>下采样]
+        Pool1 --> Conv2[Conv Block 2]
+        Conv2 --> Pool2[MaxPool 2×2]
+        Pool2 --> Conv3[Conv Block 3]
+        Conv3 --> Pool3[MaxPool 2×2]
+        Pool3 --> Bottleneck[Bottleneck<br/>最深层特征]
+        
+        Bottleneck --> UpConv3[UpConv 2×2<br/>上采样]
+        UpConv3 --> Concat3[Concat<br/>融合Skip连接]
+        Conv3 -.Skip.-> Concat3
+        Concat3 --> Decoder3[Conv Block]
+        
+        Decoder3 --> UpConv2[UpConv 2×2]
+        UpConv2 --> Concat2[Concat]
+        Conv2 -.Skip.-> Concat2
+        Concat2 --> Decoder2[Conv Block]
+        
+        Decoder2 --> UpConv1[UpConv 2×2]
+        UpConv1 --> Concat1[Concat]
+        Conv1 -.Skip.-> Concat1
+        Concat1 --> Decoder1[Conv Block]
+        
+        Decoder1 --> Output[Output Mask<br/>H×W×Classes]
+    end
+    
+    style Bottleneck fill:#ffcdd2
+    style Concat1 fill:#c8e6c9
+    style Concat2 fill:#c8e6c9
+    style Concat3 fill:#c8e6c9
+```
+
+### 9.2 DiT + Detectron2 组合架构
+
+```mermaid
+graph TB
+    subgraph DiTDetectron[DiT + Detectron2 架构]
+        Input[Document Image<br/>高分辨率] --> DiTBackbone[DiT Backbone<br/>Document Image Transformer]
+        
+        DiTBackbone --> Layer3[Layer 3 Features<br/>多尺度特征]
+        DiTBackbone --> Layer5[Layer 5 Features]
+        DiTBackbone --> Layer7[Layer 7 Features]
+        DiTBackbone --> Layer11[Layer 11 Features]
+        
+        Layer3 --> FPN[FPN<br/>Feature Pyramid Network]
+        Layer5 --> FPN
+        Layer7 --> FPN
+        Layer11 --> FPN
+        
+        FPN --> P2[P2: 高分辨率<br/>小目标]
+        FPN --> P3[P3: 中等分辨率]
+        FPN --> P4[P4: 较低分辨率]
+        FPN --> P5[P5: 低分辨率<br/>大目标]
+        
+        P2 --> RPN[RPN<br/>Region Proposal Network]
+        P3 --> RPN
+        P4 --> RPN
+        P5 --> RPN
+        
+        RPN --> ROIAlign[ROIAlign<br/>特征提取]
+        ROIAlign --> ROIHead[ROI Head<br/>分类+回归+Mask]
+        
+        ROIHead --> Output1[Bounding Boxes<br/>区域框]
+        ROIHead --> Output2[Masks<br/>实例分割]
+        ROIHead --> Output3[Classes<br/>类别标签]
+    end
+    
+    style DiTBackbone fill:#f3e5f5
+    style FPN fill:#fff3e0
+    style ROIHead fill:#e1f5ff
+```
+
+## 📚 十、总结与建议
+
+### 10.1 模型选择总结
+
+```mermaid
+graph TB
+    subgraph Summary[模型选择总结]
+        direction TB
+        
+        PixelTask[像素级任务<br/>表格线/区域分割] --> UNetRec[推荐: UNet系列<br/>UNet/UNet++/TransUNet]
+        
+        FastTask[快速区域检测<br/>工业批量处理] --> YOLORec[推荐: YOLOv8<br/>速度快、易部署]
+        
+        ComplexTask[复杂文档理解<br/>多模态推理] --> LayoutLMRec[推荐: LayoutLMv3<br/>文本+图像+布局]
+        
+        BestTask[最强精度需求<br/>高分辨率文档] --> DiTRec[推荐: DiT + Detectron2<br/>SOTA 性能]
+    end
+    
+    style UNetRec fill:#e8f5e9
+    style YOLORec fill:#fff3e0
+    style LayoutLMRec fill:#f3e5f5
+    style DiTRec fill:#ffcdd2
+```
+
+### 10.2 发展趋势
+
+1. **像素级 → 框级 → 理解级**:从单纯的区域检测向语义理解发展
+2. **单模态 → 多模态**:从纯视觉到文本+视觉+位置的融合
+3. **专用模型 → 通用框架**:从特定任务模型到可组合的模块化架构
+4. **监督学习 → 自监督学习**:BEiT 等自监督预训练方法成为趋势
+5. **单一模型 → 组合架构**:DiT + Detectron2 等组合方案展现更强能力
+
+### 10.3 实际应用建议
+
+- **表格线检测**:UNet 系列(MinerU 采用)
+- **快速布局检测**:YOLOv8(工业标准)
+- **复杂文档解析**:LayoutLMv3(多模态理解)
+- **最强精度需求**:DiT + Detectron2(学术 SOTA)
+- **通用框架**:Universal Doc Parser 通过 YAML 配置灵活组合
+
+---
+
+**文档版本**: v1.0  
+**最后更新**: 2024  
+**参考资料**: 基于技术讨论和实践经验整理
+

+ 470 - 0
docs/ocr_tools/universal_doc_parser/为什么Docling效果比DiT好.md

@@ -0,0 +1,470 @@
+# 为什么 Docling RT-DETR 效果比 DiT 好?
+
+## 📋 问题背景
+
+根据实际测试,在处理**财务报表**(如 `2023年度报告母公司.pdf`)时,**Docling 的 RT-DETR 布局检测模型**效果明显优于 **DiT + Detectron2**,这与文档中描述的"DiT + Detectron2 是最强组合"似乎不符。
+
+## 🔍 核心原因分析
+
+### 1. **训练数据差异** ⭐⭐⭐⭐⭐
+
+这是**最关键的因素**:
+
+| 模型 | 训练数据集 | 数据特点 | 适配性 |
+|------|-----------|---------|--------|
+| **DiT + Detectron2** | **PubLayNet** | 学术论文布局(PDF格式) | ❌ 不适合财务报表 |
+| **Docling RT-DETR** | **DocLayNet + 其他文档数据集** | 多样化商业文档 | ✅ 更适合财务报表 |
+
+#### PubLayNet 数据集特点
+```
+- 数据来源:arXiv 学术论文
+- 文档类型:科研论文(PDF)
+- 布局特点:
+  * 标准双栏布局
+  * 标题、摘要、正文、参考文献
+  * 表格多为数据表格
+  * 图片多为图表
+- 类别:Text, Title, List, Table, Figure
+```
+
+#### DocLayNet 数据集特点
+```
+- 数据来源:多样化商业文档
+- 文档类型:
+  * 财务报表
+  * 技术报告
+  * 法律文件
+  * 科学论文
+  * 专利文档
+- 布局特点:
+  * 复杂多栏布局
+  * 页眉页脚
+  * 表格结构多样
+  * 图片类型丰富
+- 类别:17个类别(包括 Page-header, Page-footer, Key-Value Region 等)
+```
+
+#### 财务报表特点(您的数据)
+```
+- 文档类型:年度财务报告
+- 布局特点:
+  * 页眉(公司信息、联系方式)
+  * 正文段落(审计说明)
+  * 复杂表格(财务报表)
+  * 页脚(签名、日期)
+  * 印章、签名区域
+- 与 PubLayNet 差异:
+  * ❌ 不是学术论文格式
+  * ❌ 没有标准的双栏布局
+  * ❌ 包含更多商业文档元素
+```
+
+**结论**:DiT 在学术论文上训练,对财务报表的适配性较差。
+
+---
+
+### 2. **类别体系差异** ⭐⭐⭐⭐
+
+#### DiT (PubLayNet) 类别
+```python
+# 仅5个类别
+categories = [
+    'Text',      # 文本段落
+    'Title',     # 标题
+    'List',      # 列表
+    'Table',     # 表格
+    'Figure'     # 图片
+]
+```
+
+#### Docling RT-DETR 类别
+```python
+# 17个类别,更细粒度
+categories = [
+    'Text',              # 文本段落
+    'Title',             # 标题
+    'Section-header',    # 章节标题
+    'List-item',         # 列表项
+    'Table',             # 表格
+    'Picture',           # 图片
+    'Caption',           # 图注
+    'Page-header',       # 页眉 ⭐
+    'Page-footer',       # 页脚 ⭐
+    'Footnote',          # 脚注
+    'Formula',           # 公式
+    'Code',              # 代码
+    'Checkbox',          # 复选框
+    'Form',              # 表单
+    'Key-Value Region',  # 键值对区域 ⭐
+    'Document Index',    # 文档索引
+    'Background'         # 背景
+]
+```
+
+**财务报表中的关键元素**:
+- ✅ **Page-header**:公司名称、联系方式(Docling 能识别,DiT 不能)
+- ✅ **Page-footer**:签名、日期(Docling 能识别,DiT 不能)
+- ✅ **Key-Value Region**:键值对信息(Docling 能识别,DiT 不能)
+
+**结论**:Docling 的类别体系更适合商业文档。
+
+---
+
+### 3. **模型微调策略** ⭐⭐⭐
+
+#### DiT 微调情况
+```
+- 预训练:4200万文档图像(IIT-CDIP)
+- 微调:PubLayNet(学术论文)
+- 领域适配:❌ 未针对商业文档微调
+```
+
+#### Docling RT-DETR 微调情况
+```
+- 预训练:COCO 等通用数据集
+- 微调:DocLayNet(多样化商业文档)
+- 领域适配:✅ 专门针对商业文档优化
+- 可能还包含:
+  * 财务报表数据
+  * 技术报告
+  * 法律文件
+```
+
+**结论**:Docling 在商业文档上进行了专门微调。
+
+---
+
+### 4. **后处理策略差异** ⭐⭐⭐
+
+#### DiT 后处理
+```python
+# dit_layout_adapter.py
+# 1. 重叠框处理(可能过于激进)
+if self._remove_overlap:
+    formatted_results = LayoutUtils.remove_overlapping_boxes(
+        formatted_results,
+        iou_threshold=0.8,
+        overlap_ratio_threshold=0.8,
+        max_area_ratio=0.8,  # 可能过滤掉有效的大区域
+        enable_category_restriction=True,
+        enable_category_priority=True
+    )
+
+# 2. 误检图片框过滤
+if self._filter_false_positive_images:
+    formatted_results = LayoutUtils.filter_false_positive_images(
+        formatted_results,
+        min_text_area_ratio=0.3
+    )
+```
+
+**问题**:
+- ❌ 重叠框处理可能过于严格
+- ❌ 误检过滤可能误删有效区域
+- ❌ 类别限制可能不合理
+
+#### Docling 后处理
+```python
+# docling_layout_adapter.py
+# 1. 简单过滤
+if width < 10 or height < 10:
+    continue
+
+# 2. 面积过滤(更宽松)
+if area > img_area * 0.95:  # 只过滤几乎覆盖整页的框
+    continue
+
+# 3. 无复杂的重叠处理
+# RT-DETR 端到端训练,输出质量高,无需复杂后处理
+```
+
+**优势**:
+- ✅ 后处理简单,减少误删
+- ✅ RT-DETR 端到端训练,输出质量高
+- ✅ 无需 NMS,减少后处理误差
+
+---
+
+### 5. **模型架构适配性** ⭐⭐
+
+#### DiT + Detectron2 架构
+```
+输入图像
+    ↓
+[DiT Backbone: ViT]
+    ├─ 全局注意力机制
+    ├─ 文档域预训练
+    └─ 多尺度特征提取
+    ↓
+[FPN + RPN]
+    ├─ 特征金字塔
+    └─ 区域提议网络
+    ↓
+[Mask R-CNN Head]
+    ├─ ROI Align
+    ├─ 分类 + 回归
+    └─ Mask 分割
+    ↓
+[NMS 后处理]
+    ↓
+输出结果
+```
+
+**特点**:
+- ✅ 全局建模能力强
+- ✅ 支持实例分割
+- ❌ 推理速度慢
+- ❌ 需要复杂后处理
+- ❌ 对训练数据依赖性强
+
+#### RT-DETR 架构
+```
+输入图像
+    ↓
+[Hybrid Encoder]
+    ├─ CNN Backbone(ResNet/HGNet)
+    └─ Transformer Encoder
+    ↓
+[Transformer Decoder]
+    ├─ Query-based 检测
+    └─ IoU-aware 选择
+    ↓
+**无需 NMS** ✅
+    ↓
+输出结果
+```
+
+**特点**:
+- ✅ 端到端训练,无需 NMS
+- ✅ 推理速度快
+- ✅ 输出质量高
+- ✅ 对多样化数据适应性强
+
+---
+
+## 📊 实际效果对比
+
+### 财务报表检测结果
+
+#### DiT 检测问题
+```json
+// 问题1:将整页识别为图片
+{
+  "category": "image_body",
+  "bbox": [143, 0, 1485, 2338],  // 几乎覆盖整页
+  "confidence": 0.447
+}
+
+// 问题2:类别识别错误
+{
+  "category": "text",
+  "bbox": [157, 106, 1474, 1344],
+  "confidence": 0.552,
+  "raw": {"original_label": "list"}  // 应该是 text
+}
+
+// 问题3:页眉页脚无法识别
+// DiT 没有 Page-header/Page-footer 类别
+```
+
+#### Docling 检测优势
+```json
+// 优势1:正确识别页眉
+{
+  "category": "header",
+  "bbox": [166, 132, 600, 199],
+  "confidence": 0.85,
+  "raw": {"original_label": "Page-header"}
+}
+
+// 优势2:正确识别页脚
+{
+  "category": "footer",
+  "bbox": [149, 2100, 1450, 2338],
+  "confidence": 0.82,
+  "raw": {"original_label": "Page-footer"}
+}
+
+// 优势3:类别识别准确
+{
+  "category": "text",
+  "bbox": [167, 228, 1462, 322],
+  "confidence": 0.91,
+  "raw": {"original_label": "Text"}
+}
+```
+
+---
+
+## 🎯 为什么"最强组合"不适用?
+
+### "DiT + Detectron2 是最强组合"的前提条件
+
+文档中描述的"最强组合"是**在特定条件下**成立的:
+
+1. **数据集匹配**:在 PubLayNet 或 DocLayNet 等标准数据集上
+2. **任务类型**:学术论文、技术文档等标准布局
+3. **评估指标**:mAP(平均精度均值)等学术指标
+4. **计算资源**:充足的 GPU 资源
+
+### 实际应用中的差异
+
+| 维度 | 学术评估 | 实际应用(财务报表) |
+|------|---------|---------------------|
+| **数据集** | PubLayNet/DocLayNet | 财务报表(未在训练集中) |
+| **布局类型** | 标准学术论文 | 复杂商业文档 |
+| **类别需求** | 5个基础类别 | 17个细粒度类别 |
+| **后处理** | 标准 NMS | 需要复杂规则 |
+| **速度要求** | 不敏感 | 可能需要实时处理 |
+
+**结论**:理论上的"最强组合"不等于实际应用中的"最佳选择"。
+
+---
+
+## 💡 解决方案建议
+
+### 1. **针对财务报表优化 DiT**
+
+```python
+# 方案1:在财务报表数据上微调 DiT
+# 1. 收集财务报表标注数据
+# 2. 在 PubLayNet 预训练模型基础上微调
+# 3. 添加 Page-header/Page-footer 等类别
+
+# 方案2:使用 DocLayNet 微调的 DiT
+# 如果存在 DocLayNet 上微调的 DiT 模型,效果会更好
+```
+
+### 2. **使用 Docling RT-DETR(推荐)**
+
+```yaml
+# universal_doc_parser 配置
+layout:
+  module: docling
+  config:
+    model_dir: ds4sd/docling-layout-heron  # 或 egret-large
+    device: cuda
+    conf: 0.3
+```
+
+**优势**:
+- ✅ 开箱即用,无需微调
+- ✅ 类别体系完善
+- ✅ 后处理简单
+- ✅ 推理速度快
+
+### 3. **混合策略**
+
+```python
+# 方案:根据文档类型选择模型
+def select_layout_model(doc_type):
+    if doc_type == "financial_report":
+        return DoclingLayoutDetector()  # 商业文档
+    elif doc_type == "academic_paper":
+        return DitLayoutDetector()       # 学术论文
+    else:
+        return DoclingLayoutDetector()  # 默认
+```
+
+---
+
+## 📈 模型选择决策树
+
+```mermaid
+flowchart TD
+    Start[需要 Layout Detection] --> Q1{文档类型?}
+    
+    Q1 -->|财务报表/商业文档| DoclingChoice[选择 Docling RT-DETR<br/>✅ 类别完善<br/>✅ 商业文档适配<br/>✅ 后处理简单]
+    
+    Q1 -->|学术论文/技术文档| Q2{需要最高精度?}
+    
+    Q2 -->|是: 有充足资源| DiTChoice[选择 DiT + Detectron2<br/>✅ 最高精度<br/>⚠️ 需要微调<br/>⚠️ 速度较慢]
+    
+    Q2 -->|否: 需要平衡| DoclingChoice2[选择 Docling RT-DETR<br/>✅ 精度与速度平衡<br/>✅ 开箱即用]
+    
+    Q1 -->|通用文档| DoclingChoice3[选择 Docling RT-DETR<br/>✅ 通用性强<br/>✅ 类别丰富]
+    
+    style DoclingChoice fill:#c8e6c9
+    style DoclingChoice2 fill:#c8e6c9
+    style DoclingChoice3 fill:#c8e6c9
+    style DiTChoice fill:#fff3e0
+```
+
+---
+
+## 🔬 技术细节补充
+
+### DiT 在财务报表上的问题
+
+1. **类别映射不匹配**
+   ```python
+   # DiT 输出:PubLayNet 5类
+   # 财务报表需要:Page-header, Page-footer 等
+   # 结果:页眉页脚被错误分类为 Text
+   ```
+
+2. **布局理解偏差**
+   ```python
+   # PubLayNet 训练数据:标准双栏布局
+   # 财务报表:复杂多栏、页眉页脚、签名区域
+   # 结果:整页被识别为单个区域
+   ```
+
+3. **后处理过度**
+   ```python
+   # 重叠框处理 + 误检过滤
+   # 可能误删有效的页眉页脚区域
+   ```
+
+### Docling RT-DETR 的优势
+
+1. **端到端训练**
+   ```python
+   # RT-DETR 在 DocLayNet 上端到端训练
+   # 学习到了商业文档的布局规律
+   # 输出质量高,无需复杂后处理
+   ```
+
+2. **类别体系完善**
+   ```python
+   # 17个类别覆盖商业文档所有元素
+   # Page-header, Page-footer, Key-Value Region
+   # 直接适配财务报表需求
+   ```
+
+3. **推理效率高**
+   ```python
+   # 无需 NMS,端到端推理
+   # 速度快,适合生产环境
+   ```
+
+---
+
+## 📚 总结
+
+### 核心结论
+
+1. **"最强组合"是条件性的**
+   - DiT + Detectron2 在 PubLayNet 上表现最好
+   - 但不等于在所有场景下都是最佳选择
+
+2. **数据适配性 > 模型架构**
+   - 训练数据与任务场景的匹配度更重要
+   - Docling 在商业文档上训练,更适合财务报表
+
+3. **实际应用需要综合考虑**
+   - 精度、速度、类别体系、后处理复杂度
+   - Docling RT-DETR 在财务报表场景下综合表现更好
+
+### 建议
+
+- ✅ **财务报表场景**:使用 **Docling RT-DETR**
+- ✅ **学术论文场景**:可以使用 **DiT + Detectron2**
+- ✅ **通用文档场景**:推荐 **Docling RT-DETR**(类别更丰富)
+
+---
+
+**文档版本**: v1.0  
+**最后更新**: 2024  
+**基于**: 实际测试结果和技术分析
+

+ 91 - 0
ocr_tools/universal_doc_parser/dit_support/README.md

@@ -0,0 +1,91 @@
+# DiT 支持模块
+
+本目录包含 DiT (Document Image Transformer) 布局检测所需的核心代码和配置文件。
+
+## 目录结构
+
+```
+dit_support/
+├── ditod/                    # DiT 核心模块
+│   ├── __init__.py          # 模块导出(仅推理必需)
+│   ├── config.py            # 配置扩展(add_vit_config)
+│   ├── backbone.py          # ViT backbone 实现
+│   ├── beit.py              # BEiT/DIT 模型定义
+│   └── deit.py              # DeiT 模型定义(可选)
+└── configs/                  # 配置文件
+    ├── Base-RCNN-FPN.yaml   # 基础配置
+    └── cascade/
+        └── cascade_dit_large.yaml  # Cascade R-CNN + DiT-large 配置
+```
+
+## 使用方法
+
+在 `universal_doc_parser` 中使用 DiT 布局检测:
+
+```python
+from models.adapters import get_layout_detector
+
+# 配置 DiT 检测器
+config = {
+    'module': 'dit',
+    'config_file': 'dit_support/configs/cascade/cascade_dit_large.yaml',
+    'model_weights': 'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth',
+    'device': 'cpu',  # 或 'cuda'
+    'conf': 0.3,
+    'remove_overlap': True,
+    'iou_threshold': 0.8,
+    'overlap_ratio_threshold': 0.8,
+}
+
+# 创建检测器
+detector = get_layout_detector(config)
+detector.initialize()
+
+# 检测布局
+import cv2
+img = cv2.imread('image.jpg')
+results = detector.detect(img)
+
+# 清理
+detector.cleanup()
+```
+
+## 依赖包
+
+需要安装以下 Python 包:
+
+```bash
+# 1. PyTorch(必须先安装)
+pip install torch torchvision
+
+# 2. detectron2
+# Mac M4 Pro / Apple Silicon:
+CC=clang CXX=clang++ ARCHFLAGS="-arch arm64" pip install --no-build-isolation 'git+https://github.com/facebookresearch/detectron2.git'
+
+# Linux (CPU):
+pip install 'git+https://github.com/facebookresearch/detectron2.git'
+
+# Linux (CUDA):
+pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6'
+
+# 3. timm(Vision Transformer 模型库)
+pip install timm
+
+# 4. 基础依赖
+pip install numpy opencv-python Pillow einops
+```
+
+## 迁移说明
+
+本模块是从 `unilm/dit/object_detection/` 迁移的最小版本,仅包含推理必需的代码:
+
+- ✅ 已迁移:ditod 核心模块(5个文件)、配置文件(2个)
+- ❌ 未迁移:训练相关代码(dataset_mapper.py, mytrainer.py 等)、评估代码(icdar_evaluation.py, table_evaluation/)
+
+## 注意事项
+
+1. **路径问题**:确保 `dit_support` 目录在 Python 路径中(适配器会自动处理)
+2. **模型权重**:首次运行会自动从 HuggingFace 下载,需要网络连接
+3. **PyTorch 2.6+**:代码中已包含兼容性修复
+4. **重叠框处理**:默认启用,可在配置中关闭或调整阈值
+

+ 69 - 0
ocr_tools/universal_doc_parser/dit_support/configs/Base-RCNN-FPN.yaml

@@ -0,0 +1,69 @@
+MODEL:
+  MASK_ON: True
+  META_ARCHITECTURE: "GeneralizedRCNN"
+  PIXEL_MEAN: [123.675, 116.280, 103.530]
+  PIXEL_STD: [58.395, 57.120, 57.375]
+  BACKBONE:
+    NAME: "build_vit_fpn_backbone"
+  VIT:
+    OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
+    DROP_PATH: 0.1
+    IMG_SIZE: [224,224]
+    POS_TYPE: "abs"
+  FPN:
+    IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
+  ANCHOR_GENERATOR:
+    SIZES: [[32], [64], [128], [256], [512]]  # One size for each in feature map
+    ASPECT_RATIOS: [[0.5, 1.0, 2.0]]  # Three aspect ratios (same for all in feature maps)
+  RPN:
+    IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
+    PRE_NMS_TOPK_TRAIN: 2000  # Per FPN level
+    PRE_NMS_TOPK_TEST: 1000  # Per FPN level
+    # Detectron1 uses 2000 proposals per-batch,
+    # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
+    # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
+    POST_NMS_TOPK_TRAIN: 1000
+    POST_NMS_TOPK_TEST: 1000
+  ROI_HEADS:
+    NAME: "StandardROIHeads"
+    IN_FEATURES: ["p2", "p3", "p4", "p5"]
+    NUM_CLASSES: 5
+  ROI_BOX_HEAD:
+    NAME: "FastRCNNConvFCHead"
+    NUM_FC: 2
+    POOLER_RESOLUTION: 7
+  ROI_MASK_HEAD:
+    NAME: "MaskRCNNConvUpsampleHead"
+    NUM_CONV: 4
+    POOLER_RESOLUTION: 14
+DATASETS:
+  TRAIN: ("publaynet_train",)
+  TEST: ("publaynet_val",)
+SOLVER:
+  LR_SCHEDULER_NAME: "WarmupCosineLR"
+  AMP:
+    ENABLED: True
+  OPTIMIZER: "ADAMW"
+  BACKBONE_MULTIPLIER: 1.0
+  CLIP_GRADIENTS:
+    ENABLED: True
+    CLIP_TYPE: "full_model"
+    CLIP_VALUE: 1.0
+    NORM_TYPE: 2.0
+  WARMUP_FACTOR: 0.01
+  BASE_LR: 0.0004
+  WEIGHT_DECAY: 0.05
+  IMS_PER_BATCH: 32
+INPUT:
+  CROP:
+    ENABLED: True
+    TYPE: "absolute_range"
+    SIZE: (384, 600)
+  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
+  FORMAT: "RGB"
+DATALOADER:
+  FILTER_EMPTY_ANNOTATIONS: False
+VERSION: 2
+AUG:
+  DETR: True
+SEED: 42

+ 28 - 0
ocr_tools/universal_doc_parser/dit_support/configs/cascade/cascade_dit_large.yaml

@@ -0,0 +1,28 @@
+_BASE_: "../Base-RCNN-FPN.yaml"
+MODEL:
+  PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
+  PIXEL_STD: [ 127.5, 127.5, 127.5 ]
+  WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-large-224-p16-500k-d7a2fb.pth"
+  VIT:
+    NAME: "dit_large_patch16"
+    OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
+    DROP_PATH: 0.2
+  FPN:
+    IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
+  ROI_HEADS:
+    NAME: CascadeROIHeads
+  ROI_BOX_HEAD:
+    CLS_AGNOSTIC_BBOX_REG: True
+  RPN:
+    POST_NMS_TOPK_TRAIN: 2000
+SOLVER:
+  WARMUP_ITERS: 1000
+  IMS_PER_BATCH: 16
+  MAX_ITER: 60000
+  CHECKPOINT_PERIOD: 2000
+  BASE_LR: 0.0001
+  STEPS: (40000, 53333)
+  AMP:
+    ENABLED: False
+TEST:
+  EVAL_PERIOD: 2000

+ 18 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/__init__.py

@@ -0,0 +1,18 @@
+# --------------------------------------------------------------------------------
+# MPViT: Multi-Path Vision Transformer for Dense Prediction
+# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
+# All Rights Reserved.
+# Written by Youngwan Lee
+# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------------------
+# 最小迁移版本:仅包含推理必需的导出
+
+from .config import add_vit_config
+from .backbone import build_vit_fpn_backbone
+
+__all__ = [
+    "add_vit_config",
+    "build_vit_fpn_backbone",
+]
+

+ 156 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/backbone.py

@@ -0,0 +1,156 @@
+# --------------------------------------------------------------------------------
+# VIT: Multi-Path Vision Transformer for Dense Prediction
+# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
+# All Rights Reserved.
+# Written by Youngwan Lee
+# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# CoaT: https://github.com/mlpc-ucsd/CoaT
+# --------------------------------------------------------------------------------
+
+
+import torch
+
+from detectron2.layers import (
+    ShapeSpec,
+)
+from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
+from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
+
+from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
+from .deit import deit_base_patch16, mae_base_patch16
+
+__all__ = [
+    "build_vit_fpn_backbone",
+]
+
+
+class VIT_Backbone(Backbone):
+    """
+    Implement VIT backbone.
+    """
+
+    def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs):
+        super().__init__()
+        self._out_features = out_features
+        if 'base' in name:
+            self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
+        else:
+            self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
+
+        if name == 'beit_base_patch16':
+            model_func = beit_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == 'dit_base_patch16':
+            model_func = dit_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == "deit_base_patch16":
+            model_func = deit_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == "mae_base_patch16":
+            model_func = mae_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == "dit_large_patch16":
+            model_func = dit_large_patch16
+            self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
+        elif name == "beit_large_patch16":
+            model_func = beit_large_patch16
+            self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
+        else:
+            raise ValueError("Unsupported VIT name yet.")
+
+        if 'beit' in name or 'dit' in name:
+            if pos_type == "abs":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_abs_pos_emb=True,
+                                           **model_kwargs)
+            elif pos_type == "shared_rel":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_shared_rel_pos_bias=True,
+                                           **model_kwargs)
+            elif pos_type == "rel":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_rel_pos_bias=True,
+                                           **model_kwargs)
+            else:
+                raise ValueError()
+        else:
+            self.backbone = model_func(img_size=img_size,
+                                       out_features=out_features,
+                                       drop_path_rate=drop_path,
+                                       **model_kwargs)
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+
+        Returns:
+            dict[str->Tensor]: names and the corresponding features
+        """
+        assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+        return self.backbone.forward_features(x)
+
+    def output_shape(self):
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+            )
+            for name in self._out_features
+        }
+
+
+def build_VIT_backbone(cfg):
+    """
+    Create a VIT instance from config.
+
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        A VIT backbone instance.
+    """
+    # fmt: off
+    name = cfg.MODEL.VIT.NAME
+    out_features = cfg.MODEL.VIT.OUT_FEATURES
+    drop_path = cfg.MODEL.VIT.DROP_PATH
+    img_size = cfg.MODEL.VIT.IMG_SIZE
+    pos_type = cfg.MODEL.VIT.POS_TYPE
+
+    model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
+
+    return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs)
+
+
+@BACKBONE_REGISTRY.register()
+def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
+    """
+    Create a VIT w/ FPN backbone.
+
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
+    """
+    bottom_up = build_VIT_backbone(cfg)
+    in_features = cfg.MODEL.FPN.IN_FEATURES
+    out_channels = cfg.MODEL.FPN.OUT_CHANNELS
+    backbone = FPN(
+        bottom_up=bottom_up,
+        in_features=in_features,
+        out_channels=out_channels,
+        norm=cfg.MODEL.FPN.NORM,
+        top_block=LastLevelMaxPool(),
+        fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
+    )
+    return backbone

+ 671 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/beit.py

@@ -0,0 +1,671 @@
+""" Vision Transformer (ViT) in PyTorch
+
+A PyTorch implement of Vision Transformers as described in
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
+
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+
+Status/TODO:
+* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
+* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
+* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
+* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
+
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import warnings
+import math
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        # x = self.drop(x)
+        # commit this for the orignal BERT implement
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+            proj_drop=0., window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        if attn_head_dim is not None:
+            head_dim = attn_head_dim
+        all_head_dim = head_dim * self.num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+        if qkv_bias:
+            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+        else:
+            self.q_bias = None
+            self.v_bias = None
+
+        if window_size:
+            self.window_size = window_size
+            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+            self.relative_position_bias_table = nn.Parameter(
+                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+            # cls to token & token 2 cls & cls to cls
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(window_size[0])
+            coords_w = torch.arange(window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = self.num_relative_distance - 3
+            relative_position_index[0:, 0] = self.num_relative_distance - 2
+            relative_position_index[0, 0] = self.num_relative_distance - 1
+
+            self.register_buffer("relative_position_index", relative_position_index)
+
+            # trunc_normal_(self.relative_position_bias_table, std=.0)
+        else:
+            self.window_size = None
+            self.relative_position_bias_table = None
+            self.relative_position_index = None
+
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(all_head_dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x, rel_pos_bias=None, training_window_size=None):
+        B, N, C = x.shape
+        qkv_bias = None
+        if self.q_bias is not None:
+            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        if self.relative_position_bias_table is not None:
+            if training_window_size == self.window_size:
+                relative_position_bias = \
+                    self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                        self.window_size[0] * self.window_size[1] + 1,
+                        self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+                attn = attn + relative_position_bias.unsqueeze(0)
+            else:
+                training_window_size = tuple(training_window_size.tolist())
+                new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+                # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+                new_relative_position_bias_table = F.interpolate(
+                    self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
+                                                                                 2 * self.window_size[0] - 1,
+                                                                                 2 * self.window_size[1] - 1),
+                    size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
+                    align_corners=False)
+                new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
+                                                                                         new_num_relative_distance - 3).permute(
+                    1, 0)
+                new_relative_position_bias_table = torch.cat(
+                    [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
+
+                # get pair-wise relative position index for each token inside the window
+                coords_h = torch.arange(training_window_size[0])
+                coords_w = torch.arange(training_window_size[1])
+                coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+                coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+                relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+                relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+                relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0
+                relative_coords[:, :, 1] += training_window_size[1] - 1
+                relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+                relative_position_index = \
+                    torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
+                                dtype=relative_coords.dtype)
+                relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+                relative_position_index[0, 0:] = new_num_relative_distance - 3
+                relative_position_index[0:, 0] = new_num_relative_distance - 2
+                relative_position_index[0, 0] = new_num_relative_distance - 1
+
+                relative_position_bias = \
+                    new_relative_position_bias_table[relative_position_index.view(-1)].view(
+                        training_window_size[0] * training_window_size[1] + 1,
+                        training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+                attn = attn + relative_position_bias.unsqueeze(0)
+
+        if rel_pos_bias is not None:
+            attn = attn + rel_pos_bias
+
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+                 window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if init_values is not None:
+            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+        else:
+            self.gamma_1, self.gamma_2 = None, None
+
+    def forward(self, x, rel_pos_bias=None, training_window_size=None):
+        if self.gamma_1 is None:
+            x = x + self.drop_path(
+                self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
+            x = x + self.drop_path(self.mlp(self.norm2(x)))
+        else:
+            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
+                                                            training_window_size=training_window_size))
+            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+        self.num_patches_w = self.patch_shape[0]
+        self.num_patches_h = self.patch_shape[1]
+        # the so-called patch_shape is the patch shape during pre-training
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x, position_embedding=None, **kwargs):
+        # FIXME look at relaxing size constraints
+        # assert H == self.img_size[0] and W == self.img_size[1], \
+        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        Hp, Wp = x.shape[2], x.shape[3]
+
+        if position_embedding is not None:
+            # interpolate the position embedding to the corresponding size
+            position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
+                                                                                                                  1, 2)
+            position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
+            x = x + position_embedding
+
+        x = x.flatten(2).transpose(1, 2)
+        return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+    """ CNN Feature Map Embedding
+    Extract feature map from CNN, flatten, project to embedding dim.
+    """
+
+    def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
+        super().__init__()
+        assert isinstance(backbone, nn.Module)
+        img_size = to_2tuple(img_size)
+        self.img_size = img_size
+        self.backbone = backbone
+        if feature_size is None:
+            with torch.no_grad():
+                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+                # map for all networks, the feature metadata has reliable channel and stride info, but using
+                # stride to calc feature dim requires info about padding of each stage that isn't captured.
+                training = backbone.training
+                if training:
+                    backbone.eval()
+                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+                feature_size = o.shape[-2:]
+                feature_dim = o.shape[1]
+                backbone.train(training)
+        else:
+            feature_size = to_2tuple(feature_size)
+            feature_dim = self.backbone.feature_info.channels()[-1]
+        self.num_patches = feature_size[0] * feature_size[1]
+        self.proj = nn.Linear(feature_dim, embed_dim)
+
+    def forward(self, x):
+        x = self.backbone(x)[-1]
+        x = x.flatten(2).transpose(1, 2)
+        x = self.proj(x)
+        return x
+
+
+class RelativePositionBias(nn.Module):
+
+    def __init__(self, window_size, num_heads):
+        super().__init__()
+        self.window_size = window_size
+        self.num_heads = num_heads
+        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+        # cls to token & token 2 cls & cls to cls
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(window_size[0])
+        coords_w = torch.arange(window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+        relative_position_index = \
+            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        relative_position_index[0, 0:] = self.num_relative_distance - 3
+        relative_position_index[0:, 0] = self.num_relative_distance - 2
+        relative_position_index[0, 0] = self.num_relative_distance - 1
+
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+    def forward(self, training_window_size):
+        if training_window_size == self.window_size:
+            relative_position_bias = \
+                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                    self.window_size[0] * self.window_size[1] + 1,
+                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        else:
+            training_window_size = tuple(training_window_size.tolist())
+            new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+            # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+            new_relative_position_bias_table = F.interpolate(
+                self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
+                                                                             2 * self.window_size[0] - 1,
+                                                                             2 * self.window_size[1] - 1),
+                size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
+                align_corners=False)
+            new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
+                                                                                     new_num_relative_distance - 3).permute(
+                1, 0)
+            new_relative_position_bias_table = torch.cat(
+                [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(training_window_size[0])
+            coords_w = torch.arange(training_window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += training_window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
+                            dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = new_num_relative_distance - 3
+            relative_position_index[0:, 0] = new_num_relative_distance - 2
+            relative_position_index[0, 0] = new_num_relative_distance - 1
+
+            relative_position_bias = \
+                new_relative_position_bias_table[relative_position_index.view(-1)].view(
+                    training_window_size[0] * training_window_size[1] + 1,
+                    training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+
+        return relative_position_bias
+
+
+class BEiT(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+
+    def __init__(self,
+                 img_size=[224, 224],
+                 patch_size=16,
+                 in_chans=3,
+                 num_classes=80,
+                 embed_dim=768,
+                 depth=12,
+                 num_heads=12,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 hybrid_backbone=None,
+                 norm_layer=None,
+                 init_values=None,
+                 use_abs_pos_emb=False,
+                 use_rel_pos_bias=False,
+                 use_shared_rel_pos_bias=False,
+                 use_checkpoint=True,
+                 pretrained=None,
+                 out_features=None,
+                 ):
+
+        super(BEiT, self).__init__()
+
+        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        self.use_checkpoint = use_checkpoint
+
+        if hybrid_backbone is not None:
+            self.patch_embed = HybridEmbed(
+                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+        else:
+            self.patch_embed = PatchEmbed(
+                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+        self.out_features = out_features
+        self.out_indices = [int(name[5:]) for name in out_features]
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        if use_abs_pos_emb:
+            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        else:
+            self.pos_embed = None
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
+        if use_shared_rel_pos_bias:
+            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+        else:
+            self.rel_pos_bias = None
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.use_rel_pos_bias = use_rel_pos_bias
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+            for i in range(depth)])
+
+        # trunc_normal_(self.mask_token, std=.02)
+
+        if patch_size == 16:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+                # nn.SyncBatchNorm(embed_dim),
+                nn.BatchNorm2d(embed_dim),
+                nn.GELU(),
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn3 = nn.Identity()
+
+            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+        elif patch_size == 8:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Identity()
+
+            self.fpn3 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=2, stride=2),
+            )
+
+            self.fpn4 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=4, stride=4),
+            )
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    '''
+    def init_weights(self):
+        """Initialize the weights in backbone.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        logger = get_root_logger()
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+
+        if self.init_cfg is None:
+            logger.warn(f'No pre-trained weights for '
+                        f'{self.__class__.__name__}, '
+                        f'training start from scratch')
+        else:
+            assert 'checkpoint' in self.init_cfg, f'Only support ' \
+                                                  f'specify `Pretrained` in ' \
+                                                  f'`init_cfg` in ' \
+                                                  f'{self.__class__.__name__} '
+            logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
+            load_checkpoint(self,
+                            filename=self.init_cfg['checkpoint'],
+                            strict=False,
+                            logger=logger,
+                            beit_spec_expand_rel_pos = self.use_rel_pos_bias,
+                            )
+    '''
+
+    def get_num_layers(self):
+        return len(self.blocks)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def forward_features(self, x):
+        B, C, H, W = x.shape
+        x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
+        # Hp, Wp are HW for patches
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        if self.pos_embed is not None:
+            cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
+        x = torch.cat((cls_tokens, x), dim=1)
+        x = self.pos_drop(x)
+
+        features = []
+        training_window_size = torch.tensor([Hp, Wp])
+
+        rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
+
+        for i, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
+            else:
+                x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
+            if i in self.out_indices:
+                xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+                features.append(xp.contiguous())
+
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for i in range(len(features)):
+            features[i] = ops[i](features[i])
+
+        feat_out = {}
+
+        for name, value in zip(self.out_features, features):
+            feat_out[name] = value
+
+        return feat_out
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+def beit_base_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=None,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def beit_large_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=None,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def dit_base_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=0.1,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def dit_large_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=1e-5,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+if __name__ == '__main__':
+    model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
+    model = model.to("cuda:0")
+    input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
+    input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
+    input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
+    output1 = model(input1)
+    output2 = model(input2)
+    output3 = model(input3)
+    print("all done")

+ 32 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/config.py

@@ -0,0 +1,32 @@
+from detectron2.config import CfgNode as CN
+
+
+def add_vit_config(cfg):
+    """
+    Add config for VIT.
+    """
+    _C = cfg
+
+    _C.MODEL.VIT = CN()
+
+    # CoaT model name.
+    _C.MODEL.VIT.NAME = ""
+
+    # Output features from CoaT backbone.
+    _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
+
+    _C.MODEL.VIT.IMG_SIZE = [224, 224]
+
+    _C.MODEL.VIT.POS_TYPE = "shared_rel"
+
+    _C.MODEL.VIT.DROP_PATH = 0.
+
+    _C.MODEL.VIT.MODEL_KWARGS = "{}"
+
+    _C.SOLVER.OPTIMIZER = "ADAMW"
+
+    _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
+
+    _C.AUG = CN()
+
+    _C.AUG.DETR = False

+ 476 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/deit.py

@@ -0,0 +1,476 @@
+"""
+Mostly copy-paste from DINO and timm library:
+https://github.com/facebookresearch/dino
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+import warnings
+
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import trunc_normal_, drop_path, to_2tuple
+from functools import partial
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        B, N, C = x.shape
+        q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
+                                      C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
+                       act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        x = x + self.drop_path(self.attn(self.norm1(x)))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+
+        self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+
+        self.num_patches_w, self.num_patches_h = self.window_size
+
+        self.num_patches = self.window_size[0] * self.window_size[1]
+        self.img_size = img_size
+        self.patch_size = patch_size
+
+        self.proj = nn.Conv2d(in_chans, embed_dim,
+                              kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        x = self.proj(x)
+        return x
+
+
+class HybridEmbed(nn.Module):
+    """ CNN Feature Map Embedding
+    Extract feature map from CNN, flatten, project to embedding dim.
+    """
+
+    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+        super().__init__()
+        assert isinstance(backbone, nn.Module)
+        img_size = to_2tuple(img_size)
+        self.img_size = img_size
+        self.backbone = backbone
+        if feature_size is None:
+            with torch.no_grad():
+                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+                # map for all networks, the feature metadata has reliable channel and stride info, but using
+                # stride to calc feature dim requires info about padding of each stage that isn't captured.
+                training = backbone.training
+                if training:
+                    backbone.eval()
+                o = self.backbone(torch.zeros(
+                    1, in_chans, img_size[0], img_size[1]))[-1]
+                feature_size = o.shape[-2:]
+                feature_dim = o.shape[1]
+                backbone.train(training)
+        else:
+            feature_size = to_2tuple(feature_size)
+            feature_dim = self.backbone.feature_info.channels()[-1]
+        self.num_patches = feature_size[0] * feature_size[1]
+        self.proj = nn.Linear(feature_dim, embed_dim)
+
+    def forward(self, x):
+        x = self.backbone(x)[-1]
+        x = x.flatten(2).transpose(1, 2)
+        x = self.proj(x)
+        return x
+
+
+class ViT(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+
+    def __init__(self,
+                 model_name='vit_base_patch16_224',
+                 img_size=384,
+                 patch_size=16,
+                 in_chans=3,
+                 embed_dim=1024,
+                 depth=24,
+                 num_heads=16,
+                 num_classes=19,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.1,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 hybrid_backbone=None,
+                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
+                 norm_cfg=None,
+                 pos_embed_interp=False,
+                 random_init=False,
+                 align_corners=False,
+                 use_checkpoint=False,
+                 num_extra_tokens=1,
+                 out_features=None,
+                 **kwargs,
+                 ):
+
+        super(ViT, self).__init__()
+        self.model_name = model_name
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+        self.depth = depth
+        self.num_heads = num_heads
+        self.num_classes = num_classes
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.qk_scale = qk_scale
+        self.drop_rate = drop_rate
+        self.attn_drop_rate = attn_drop_rate
+        self.drop_path_rate = drop_path_rate
+        self.hybrid_backbone = hybrid_backbone
+        self.norm_layer = norm_layer
+        self.norm_cfg = norm_cfg
+        self.pos_embed_interp = pos_embed_interp
+        self.random_init = random_init
+        self.align_corners = align_corners
+        self.use_checkpoint = use_checkpoint
+        self.num_extra_tokens = num_extra_tokens
+        self.out_features = out_features
+        self.out_indices = [int(name[5:]) for name in out_features]
+
+        # self.num_stages = self.depth
+        # self.out_indices = tuple(range(self.num_stages))
+
+        if self.hybrid_backbone is not None:
+            self.patch_embed = HybridEmbed(
+                self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
+        else:
+            self.patch_embed = PatchEmbed(
+                img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
+        self.num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+        if self.num_extra_tokens == 2:
+            self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+        self.pos_embed = nn.Parameter(torch.zeros(
+            1, self.num_patches + self.num_extra_tokens, self.embed_dim))
+        self.pos_drop = nn.Dropout(p=self.drop_rate)
+
+        # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
+        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
+                                                self.depth)]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
+                qk_scale=self.qk_scale,
+                drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
+            for i in range(self.depth)])
+
+        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
+        # self.repr = nn.Linear(embed_dim, representation_size)
+        # self.repr_act = nn.Tanh()
+
+        if patch_size == 16:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+                nn.SyncBatchNorm(embed_dim),
+                nn.GELU(),
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn3 = nn.Identity()
+
+            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+        elif patch_size == 8:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Identity()
+
+            self.fpn3 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=2, stride=2),
+            )
+
+            self.fpn4 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=4, stride=4),
+            )
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        if self.num_extra_tokens==2:
+            trunc_normal_(self.dist_token, std=0.2)
+        self.apply(self._init_weights)
+        # self.fix_init_weight()
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    '''
+    def init_weights(self):
+        logger = get_root_logger()
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+
+        if self.init_cfg is None:
+            logger.warn(f'No pre-trained weights for '
+                        f'{self.__class__.__name__}, '
+                        f'training start from scratch')
+        else:
+            assert 'checkpoint' in self.init_cfg, f'Only support ' \
+                                                  f'specify `Pretrained` in ' \
+                                                  f'`init_cfg` in ' \
+                                                  f'{self.__class__.__name__} '
+            logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
+            load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
+    '''
+
+    def get_num_layers(self):
+        return len(self.blocks)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def _conv_filter(self, state_dict, patch_size=16):
+        """ convert patch embedding weight from manual patchify + linear proj to conv"""
+        out_dict = {}
+        for k, v in state_dict.items():
+            if 'patch_embed.proj.weight' in k:
+                v = v.reshape((v.shape[0], 3, patch_size, patch_size))
+            out_dict[k] = v
+        return out_dict
+
+    def to_2D(self, x):
+        n, hw, c = x.shape
+        h = w = int(math.sqrt(hw))
+        x = x.transpose(1, 2).reshape(n, c, h, w)
+        return x
+
+    def to_1D(self, x):
+        n, c, h, w = x.shape
+        x = x.reshape(n, c, -1).transpose(1, 2)
+        return x
+
+    def interpolate_pos_encoding(self, x, w, h):
+        npatch = x.shape[1] - self.num_extra_tokens
+        N = self.pos_embed.shape[1] - self.num_extra_tokens
+        if npatch == N and w == h:
+            return self.pos_embed
+
+        class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
+
+        patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
+
+        dim = x.shape[-1]
+        w0 = w // self.patch_embed.patch_size[0]
+        h0 = h // self.patch_embed.patch_size[1]
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        w0, h0 = w0 + 0.1, h0 + 0.1
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            mode='bicubic',
+        )
+        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
+
+    def prepare_tokens(self, x, mask=None):
+        B, nc, w, h = x.shape
+        # patch linear embedding
+        x = self.patch_embed(x)
+
+        # mask image modeling
+        if mask is not None:
+            x = self.mask_model(x, mask)
+        x = x.flatten(2).transpose(1, 2)
+
+        # add the [CLS] token to the embed patch tokens
+        all_tokens = [self.cls_token.expand(B, -1, -1)]
+
+        if self.num_extra_tokens == 2:
+            dist_tokens = self.dist_token.expand(B, -1, -1)
+            all_tokens.append(dist_tokens)
+        all_tokens.append(x)
+
+        x = torch.cat(all_tokens, dim=1)
+
+        # add positional encoding to each token
+        x = x + self.interpolate_pos_encoding(x, w, h)
+
+        return self.pos_drop(x)
+
+    def forward_features(self, x):
+        # print(f"==========shape of x is {x.shape}==========")
+        B, _, H, W = x.shape
+        Hp, Wp = H // self.patch_size, W // self.patch_size
+        x = self.prepare_tokens(x)
+
+        features = []
+        for i, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+            if i in self.out_indices:
+                xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+                features.append(xp.contiguous())
+
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for i in range(len(features)):
+            features[i] = ops[i](features[i])
+
+        feat_out = {}
+
+        for name, value in zip(self.out_features, features):
+            feat_out[name] = value
+
+        return feat_out
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+def deit_base_patch16(pretrained=False, **kwargs):
+    model = ViT(
+        patch_size=16,
+        drop_rate=0.,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        num_classes=1000,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        use_checkpoint=True,
+        num_extra_tokens=2,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def mae_base_patch16(pretrained=False, **kwargs):
+    model = ViT(
+        patch_size=16,
+        drop_rate=0.,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        num_classes=1000,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        use_checkpoint=True,
+        num_extra_tokens=1,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model

+ 15 - 0
ocr_tools/universal_doc_parser/models/adapters/__init__.py

@@ -15,6 +15,13 @@ from .paddle_layout_detector import PaddleLayoutDetector
 from .paddle_vl_adapter import PaddleVLRecognizer
 
 from .docling_layout_adapter import DoclingLayoutDetector
+
+# 可选导入 DiT 适配器
+try:
+    from .dit_layout_adapter import DitLayoutDetector
+    DIT_AVAILABLE = True
+except ImportError:
+    DIT_AVAILABLE = False
 # 可选导入 MinerU 适配器
 try:
     from .mineru_adapter import (
@@ -44,6 +51,10 @@ __all__ = [
     'DoclingLayoutDetector',
 ]
 
+# 如果 DiT 可用,添加到导出列表
+if DIT_AVAILABLE:
+    __all__.append('DitLayoutDetector')
+
 # 如果 MinerU 可用,添加到导出列表
 if MINERU_AVAILABLE:
     __all__.extend([
@@ -75,6 +86,10 @@ def get_layout_detector(config: dict):
         return MinerULayoutDetector(config)
     elif module == 'docling':
         return DoclingLayoutDetector(config)
+    elif module == 'dit':
+        if not DIT_AVAILABLE:
+            raise ImportError("DiT adapter not available. Please ensure detectron2 and ditod are installed.")
+        return DitLayoutDetector(config)
     else:
         raise ValueError(f"Unknown layout detection module: {module}")
 

+ 941 - 0
ocr_tools/universal_doc_parser/models/adapters/dit_layout_adapter.py

@@ -0,0 +1,941 @@
+"""DiT Layout Detector 适配器
+
+基于 DiT (Document Image Transformer) 的布局检测适配器,参考 docling_layout_adapter 的实现方式。
+支持 PubLayNet 数据集的 5 个类别:text, title, list, table, figure。
+
+支持的配置:
+- config_file: DiT 配置文件路径
+- model_weights: 模型权重路径或 URL
+- device: 运行设备 ('cpu', 'cuda', 'mps')
+- conf: 置信度阈值 (默认 0.3)
+- remove_overlap: 是否启用重叠框处理 (默认 True)
+- iou_threshold: IoU 阈值 (默认 0.8)
+- overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
+"""
+
+import cv2
+import numpy as np
+import threading
+from pathlib import Path
+from typing import Dict, List, Union, Any, Optional, Tuple
+from PIL import Image
+
+try:
+    from .base import BaseLayoutDetector
+except ImportError:
+    from base import BaseLayoutDetector
+
+# 全局锁,防止模型初始化时的线程问题
+_model_init_lock = threading.Lock()
+
+
+class LayoutUtils:
+    """布局处理工具类(简化版,不依赖 external 模块)"""
+    
+    @staticmethod
+    def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
+        """计算两个 bbox 的 IoU(交并比)"""
+        x1_1, y1_1, x2_1, y2_1 = bbox1
+        x1_2, y1_2, x2_2, y2_2 = bbox2
+        
+        # 计算交集
+        x1_i = max(x1_1, x1_2)
+        y1_i = max(y1_1, y1_2)
+        x2_i = min(x2_1, x2_2)
+        y2_i = min(y2_1, y2_2)
+        
+        if x2_i <= x1_i or y2_i <= y1_i:
+            return 0.0
+        
+        intersection = (x2_i - x1_i) * (y2_i - y1_i)
+        
+        # 计算并集
+        area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
+        area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
+        union = area1 + area2 - intersection
+        
+        if union == 0:
+            return 0.0
+        
+        return intersection / union
+    
+    @staticmethod
+    def calculate_overlap_ratio(bbox1: List[float], bbox2: List[float]) -> float:
+        """计算重叠面积占小框面积的比例"""
+        x1_1, y1_1, x2_1, y2_1 = bbox1
+        x1_2, y1_2, x2_2, y2_2 = bbox2
+        
+        # 计算交集
+        x1_i = max(x1_1, x1_2)
+        y1_i = max(y1_1, y1_2)
+        x2_i = min(x2_1, x2_2)
+        y2_i = min(y2_1, y2_2)
+        
+        if x2_i <= x1_i or y2_i <= y1_i:
+            return 0.0
+        
+        intersection = (x2_i - x1_i) * (y2_i - y1_i)
+        
+        # 计算两个框的面积
+        area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
+        area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
+        
+        # 返回交集占小框面积的比例
+        min_area = min(area1, area2)
+        if min_area == 0:
+            return 0.0
+        
+        return intersection / min_area
+    
+    # 不允许合并的类别组合
+    FORBIDDEN_MERGE = {
+        'image_body': ['text', 'title', 'table_body', 'table'],
+        'figure': ['text', 'title', 'table_body', 'table'],
+    }
+    
+    # 类别优先级(数字越大优先级越高)
+    CATEGORY_PRIORITY = {
+        'text': 3,
+        'title': 3,
+        'table_body': 3,
+        'table': 3,
+        'image_body': 1,
+        'figure': 1,
+    }
+    
+    @staticmethod
+    def remove_overlapping_boxes(
+        layout_results: List[Dict[str, Any]],
+        iou_threshold: float = 0.8,
+        overlap_ratio_threshold: float = 0.8,
+        image_size: Optional[Tuple[int, int]] = None,
+        max_area_ratio: float = 0.8,
+        enable_category_restriction: bool = True,
+        enable_category_priority: bool = True
+    ) -> List[Dict[str, Any]]:
+        """
+        处理重叠的布局框(参考 MinerU 的去重策略)
+        
+        策略:
+        1. 高 IoU 重叠:保留置信度高的框(考虑类别优先级)
+        2. 包含关系:小框被大框高度包含时,检查类别限制和面积限制后决定是否合并
+        
+        Args:
+            layout_results: Layout 检测结果列表
+            iou_threshold: IoU 阈值,超过此值认为高度重叠
+            overlap_ratio_threshold: 重叠面积占小框面积的比例阈值
+            image_size: 图像尺寸 (width, height),用于计算面积限制
+            max_area_ratio: 最大面积比例,合并后的框超过此比例则拒绝合并(默认0.8)
+            enable_category_restriction: 是否启用类别限制(默认True)
+            enable_category_priority: 是否启用类别优先级(默认True)
+            
+        Returns:
+            去重后的布局结果列表
+        """
+        if not layout_results or len(layout_results) <= 1:
+            return layout_results
+        
+        # 复制列表避免修改原数据
+        results = [item.copy() for item in layout_results]
+        need_remove = set()
+        
+        # 计算图像总面积(如果提供了图像尺寸)
+        img_area = None
+        if image_size is not None:
+            img_width, img_height = image_size
+            img_area = img_width * img_height
+        
+        def can_merge(cat1: str, cat2: str) -> bool:
+            """检查两个类别是否允许合并"""
+            if not enable_category_restriction:
+                return True
+            
+            # 检查是否在禁止合并列表中
+            forbidden1 = LayoutUtils.FORBIDDEN_MERGE.get(cat1, [])
+            if cat2 in forbidden1:
+                return False
+            
+            forbidden2 = LayoutUtils.FORBIDDEN_MERGE.get(cat2, [])
+            if cat1 in forbidden2:
+                return False
+            
+            return True
+        
+        def get_priority(category: str) -> int:
+            """获取类别优先级"""
+            if not enable_category_priority:
+                return 0
+            return LayoutUtils.CATEGORY_PRIORITY.get(category, 0)
+        
+        def check_area_limit(merged_bbox: List[float]) -> bool:
+            """检查合并后的框是否超过面积限制"""
+            if img_area is None:
+                return True  # 如果没有提供图像尺寸,不检查
+            
+            merged_area = (merged_bbox[2] - merged_bbox[0]) * (merged_bbox[3] - merged_bbox[1])
+            area_ratio = merged_area / img_area if img_area > 0 else 0
+            
+            return area_ratio <= max_area_ratio
+        
+        for i in range(len(results)):
+            if i in need_remove:
+                continue
+                
+            for j in range(i + 1, len(results)):
+                if j in need_remove:
+                    continue
+                
+                bbox1 = results[i].get('bbox', [0, 0, 0, 0])
+                bbox2 = results[j].get('bbox', [0, 0, 0, 0])
+                
+                if len(bbox1) < 4 or len(bbox2) < 4:
+                    continue
+                
+                cat1 = results[i].get('category', 'unknown')
+                cat2 = results[j].get('category', 'unknown')
+                
+                # 计算 IoU
+                iou = LayoutUtils.calculate_iou(bbox1, bbox2)
+                
+                if iou > iou_threshold:
+                    # 高度重叠,保留置信度高的框(考虑类别优先级)
+                    score1 = results[i].get('confidence', results[i].get('score', 0))
+                    score2 = results[j].get('confidence', results[j].get('score', 0))
+                    priority1 = get_priority(cat1)
+                    priority2 = get_priority(cat2)
+                    
+                    # 如果类别优先级不同,优先保留高优先级
+                    if priority1 != priority2:
+                        if priority1 > priority2:
+                            need_remove.add(j)
+                        else:
+                            need_remove.add(i)
+                            break
+                    # 如果类别优先级相同,保留置信度高的
+                    elif score1 >= score2:
+                        need_remove.add(j)
+                    else:
+                        need_remove.add(i)
+                        break
+                else:
+                    # 检查包含关系
+                    overlap_ratio = LayoutUtils.calculate_overlap_ratio(bbox1, bbox2)
+                    
+                    if overlap_ratio > overlap_ratio_threshold:
+                        # 检查类别是否允许合并
+                        if not can_merge(cat1, cat2):
+                            continue  # 不允许合并,跳过
+                        
+                        # 小框被大框高度包含
+                        area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
+                        area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
+                        
+                        if area1 <= area2:
+                            small_idx, large_idx = i, j
+                        else:
+                            small_idx, large_idx = j, i
+                        
+                        # 计算合并后的框
+                        small_bbox = results[small_idx]['bbox']
+                        large_bbox = results[large_idx]['bbox']
+                        merged_bbox = [
+                            min(small_bbox[0], large_bbox[0]),
+                            min(small_bbox[1], large_bbox[1]),
+                            max(small_bbox[2], large_bbox[2]),
+                            max(small_bbox[3], large_bbox[3])
+                        ]
+                        
+                        # 检查合并后的面积是否超过限制
+                        if not check_area_limit(merged_bbox):
+                            continue  # 超过面积限制,拒绝合并
+                        
+                        # 检查类别优先级:如果小框优先级更高,不应该被大框合并
+                        small_cat = results[small_idx].get('category', 'unknown')
+                        large_cat = results[large_idx].get('category', 'unknown')
+                        small_priority = get_priority(small_cat)
+                        large_priority = get_priority(large_cat)
+                        
+                        if small_priority > large_priority:
+                            continue  # 小框优先级更高,不应该被合并
+                        
+                        # 执行合并:扩展大框的边界
+                        results[large_idx]['bbox'] = merged_bbox
+                        need_remove.add(small_idx)
+                        
+                        if small_idx == i:
+                            break  # i 被移除,跳出内层循环
+        
+        # 返回去重后的结果
+        return [results[i] for i in range(len(results)) if i not in need_remove]
+    
+    @staticmethod
+    def filter_false_positive_images(
+        layout_results: List[Dict[str, Any]],
+        min_text_area_ratio: float = 0.3
+    ) -> List[Dict[str, Any]]:
+        """
+        过滤误检的图片框:如果图片框内包含的其他类型(如text/title/table)的面积总和
+        与图片框的面积比大于阈值,则认为该图片框是误检,应该移除。
+        
+        Args:
+            layout_results: Layout 检测结果列表
+            min_text_area_ratio: 最小文本面积比例阈值,如果图片框内文本面积占比超过此值则移除(默认0.3)
+            
+        Returns:
+            过滤后的布局结果列表
+        """
+        if not layout_results:
+            return layout_results
+        
+        # 需要移除的图片框索引
+        need_remove = set()
+        
+        # 找出所有图片框
+        image_boxes = []
+        other_boxes = []
+        
+        for i, result in enumerate(layout_results):
+            category = result.get('category', 'unknown')
+            if category in ['image_body', 'figure']:
+                image_boxes.append((i, result))
+            else:
+                other_boxes.append((i, result))
+        
+        # 对每个图片框,检查其内部包含的其他类型框的面积
+        for img_idx, img_result in image_boxes:
+            img_bbox = img_result.get('bbox', [0, 0, 0, 0])
+            if len(img_bbox) < 4:
+                continue
+            
+            img_area = (img_bbox[2] - img_bbox[0]) * (img_bbox[3] - img_bbox[1])
+            if img_area == 0:
+                continue
+            
+            # 计算图片框内包含的其他类型框的总面积
+            total_contained_area = 0.0
+            
+            for other_idx, other_result in other_boxes:
+                if other_idx in need_remove:
+                    continue
+                
+                other_bbox = other_result.get('bbox', [0, 0, 0, 0])
+                if len(other_bbox) < 4:
+                    continue
+                
+                # 检查其他框是否被图片框包含
+                # 使用 IoU 或包含关系判断
+                overlap_ratio = LayoutUtils.calculate_overlap_ratio(other_bbox, img_bbox)
+                
+                # 如果其他框的大部分(>50%)都在图片框内,认为被包含
+                if overlap_ratio > 0.5:
+                    other_area = (other_bbox[2] - other_bbox[0]) * (other_bbox[3] - other_bbox[1])
+                    # 计算实际包含的面积(交集)
+                    x1_i = max(img_bbox[0], other_bbox[0])
+                    y1_i = max(img_bbox[1], other_bbox[1])
+                    x2_i = min(img_bbox[2], other_bbox[2])
+                    y2_i = min(img_bbox[3], other_bbox[3])
+                    
+                    if x2_i > x1_i and y2_i > y1_i:
+                        intersection_area = (x2_i - x1_i) * (y2_i - y1_i)
+                        total_contained_area += intersection_area
+            
+            # 计算文本面积占比
+            text_area_ratio = total_contained_area / img_area if img_area > 0 else 0.0
+            
+            # 如果文本面积占比超过阈值,移除该图片框
+            if text_area_ratio > min_text_area_ratio:
+                need_remove.add(img_idx)
+                # 可选:打印调试信息
+                # print(f"🔄 Removed false positive image box: category={img_result.get('category')}, "
+                #       f"bbox={img_bbox}, text_area_ratio={text_area_ratio:.2f} > {min_text_area_ratio}")
+        
+        # 返回过滤后的结果
+        return [result for i, result in enumerate(layout_results) if i not in need_remove]
+
+
+class DitLayoutDetector(BaseLayoutDetector):
+    """DiT Layout Detector 适配器
+    
+    基于 DiT (Document Image Transformer) 的布局检测器,使用 detectron2 + DiT backbone。
+    支持 PubLayNet 数据集的布局检测。
+    """
+    
+    # DiT/PubLayNet 原始类别定义
+    DIT_LABELS = {
+        0: 'text',
+        1: 'title',
+        2: 'list',
+        3: 'table',
+        4: 'figure',
+    }
+    
+    # 类别映射:PubLayNet → MinerU/EnhancedDocPipeline 类别体系
+    # 参考:
+    # - Pipeline: universal_doc_parser/core/pipeline_manager_v2.py (EnhancedDocPipeline 类别定义)
+    CATEGORY_MAP = {
+        'text': 'text',                    # Text -> text (TEXT_CATEGORIES)
+        'title': 'title',                  # Title -> title (TEXT_CATEGORIES)
+        'list': 'text',                    # List-item -> text (TEXT_CATEGORIES)
+        'table': 'table_body',             # Table -> table_body (TABLE_BODY_CATEGORIES)
+        'figure': 'image_body',            # Figure -> image_body (IMAGE_BODY_CATEGORIES)
+    }
+    
+    def __init__(self, config: Dict[str, Any]):
+        """
+        初始化 DiT Layout 检测器
+        
+        Args:
+            config: 配置字典,支持以下参数:
+                - config_file: DiT 配置文件路径(默认使用 cascade_dit_large.yaml)
+                - model_weights: 模型权重路径或 URL
+                - device: 运行设备 ('cpu', 'cuda', 'mps')
+                - conf: 置信度阈值 (默认 0.3)
+                - remove_overlap: 是否启用重叠框处理 (默认 True)
+                - iou_threshold: IoU 阈值 (默认 0.8)
+                - overlap_ratio_threshold: 重叠比例阈值 (默认 0.8)
+                - max_area_ratio: 最大面积比例 (默认 0.8)
+                - enable_category_restriction: 是否启用类别限制 (默认 True)
+                - enable_category_priority: 是否启用类别优先级 (默认 True)
+                - filter_false_positive_images: 是否过滤误检的图片框 (默认 True)
+                - min_text_area_ratio: 最小文本面积比例阈值,图片框内文本面积占比超过此值则移除 (默认 0.3)
+        """
+        super().__init__(config)
+        self.predictor = None
+        self.cfg = None
+        self._device = None
+        self._threshold = 0.3
+        self._remove_overlap = True
+        self._iou_threshold = 0.8
+        self._overlap_ratio_threshold = 0.8
+        self._max_area_ratio = 0.8 # 最大面积比例,合并后的框超过此比例则拒绝合并(默认0.8)
+        self._enable_category_restriction = True
+        self._enable_category_priority = True
+        self._filter_false_positive_images = True
+        self._min_text_area_ratio = 0.3
+    
+    def initialize(self):
+        """初始化模型"""
+        import os
+        import sys
+        
+        try:
+            import torch
+            from detectron2.config import get_cfg
+            from detectron2.engine import DefaultPredictor
+            from detectron2.data import MetadataCatalog
+            
+            # PyTorch 2.6+ 兼容性修复
+            if hasattr(torch, '__version__'):
+                torch_version = tuple(map(int, torch.__version__.split('.')[:2]))
+                if torch_version >= (2, 6):
+                    _original_torch_load = torch.load
+                    def _patched_torch_load(f, map_location=None, pickle_module=None, 
+                                            weights_only=None, **kwargs):
+                        if weights_only is None:
+                            weights_only = False
+                        return _original_torch_load(f, map_location=map_location, 
+                                                  pickle_module=pickle_module,
+                                                  weights_only=weights_only, **kwargs)
+                    torch.load = _patched_torch_load
+            
+            # 添加 dit_support 路径(适配到 universal_doc_parser)
+            current_dir = os.path.dirname(os.path.abspath(__file__))
+            dit_support_path = Path(__file__).parents[2] / 'dit_support'
+            if dit_support_path not in sys.path:
+                sys.path.insert(0, str(dit_support_path))
+            
+            from ditod import add_vit_config
+            
+            # 获取配置参数
+            config_file = self.config.get(
+                'config_file',
+                dit_support_path / 'configs' / 'cascade' / 'cascade_dit_large.yaml'
+            )
+            model_weights = self.config.get(
+                'model_weights',
+                'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth'
+            )
+            device = self.config.get('device', 'cpu')
+            self._threshold = self.config.get('conf', 0.3)
+            self._remove_overlap = self.config.get('remove_overlap', True)
+            self._iou_threshold = self.config.get('iou_threshold', 0.8)
+            self._overlap_ratio_threshold = self.config.get('overlap_ratio_threshold', 0.8)
+            self._max_area_ratio = self.config.get('max_area_ratio', 0.8)
+            self._enable_category_restriction = self.config.get('enable_category_restriction', True)
+            self._enable_category_priority = self.config.get('enable_category_priority', True)
+            self._filter_false_positive_images = self.config.get('filter_false_positive_images', True)
+            self._min_text_area_ratio = self.config.get('min_text_area_ratio', 0.3)
+            
+            # 设置设备
+            self._device = torch.device(device)
+            
+            # 验证配置文件存在
+            if not os.path.exists(config_file):
+                raise FileNotFoundError(f"Config file not found: {config_file}")
+            
+            # 加载配置
+            self.cfg = get_cfg()
+            add_vit_config(self.cfg)
+            self.cfg.merge_from_file(config_file)
+            self.cfg.merge_from_list(["MODEL.WEIGHTS", model_weights])
+            self.cfg.MODEL.DEVICE = str(self._device)
+            
+            # 设置元数据
+            dataset_name = self.cfg.DATASETS.TEST[0]
+            md = MetadataCatalog.get(dataset_name)
+            if dataset_name == 'icdar2019_test':
+                md.set(thing_classes=["table"])
+            else:
+                md.set(thing_classes=["text", "title", "list", "table", "figure"])
+            
+            # 创建预测器(使用锁防止线程问题)
+            with _model_init_lock:
+                self.predictor = DefaultPredictor(self.cfg)
+            
+            print(f"✅ DiT Layout Detector initialized")
+            print(f"   - Config: {config_file}")
+            print(f"   - Device: {self._device}")
+            print(f"   - Threshold: {self._threshold}")
+            print(f"   - Remove overlap: {self._remove_overlap}")
+            
+        except ImportError as e:
+            print(f"❌ Failed to import required libraries: {e}")
+            print("   Please ensure detectron2 and ditod are installed")
+            raise
+        except Exception as e:
+            print(f"❌ Failed to initialize DiT Layout Detector: {e}")
+            raise
+    
+    def cleanup(self):
+        """清理资源"""
+        self.predictor = None
+        self.cfg = None
+        self._device = None
+    
+    def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
+        """
+        检测布局
+        
+        Args:
+            image: 输入图像 (numpy数组或PIL图像)
+            
+        Returns:
+            检测结果列表,每个元素包含:
+            - category: MinerU类别名称
+            - bbox: [x1, y1, x2, y2]
+            - confidence: 置信度
+            - raw: 原始检测结果
+        """
+        if self.predictor is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        # 转换为 numpy 数组 (BGR 格式)
+        if isinstance(image, Image.Image):
+            image = np.array(image)
+            if len(image.shape) == 3 and image.shape[2] == 3:
+                # PIL RGB -> OpenCV BGR
+                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+        
+        # 确保是 BGR 格式
+        if isinstance(image, np.ndarray):
+            if len(image.shape) == 2:
+                image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
+            elif len(image.shape) == 3 and image.shape[2] == 3:
+                # 假设是 RGB,转换为 BGR
+                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) if image.dtype == np.uint8 else image
+        
+        orig_h, orig_w = image.shape[:2]
+        
+        # 运行推理
+        outputs = self.predictor(image)
+        instances = outputs["instances"]
+        
+        # 解析结果
+        formatted_results = []
+        for i in range(len(instances)):
+            score = float(instances.scores[i].cpu().item())
+            
+            # 过滤低置信度
+            if score < self._threshold:
+                continue
+            
+            # 获取类别
+            class_id = int(instances.pred_classes[i].cpu().item())
+            original_label = self.DIT_LABELS.get(class_id, f'unknown_{class_id}')
+            
+            # 映射到 MinerU 类别
+            mineru_category = self.CATEGORY_MAP.get(original_label, 'text')
+            
+            # 提取边界框
+            bbox_tensor = instances.pred_boxes[i].tensor[0].cpu().numpy()
+            x1 = max(0, min(orig_w, float(bbox_tensor[0])))
+            y1 = max(0, min(orig_h, float(bbox_tensor[1])))
+            x2 = max(0, min(orig_w, float(bbox_tensor[2])))
+            y2 = max(0, min(orig_h, float(bbox_tensor[3])))
+            
+            bbox = [int(x1), int(y1), int(x2), int(y2)]
+            
+            # 计算宽高
+            width = bbox[2] - bbox[0]
+            height = bbox[3] - bbox[1]
+            
+            # 过滤太小的框
+            if width < 10 or height < 10:
+                continue
+            
+            # 过滤面积异常大的框
+            area = width * height
+            img_area = orig_w * orig_h
+            if area > img_area:
+                continue
+            
+            # 生成多边形坐标
+            poly = [
+                bbox[0], bbox[1],  # 左上
+                bbox[2], bbox[1],  # 右上
+                bbox[2], bbox[3],  # 右下
+                bbox[0], bbox[3],  # 左下
+            ]
+            
+            formatted_results.append({
+                'category': mineru_category,
+                'bbox': bbox,
+                'confidence': score,
+                'raw': {
+                    'original_label': original_label,
+                    'original_label_id': class_id,
+                    'poly': poly,
+                    'width': width,
+                    'height': height
+                }
+            })
+        
+        # 应用重叠框处理
+        if self._remove_overlap and len(formatted_results) > 1:
+            formatted_results = LayoutUtils.remove_overlapping_boxes(
+                formatted_results,
+                iou_threshold=self._iou_threshold,
+                overlap_ratio_threshold=self._overlap_ratio_threshold,
+                image_size=(orig_w, orig_h),
+                max_area_ratio=self._max_area_ratio,
+                enable_category_restriction=self._enable_category_restriction,
+                enable_category_priority=self._enable_category_priority
+            )
+        
+        # 过滤误检的图片框(包含过多文本内容的图片框)
+        if self._filter_false_positive_images and len(formatted_results) > 1:
+            before_count = len(formatted_results)
+            formatted_results = LayoutUtils.filter_false_positive_images(
+                formatted_results,
+                min_text_area_ratio=self._min_text_area_ratio
+            )
+            removed_count = before_count - len(formatted_results)
+            if removed_count > 0:
+                print(f"🔄 Filtered {removed_count} false positive image boxes")
+        
+        return formatted_results
+    
+    def detect_batch(
+        self, 
+        images: List[Union[np.ndarray, Image.Image]]
+    ) -> List[List[Dict[str, Any]]]:
+        """
+        批量检测布局
+        
+        Args:
+            images: 输入图像列表
+            
+        Returns:
+            每个图像的检测结果列表
+        """
+        if self.predictor is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        
+        if not images:
+            return []
+        
+        all_results = []
+        for image in images:
+            results = self.detect(image)
+            all_results.append(results)
+        
+        return all_results
+    
+    def visualize(
+        self, 
+        img: np.ndarray, 
+        results: List[Dict],
+        output_path: Optional[str] = None,
+        show_confidence: bool = True,
+        min_confidence: float = 0.0
+    ) -> np.ndarray:
+        """
+        可视化检测结果
+        
+        Args:
+            img: 输入图像 (BGR 格式)
+            results: 检测结果 (MinerU 格式)
+            output_path: 输出路径(可选)
+            show_confidence: 是否显示置信度
+            min_confidence: 最小置信度阈值
+            
+        Returns:
+            标注后的图像
+        """
+        import random
+        
+        vis_img = img.copy()
+        
+        # 预定义类别颜色(与 EnhancedDocPipeline 保持一致)
+        predefined_colors = {
+            # 文本类
+            'text': (153, 0, 76),
+            'title': (102, 102, 255),
+            'header': (128, 128, 128),
+            'footer': (128, 128, 128),
+            'page_footnote': (200, 200, 200),
+            # 表格类
+            'table_body': (204, 204, 0),
+            'table_caption': (255, 255, 102),
+            # 图片类
+            'image_body': (153, 255, 51),
+            'image_caption': (102, 178, 255),
+            # 公式类
+            'interline_equation': (0, 255, 0),
+            # 代码类
+            'code': (102, 0, 204),
+            # 丢弃类
+            'abandon': (100, 100, 100),
+        }
+        
+        # 过滤低置信度结果
+        filtered_results = [
+            res for res in results 
+            if res['confidence'] >= min_confidence
+        ]
+        
+        if not filtered_results:
+            print(f"⚠️ No results to visualize (min_confidence={min_confidence})")
+            return vis_img
+        
+        # 为每个出现的类别分配颜色
+        category_colors = {}
+        for res in filtered_results:
+            cat = res['category']
+            if cat not in category_colors:
+                if cat in predefined_colors:
+                    category_colors[cat] = predefined_colors[cat]
+                else:
+                    category_colors[cat] = (
+                        random.randint(50, 255),
+                        random.randint(50, 255),
+                        random.randint(50, 255)
+                    )
+        
+        # 绘制检测框
+        for res in filtered_results:
+            bbox = res['bbox']
+            x1, y1, x2, y2 = bbox
+            cat = res['category']
+            confidence = res['confidence']
+            color = category_colors[cat]
+            
+            # 获取原始标签
+            original_label = res.get('raw', {}).get('original_label', cat)
+            
+            # 绘制矩形边框
+            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
+            
+            # 构造标签文本
+            if show_confidence:
+                label = f"{original_label}->{cat} {confidence:.2f}"
+            else:
+                label = f"{original_label}->{cat}"
+            
+            # 计算标签尺寸
+            label_size, baseline = cv2.getTextSize(
+                label, 
+                cv2.FONT_HERSHEY_SIMPLEX, 
+                0.4, 
+                1
+            )
+            label_w, label_h = label_size
+            
+            # 绘制标签背景
+            cv2.rectangle(
+                vis_img,
+                (x1, y1 - label_h - 4),
+                (x1 + label_w, y1),
+                color,
+                -1
+            )
+            
+            # 绘制标签文字
+            cv2.putText(
+                vis_img,
+                label,
+                (x1, y1 - 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.4,
+                (255, 255, 255),
+                1,
+                cv2.LINE_AA
+            )
+        
+        # 添加图例
+        if category_colors:
+            self._draw_legend(vis_img, category_colors, len(filtered_results))
+        
+        # 保存可视化结果
+        if output_path:
+            output_path_obj = Path(output_path)
+            output_path_obj.parent.mkdir(parents=True, exist_ok=True)
+            cv2.imwrite(str(output_path_obj), vis_img)
+            print(f"💾 Visualization saved to: {output_path_obj}")
+        
+        return vis_img
+    
+    def _draw_legend(
+        self, 
+        img: np.ndarray, 
+        category_colors: Dict[str, tuple],
+        total_count: int
+    ):
+        """在图像上绘制图例"""
+        legend_x = img.shape[1] - 200
+        legend_y = 20
+        line_height = 25
+        
+        # 绘制半透明背景
+        overlay = img.copy()
+        cv2.rectangle(
+            overlay,
+            (legend_x - 10, legend_y - 10),
+            (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30),
+            (255, 255, 255),
+            -1
+        )
+        cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img)
+        
+        # 绘制标题
+        cv2.putText(
+            img,
+            f"Legend ({total_count} total)",
+            (legend_x, legend_y),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.5,
+            (0, 0, 0),
+            1,
+            cv2.LINE_AA
+        )
+        
+        # 绘制每个类别
+        y_offset = legend_y + line_height
+        for cat, color in sorted(category_colors.items()):
+            cv2.rectangle(
+                img,
+                (legend_x, y_offset - 10),
+                (legend_x + 15, y_offset),
+                color,
+                -1
+            )
+            cv2.rectangle(
+                img,
+                (legend_x, y_offset - 10),
+                (legend_x + 15, y_offset),
+                (0, 0, 0),
+                1
+            )
+            
+            cv2.putText(
+                img,
+                cat,
+                (legend_x + 20, y_offset - 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                0.4,
+                (0, 0, 0),
+                1,
+                cv2.LINE_AA
+            )
+            
+            y_offset += line_height
+
+
+# 测试代码
+if __name__ == "__main__":
+    import sys
+    import os
+    
+    # 测试配置
+    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
+    config = {
+        'config_file': os.path.join(project_root, 'dit', 'object_detection',
+                                   'publaynet_configs', 'cascade', 'cascade_dit_large.yaml'),
+        'model_weights': 'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth',
+        'device': 'cpu',
+        'conf': 0.3,
+        'remove_overlap': True,
+        'iou_threshold': 0.8,
+        'overlap_ratio_threshold': 0.8
+    }
+    
+    # 初始化检测器
+    print("🔧 Initializing DiT Layout Detector...")
+    detector = DitLayoutDetector(config)
+    detector.initialize()
+    
+    # 读取测试图像
+    img_path = "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司/paddleocr_vl_results/2023年度报告母公司/2023年度报告母公司_page_021.png"
+    
+    print(f"\n📖 Loading image: {img_path}")
+    img = cv2.imread(img_path)
+    
+    if img is None:
+        print(f"❌ Failed to load image: {img_path}")
+        sys.exit(1)
+    
+    print(f"   Image shape: {img.shape}")
+    
+    # 执行检测
+    print("\n🔍 Detecting layout...")
+    results = detector.detect(img)
+    
+    print(f"\n✅ 检测到 {len(results)} 个区域:")
+    for i, res in enumerate(results, 1):
+        print(f"  [{i}] {res['category']}: "
+              f"score={res['confidence']:.3f}, "
+              f"bbox={res['bbox']}, "
+              f"original={res['raw']['original_label']}")
+    
+    # 统计各类别
+    category_counts = {}
+    for res in results:
+        cat = res['category']
+        category_counts[cat] = category_counts.get(cat, 0) + 1
+    
+    print(f"\n📊 类别统计 (MinerU格式):")
+    for cat, count in sorted(category_counts.items()):
+        print(f"  - {cat}: {count}")
+    
+    # 可视化
+    if len(results) > 0:
+        print("\n🎨 Generating visualization...")
+        
+        output_dir = Path(__file__).parent / "output"
+        output_dir.mkdir(parents=True, exist_ok=True)
+        output_path = output_dir / f"{Path(img_path).stem}_dit_layout_vis.jpg"
+        
+        vis_img = detector.visualize(
+            img, 
+            results, 
+            output_path=str(output_path),
+            show_confidence=True,
+            min_confidence=0.0
+        )
+        
+        print(f"💾 Visualization saved to: {output_path}")
+    
+    # 清理
+    detector.cleanup()
+    print("\n✅ 测试完成!")
+

+ 529 - 0
ocr_tools/universal_doc_parser/tests/test_dit_layout_adapter.py

@@ -0,0 +1,529 @@
+"""
+DiT Layout Detector 测试脚本
+
+测试 DitLayoutDetector 适配器,支持:
+- PDF 文件输入(自动转换为图像)
+- 图像文件输入
+- 目录输入(批量处理)
+- 页面范围过滤
+- 布局检测和结果统计
+- 可视化结果保存
+"""
+
+import sys
+import json
+import argparse
+from pathlib import Path
+from typing import List, Dict, Any
+
+import cv2
+
+# 添加项目根目录到路径
+project_root = Path(__file__).parents[1]
+sys.path.insert(0, str(project_root))
+
+# 添加 ocr_platform 根目录(用于导入 ocr_utils)
+ocr_platform_root = project_root.parents[1]
+if str(ocr_platform_root) not in sys.path:
+    sys.path.insert(0, str(ocr_platform_root))
+
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+from models.adapters.dit_layout_adapter import DitLayoutDetector
+from ocr_utils.file_utils import convert_pdf_to_images, get_image_files_from_dir
+
+
+def parse_args():
+    """解析命令行参数"""
+    parser = argparse.ArgumentParser(
+        description="测试 DiT Layout Detector 适配器",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+示例:
+  # 测试 PDF 文件(处理所有页面)
+  python test_dit_layout_adapter.py --input /path/to/document.pdf
+
+  # 测试 PDF 文件(指定页面范围)
+  python test_dit_layout_adapter.py --input /path/to/document.pdf --pages "1-5,10-15"
+
+  # 测试图像文件
+  python test_dit_layout_adapter.py --input /path/to/image.png
+
+  # 测试目录(批量处理)
+  python test_dit_layout_adapter.py --input /path/to/images/ --output-dir ./results
+
+  # 使用自定义配置
+  python test_dit_layout_adapter.py --input /path/to/document.pdf \\
+      --config-file ./custom_config.yaml \\
+      --model-weights /path/to/model.pth \\
+      --device cuda \\
+      --conf 0.5
+        """
+    )
+    
+    parser.add_argument(
+        "--input",
+        type=str,
+        required=True,
+        help="输入路径(PDF文件/图像文件/图像目录)"
+    )
+    
+    parser.add_argument(
+        "--output-dir",
+        type=str,
+        default=None,
+        help="输出目录(默认: tests/output/)"
+    )
+    
+    parser.add_argument(
+        "--config-file",
+        type=str,
+        default=None,
+        help="DiT 配置文件路径(可选,默认使用内置配置)"
+    )
+    
+    parser.add_argument(
+        "--model-weights",
+        type=str,
+        default=None,
+        help="模型权重路径或 URL(可选,默认从 HuggingFace 下载)"
+    )
+    
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cpu",
+        choices=["cpu", "cuda", "mps"],
+        help="运行设备 (默认: cpu)"
+    )
+    
+    parser.add_argument(
+        "--conf",
+        type=float,
+        default=0.3,
+        help="置信度阈值 (默认: 0.3)"
+    )
+    
+    parser.add_argument(
+        "--pages",
+        type=str,
+        default=None,
+        help="页面范围(如 '1-5,7,9-12'),仅对 PDF 有效"
+    )
+    
+    parser.add_argument(
+        "--remove-overlap",
+        action="store_true",
+        default=True,
+        help="启用重叠框处理(默认启用)"
+    )
+    
+    parser.add_argument(
+        "--no-remove-overlap",
+        action="store_false",
+        dest="remove_overlap",
+        help="禁用重叠框处理"
+    )
+    
+    parser.add_argument(
+        "--iou-threshold",
+        type=float,
+        default=0.8,
+        help="IoU 阈值 (默认: 0.8)"
+    )
+    
+    parser.add_argument(
+        "--overlap-ratio-threshold",
+        type=float,
+        default=0.8,
+        help="重叠比例阈值 (默认: 0.8)"
+    )
+    
+    parser.add_argument(
+        "--dpi",
+        type=int,
+        default=200,
+        help="PDF 转图像 DPI (默认: 200)"
+    )
+    
+    parser.add_argument(
+        "--save-json",
+        action="store_true",
+        help="保存 JSON 格式的检测结果"
+    )
+    
+    parser.add_argument(
+        "--min-confidence",
+        type=float,
+        default=0.0,
+        help="可视化时的最小置信度阈值 (默认: 0.0)"
+    )
+    
+    return parser.parse_args()
+
+
+def get_input_images(input_path: str, page_range: str = None, dpi: int = 200) -> List[str]:
+    """
+    获取输入图像文件列表
+    
+    Args:
+        input_path: 输入路径(PDF/图像/目录)
+        page_range: 页面范围(仅对 PDF 有效)
+        dpi: PDF 转图像 DPI
+    
+    Returns:
+        图像文件路径列表
+    """
+    input_path_obj = Path(input_path)
+    
+    if not input_path_obj.exists():
+        raise FileNotFoundError(f"输入路径不存在: {input_path}")
+    
+    image_files = []
+    
+    if input_path_obj.is_file():
+        if input_path_obj.suffix.lower() == '.pdf':
+            # PDF 文件:转换为图像
+            print(f"📄 处理 PDF 文件: {input_path_obj.name}")
+            image_files = convert_pdf_to_images(
+                str(input_path_obj),
+                output_dir=None,  # 使用默认输出目录
+                dpi=dpi,
+                page_range=page_range
+            )
+            if not image_files:
+                raise ValueError(f"PDF 转换失败,未生成图像文件")
+            print(f"✅ PDF 转换为 {len(image_files)} 张图像")
+        
+        elif input_path_obj.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']:
+            # 图像文件:直接添加
+            image_files = [str(input_path_obj)]
+            print(f"📷 处理图像文件: {input_path_obj.name}")
+        
+        else:
+            raise ValueError(f"不支持的文件类型: {input_path_obj.suffix}")
+    
+    elif input_path_obj.is_dir():
+        # 目录:扫描所有图像文件
+        image_files = get_image_files_from_dir(input_path_obj)
+        if not image_files:
+            raise ValueError(f"目录中未找到图像文件: {input_path}")
+        print(f"📁 从目录中找到 {len(image_files)} 张图像")
+    
+    else:
+        raise ValueError(f"无效的输入路径: {input_path}")
+    
+    return sorted(image_files)
+
+
+def build_config(args, project_root: Path) -> Dict[str, Any]:
+    """
+    构建检测器配置
+    
+    Args:
+        args: 命令行参数
+        project_root: 项目根目录
+    
+    Returns:
+        配置字典
+    """
+    config = {
+        'device': args.device,
+        'conf': args.conf,
+        'remove_overlap': args.remove_overlap,
+        'iou_threshold': args.iou_threshold,
+        'overlap_ratio_threshold': args.overlap_ratio_threshold,
+    }
+    
+    # 配置文件路径
+    if args.config_file:
+        config['config_file'] = args.config_file
+    else:
+        # 使用默认配置文件
+        default_config_file = project_root / 'dit_support' / 'configs' / 'cascade' / 'cascade_dit_large.yaml'
+        if default_config_file.exists():
+            config['config_file'] = str(default_config_file)
+        else:
+            print(f"⚠️  警告: 默认配置文件不存在: {default_config_file}")
+            print("   请使用 --config-file 指定配置文件路径")
+    
+    # 模型权重
+    if args.model_weights:
+        config['model_weights'] = args.model_weights
+    else:
+        # 使用默认模型权重 URL
+        config['model_weights'] = (
+            'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth'
+        )
+    
+    return config
+
+
+def process_images(
+    detector: DitLayoutDetector,
+    image_files: List[str],
+    output_dir: Path,
+    save_json: bool = False,
+    min_confidence: float = 0.0
+) -> Dict[str, Any]:
+    """
+    处理图像列表,进行布局检测
+    
+    Args:
+        detector: 布局检测器
+        image_files: 图像文件路径列表
+        output_dir: 输出目录
+        save_json: 是否保存 JSON 结果
+        min_confidence: 最小置信度阈值
+    
+    Returns:
+        统计结果字典
+    """
+    all_results = {}
+    total_stats = {
+        'total_pages': len(image_files),
+        'total_regions': 0,
+        'category_counts': {},
+        'confidence_stats': {
+            'min': float('inf'),
+            'max': 0.0,
+            'sum': 0.0,
+            'count': 0
+        }
+    }
+    
+    for idx, image_path in enumerate(image_files, 1):
+        print(f"\n{'='*60}")
+        print(f"📖 处理图像 {idx}/{len(image_files)}: {Path(image_path).name}")
+        print(f"{'='*60}")
+        
+        # 读取图像
+        img = cv2.imread(image_path)
+        if img is None:
+            print(f"❌ 无法读取图像: {image_path}")
+            continue
+        
+        print(f"   图像尺寸: {img.shape[1]}x{img.shape[0]}")
+        
+        # 执行检测
+        try:
+            results = detector.detect(img)
+            print(f"✅ 检测到 {len(results)} 个区域")
+            
+            # 统计结果
+            page_stats = {
+                'image_path': image_path,
+                'image_size': [img.shape[1], img.shape[0]],
+                'regions': [],
+                'category_counts': {}
+            }
+            
+            for res in results:
+                # 添加到页面统计
+                page_stats['regions'].append({
+                    'category': res['category'],
+                    'bbox': res['bbox'],
+                    'confidence': float(res['confidence']),
+                    'original_label': res.get('raw', {}).get('original_label', 'unknown')
+                })
+                
+                # 更新类别统计
+                cat = res['category']
+                page_stats['category_counts'][cat] = page_stats['category_counts'].get(cat, 0) + 1
+                total_stats['category_counts'][cat] = total_stats['category_counts'].get(cat, 0) + 1
+                
+                # 更新置信度统计
+                conf = res['confidence']
+                total_stats['confidence_stats']['min'] = min(total_stats['confidence_stats']['min'], conf)
+                total_stats['confidence_stats']['max'] = max(total_stats['confidence_stats']['max'], conf)
+                total_stats['confidence_stats']['sum'] += conf
+                total_stats['confidence_stats']['count'] += 1
+            
+            total_stats['total_regions'] += len(results)
+            all_results[image_path] = page_stats
+            
+            # 打印页面统计
+            if page_stats['category_counts']:
+                print(f"\n   类别统计:")
+                for cat, count in sorted(page_stats['category_counts'].items()):
+                    print(f"     - {cat}: {count}")
+            
+            # 可视化
+            if len(results) > 0:
+                print(f"\n   🎨 生成可视化图像...")
+                
+                image_stem = Path(image_path).stem
+                output_path = output_dir / f"{image_stem}_dit_layout_vis.jpg"
+                
+                vis_img = detector.visualize(
+                    img,
+                    results,
+                    output_path=str(output_path),
+                    show_confidence=True,
+                    min_confidence=min_confidence
+                )
+                
+                print(f"   💾 可视化图像已保存: {output_path}")
+            
+            # 保存 JSON 结果
+            if save_json:
+                json_path = output_dir / f"{Path(image_path).stem}_dit_layout_results.json"
+                with open(json_path, 'w', encoding='utf-8') as f:
+                    json.dump(page_stats, f, ensure_ascii=False, indent=2)
+                print(f"   💾 JSON 结果已保存: {json_path}")
+        
+        except Exception as e:
+            print(f"❌ 检测失败: {e}")
+            import traceback
+            traceback.print_exc()
+            continue
+    
+    # 计算平均置信度
+    if total_stats['confidence_stats']['count'] > 0:
+        total_stats['confidence_stats']['mean'] = (
+            total_stats['confidence_stats']['sum'] / total_stats['confidence_stats']['count']
+        )
+    else:
+        total_stats['confidence_stats']['mean'] = 0.0
+        total_stats['confidence_stats']['min'] = 0.0
+    
+    return {
+        'all_results': all_results,
+        'total_stats': total_stats
+    }
+
+
+def print_summary(stats: Dict[str, Any]):
+    """打印统计摘要"""
+    total_stats = stats['total_stats']
+    
+    print(f"\n{'='*60}")
+    print(f"📊 检测结果摘要")
+    print(f"{'='*60}")
+    print(f"总页数: {total_stats['total_pages']}")
+    print(f"总区域数: {total_stats['total_regions']}")
+    
+    if total_stats['total_regions'] > 0:
+        print(f"\n类别统计:")
+        for cat, count in sorted(total_stats['category_counts'].items()):
+            percentage = (count / total_stats['total_regions']) * 100
+            print(f"  - {cat}: {count} ({percentage:.1f}%)")
+        
+        conf_stats = total_stats['confidence_stats']
+        print(f"\n置信度统计:")
+        print(f"  - 最小值: {conf_stats['min']:.3f}")
+        print(f"  - 最大值: {conf_stats['max']:.3f}")
+        print(f"  - 平均值: {conf_stats['mean']:.3f}")
+
+
+def main():
+    """主函数"""
+    args = parse_args()
+    
+    # 设置输出目录
+    if args.output_dir:
+        output_dir = Path(args.output_dir)
+    else:
+        output_dir = Path(__file__).parent / "output"
+    output_dir.mkdir(parents=True, exist_ok=True)
+    print(f"📁 输出目录: {output_dir}")
+    
+    # 获取输入图像列表
+    try:
+        image_files = get_input_images(
+            args.input,
+            page_range=args.pages,
+            dpi=args.dpi
+        )
+    except Exception as e:
+        print(f"❌ 错误: {e}")
+        sys.exit(1)
+    
+    if not image_files:
+        print("❌ 未找到要处理的图像文件")
+        sys.exit(1)
+    
+    # 构建配置
+    project_root = Path(__file__).parents[1]
+    config = build_config(args, project_root)
+    
+    # 初始化检测器
+    print(f"\n{'='*60}")
+    print(f"🔧 初始化 DiT Layout Detector")
+    print(f"{'='*60}")
+    print(f"配置文件: {config.get('config_file', 'N/A')}")
+    print(f"模型权重: {config.get('model_weights', 'N/A')}")
+    print(f"设备: {config['device']}")
+    print(f"置信度阈值: {config['conf']}")
+    print(f"重叠框处理: {config['remove_overlap']}")
+    
+    try:
+        detector = DitLayoutDetector(config)
+        detector.initialize()
+        print("✅ 检测器初始化成功")
+    except Exception as e:
+        print(f"❌ 检测器初始化失败: {e}")
+        import traceback
+        traceback.print_exc()
+        sys.exit(1)
+    
+    # 处理图像
+    try:
+        stats = process_images(
+            detector,
+            image_files,
+            output_dir,
+            save_json=args.save_json,
+            min_confidence=args.min_confidence
+        )
+        
+        # 打印摘要
+        print_summary(stats)
+        
+        # 保存总体统计
+        summary_path = output_dir / "detection_summary.json"
+        with open(summary_path, 'w', encoding='utf-8') as f:
+            json.dump(stats['total_stats'], f, ensure_ascii=False, indent=2)
+        print(f"\n💾 统计摘要已保存: {summary_path}")
+        
+    except Exception as e:
+        print(f"❌ 处理过程中出错: {e}")
+        import traceback
+        traceback.print_exc()
+    finally:
+        # 清理资源
+        detector.cleanup()
+        print("\n✅ 测试完成!")
+
+
+if __name__ == "__main__":
+    if len(sys.argv) == 1:
+        # 没有命令行参数时,使用默认配置运行
+        print("ℹ️  未提供命令行参数,使用默认配置运行...")
+        
+        # 默认配置
+        default_config = {
+            # 测试输入
+            "input": "/Users/zhch158/workspace/data/流水分析/2023年度报告母公司.pdf",
+            "output-dir": "./output/2023年度报告母公司_dit_layout_adapter",
+
+            
+            # 页面范围(可选)
+            # "pages": "2-7,24, 26, 29-34",  # 只处理前1页
+            "pages": "32",  # 处理指定页面
+
+			# 是否启用重叠框处理
+			# "no-remove-overlap": True,
+        }
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            if isinstance(value, bool):
+                if value:
+                    sys.argv.append(f"--{key}")
+            else:
+                sys.argv.extend([f"--{key}", str(value)])
+    
+    sys.exit(main())