Quellcode durchsuchen

feat: 新增 DiT 支持模块及其核心功能

- 添加 DiT (Document Image Transformer) 布局检测所需的核心代码和配置文件,包括模型定义、配置文件和使用示例。
- 引入多个配置文件,支持不同的模型和训练参数设置。
- 更新 README 文档,详细说明模块结构、使用方法和依赖包,提升用户体验和理解。
zhch158_admin vor 1 Woche
Ursprung
Commit
20d936e629

+ 91 - 0
ocr_tools/universal_doc_parser/dit_support/README.md

@@ -0,0 +1,91 @@
+# DiT 支持模块
+
+本目录包含 DiT (Document Image Transformer) 布局检测所需的核心代码和配置文件。
+
+## 目录结构
+
+```
+dit_support/
+├── ditod/                    # DiT 核心模块
+│   ├── __init__.py          # 模块导出(仅推理必需)
+│   ├── config.py            # 配置扩展(add_vit_config)
+│   ├── backbone.py          # ViT backbone 实现
+│   ├── beit.py              # BEiT/DIT 模型定义
+│   └── deit.py              # DeiT 模型定义(可选)
+└── configs/                  # 配置文件
+    ├── Base-RCNN-FPN.yaml   # 基础配置
+    └── cascade/
+        └── cascade_dit_large.yaml  # Cascade R-CNN + DiT-large 配置
+```
+
+## 使用方法
+
+在 `universal_doc_parser` 中使用 DiT 布局检测:
+
+```python
+from models.adapters import get_layout_detector
+
+# 配置 DiT 检测器
+config = {
+    'module': 'dit',
+    'config_file': 'dit_support/configs/cascade/cascade_dit_large.yaml',
+    'model_weights': 'https://huggingface.co/HYPJUDY/dit/resolve/main/dit-fts/publaynet_dit-l_cascade.pth',
+    'device': 'cpu',  # 或 'cuda'
+    'conf': 0.3,
+    'remove_overlap': True,
+    'iou_threshold': 0.8,
+    'overlap_ratio_threshold': 0.8,
+}
+
+# 创建检测器
+detector = get_layout_detector(config)
+detector.initialize()
+
+# 检测布局
+import cv2
+img = cv2.imread('image.jpg')
+results = detector.detect(img)
+
+# 清理
+detector.cleanup()
+```
+
+## 依赖包
+
+需要安装以下 Python 包:
+
+```bash
+# 1. PyTorch(必须先安装)
+pip install torch torchvision
+
+# 2. detectron2
+# Mac M4 Pro / Apple Silicon:
+CC=clang CXX=clang++ ARCHFLAGS="-arch arm64" pip install --no-build-isolation 'git+https://github.com/facebookresearch/detectron2.git'
+
+# Linux (CPU):
+pip install 'git+https://github.com/facebookresearch/detectron2.git'
+
+# Linux (CUDA):
+pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6'
+
+# 3. timm(Vision Transformer 模型库)
+pip install timm
+
+# 4. 基础依赖
+pip install numpy opencv-python Pillow einops
+```
+
+## 迁移说明
+
+本模块是从 `unilm/dit/object_detection/` 迁移的最小版本,仅包含推理必需的代码:
+
+- ✅ 已迁移:ditod 核心模块(5个文件)、配置文件(2个)
+- ❌ 未迁移:训练相关代码(dataset_mapper.py, mytrainer.py 等)、评估代码(icdar_evaluation.py, table_evaluation/)
+
+## 注意事项
+
+1. **路径问题**:确保 `dit_support` 目录在 Python 路径中(适配器会自动处理)
+2. **模型权重**:首次运行会自动从 HuggingFace 下载,需要网络连接
+3. **PyTorch 2.6+**:代码中已包含兼容性修复
+4. **重叠框处理**:默认启用,可在配置中关闭或调整阈值
+

+ 69 - 0
ocr_tools/universal_doc_parser/dit_support/configs/Base-RCNN-FPN.yaml

@@ -0,0 +1,69 @@
+MODEL:
+  MASK_ON: True
+  META_ARCHITECTURE: "GeneralizedRCNN"
+  PIXEL_MEAN: [123.675, 116.280, 103.530]
+  PIXEL_STD: [58.395, 57.120, 57.375]
+  BACKBONE:
+    NAME: "build_vit_fpn_backbone"
+  VIT:
+    OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
+    DROP_PATH: 0.1
+    IMG_SIZE: [224,224]
+    POS_TYPE: "abs"
+  FPN:
+    IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
+  ANCHOR_GENERATOR:
+    SIZES: [[32], [64], [128], [256], [512]]  # One size for each in feature map
+    ASPECT_RATIOS: [[0.5, 1.0, 2.0]]  # Three aspect ratios (same for all in feature maps)
+  RPN:
+    IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
+    PRE_NMS_TOPK_TRAIN: 2000  # Per FPN level
+    PRE_NMS_TOPK_TEST: 1000  # Per FPN level
+    # Detectron1 uses 2000 proposals per-batch,
+    # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
+    # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
+    POST_NMS_TOPK_TRAIN: 1000
+    POST_NMS_TOPK_TEST: 1000
+  ROI_HEADS:
+    NAME: "StandardROIHeads"
+    IN_FEATURES: ["p2", "p3", "p4", "p5"]
+    NUM_CLASSES: 5
+  ROI_BOX_HEAD:
+    NAME: "FastRCNNConvFCHead"
+    NUM_FC: 2
+    POOLER_RESOLUTION: 7
+  ROI_MASK_HEAD:
+    NAME: "MaskRCNNConvUpsampleHead"
+    NUM_CONV: 4
+    POOLER_RESOLUTION: 14
+DATASETS:
+  TRAIN: ("publaynet_train",)
+  TEST: ("publaynet_val",)
+SOLVER:
+  LR_SCHEDULER_NAME: "WarmupCosineLR"
+  AMP:
+    ENABLED: True
+  OPTIMIZER: "ADAMW"
+  BACKBONE_MULTIPLIER: 1.0
+  CLIP_GRADIENTS:
+    ENABLED: True
+    CLIP_TYPE: "full_model"
+    CLIP_VALUE: 1.0
+    NORM_TYPE: 2.0
+  WARMUP_FACTOR: 0.01
+  BASE_LR: 0.0004
+  WEIGHT_DECAY: 0.05
+  IMS_PER_BATCH: 32
+INPUT:
+  CROP:
+    ENABLED: True
+    TYPE: "absolute_range"
+    SIZE: (384, 600)
+  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
+  FORMAT: "RGB"
+DATALOADER:
+  FILTER_EMPTY_ANNOTATIONS: False
+VERSION: 2
+AUG:
+  DETR: True
+SEED: 42

+ 28 - 0
ocr_tools/universal_doc_parser/dit_support/configs/cascade/cascade_dit_large.yaml

@@ -0,0 +1,28 @@
+_BASE_: "../Base-RCNN-FPN.yaml"
+MODEL:
+  PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
+  PIXEL_STD: [ 127.5, 127.5, 127.5 ]
+  WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-large-224-p16-500k-d7a2fb.pth"
+  VIT:
+    NAME: "dit_large_patch16"
+    OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
+    DROP_PATH: 0.2
+  FPN:
+    IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
+  ROI_HEADS:
+    NAME: CascadeROIHeads
+  ROI_BOX_HEAD:
+    CLS_AGNOSTIC_BBOX_REG: True
+  RPN:
+    POST_NMS_TOPK_TRAIN: 2000
+SOLVER:
+  WARMUP_ITERS: 1000
+  IMS_PER_BATCH: 16
+  MAX_ITER: 60000
+  CHECKPOINT_PERIOD: 2000
+  BASE_LR: 0.0001
+  STEPS: (40000, 53333)
+  AMP:
+    ENABLED: False
+TEST:
+  EVAL_PERIOD: 2000

+ 18 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/__init__.py

@@ -0,0 +1,18 @@
+# --------------------------------------------------------------------------------
+# MPViT: Multi-Path Vision Transformer for Dense Prediction
+# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
+# All Rights Reserved.
+# Written by Youngwan Lee
+# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------------------
+# 最小迁移版本:仅包含推理必需的导出
+
+from .config import add_vit_config
+from .backbone import build_vit_fpn_backbone
+
+__all__ = [
+    "add_vit_config",
+    "build_vit_fpn_backbone",
+]
+

+ 156 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/backbone.py

@@ -0,0 +1,156 @@
+# --------------------------------------------------------------------------------
+# VIT: Multi-Path Vision Transformer for Dense Prediction
+# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
+# All Rights Reserved.
+# Written by Youngwan Lee
+# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# CoaT: https://github.com/mlpc-ucsd/CoaT
+# --------------------------------------------------------------------------------
+
+
+import torch
+
+from detectron2.layers import (
+    ShapeSpec,
+)
+from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
+from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
+
+from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
+from .deit import deit_base_patch16, mae_base_patch16
+
+__all__ = [
+    "build_vit_fpn_backbone",
+]
+
+
+class VIT_Backbone(Backbone):
+    """
+    Implement VIT backbone.
+    """
+
+    def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs):
+        super().__init__()
+        self._out_features = out_features
+        if 'base' in name:
+            self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
+        else:
+            self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
+
+        if name == 'beit_base_patch16':
+            model_func = beit_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == 'dit_base_patch16':
+            model_func = dit_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == "deit_base_patch16":
+            model_func = deit_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == "mae_base_patch16":
+            model_func = mae_base_patch16
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        elif name == "dit_large_patch16":
+            model_func = dit_large_patch16
+            self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
+        elif name == "beit_large_patch16":
+            model_func = beit_large_patch16
+            self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
+        else:
+            raise ValueError("Unsupported VIT name yet.")
+
+        if 'beit' in name or 'dit' in name:
+            if pos_type == "abs":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_abs_pos_emb=True,
+                                           **model_kwargs)
+            elif pos_type == "shared_rel":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_shared_rel_pos_bias=True,
+                                           **model_kwargs)
+            elif pos_type == "rel":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_rel_pos_bias=True,
+                                           **model_kwargs)
+            else:
+                raise ValueError()
+        else:
+            self.backbone = model_func(img_size=img_size,
+                                       out_features=out_features,
+                                       drop_path_rate=drop_path,
+                                       **model_kwargs)
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+
+        Returns:
+            dict[str->Tensor]: names and the corresponding features
+        """
+        assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+        return self.backbone.forward_features(x)
+
+    def output_shape(self):
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+            )
+            for name in self._out_features
+        }
+
+
+def build_VIT_backbone(cfg):
+    """
+    Create a VIT instance from config.
+
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        A VIT backbone instance.
+    """
+    # fmt: off
+    name = cfg.MODEL.VIT.NAME
+    out_features = cfg.MODEL.VIT.OUT_FEATURES
+    drop_path = cfg.MODEL.VIT.DROP_PATH
+    img_size = cfg.MODEL.VIT.IMG_SIZE
+    pos_type = cfg.MODEL.VIT.POS_TYPE
+
+    model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
+
+    return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs)
+
+
+@BACKBONE_REGISTRY.register()
+def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
+    """
+    Create a VIT w/ FPN backbone.
+
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
+    """
+    bottom_up = build_VIT_backbone(cfg)
+    in_features = cfg.MODEL.FPN.IN_FEATURES
+    out_channels = cfg.MODEL.FPN.OUT_CHANNELS
+    backbone = FPN(
+        bottom_up=bottom_up,
+        in_features=in_features,
+        out_channels=out_channels,
+        norm=cfg.MODEL.FPN.NORM,
+        top_block=LastLevelMaxPool(),
+        fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
+    )
+    return backbone

+ 671 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/beit.py

@@ -0,0 +1,671 @@
+""" Vision Transformer (ViT) in PyTorch
+
+A PyTorch implement of Vision Transformers as described in
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
+
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+
+Status/TODO:
+* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
+* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
+* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
+* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
+
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import warnings
+import math
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        # x = self.drop(x)
+        # commit this for the orignal BERT implement
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+            proj_drop=0., window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        if attn_head_dim is not None:
+            head_dim = attn_head_dim
+        all_head_dim = head_dim * self.num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+        if qkv_bias:
+            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+        else:
+            self.q_bias = None
+            self.v_bias = None
+
+        if window_size:
+            self.window_size = window_size
+            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+            self.relative_position_bias_table = nn.Parameter(
+                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+            # cls to token & token 2 cls & cls to cls
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(window_size[0])
+            coords_w = torch.arange(window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = self.num_relative_distance - 3
+            relative_position_index[0:, 0] = self.num_relative_distance - 2
+            relative_position_index[0, 0] = self.num_relative_distance - 1
+
+            self.register_buffer("relative_position_index", relative_position_index)
+
+            # trunc_normal_(self.relative_position_bias_table, std=.0)
+        else:
+            self.window_size = None
+            self.relative_position_bias_table = None
+            self.relative_position_index = None
+
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(all_head_dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x, rel_pos_bias=None, training_window_size=None):
+        B, N, C = x.shape
+        qkv_bias = None
+        if self.q_bias is not None:
+            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        if self.relative_position_bias_table is not None:
+            if training_window_size == self.window_size:
+                relative_position_bias = \
+                    self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                        self.window_size[0] * self.window_size[1] + 1,
+                        self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+                attn = attn + relative_position_bias.unsqueeze(0)
+            else:
+                training_window_size = tuple(training_window_size.tolist())
+                new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+                # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+                new_relative_position_bias_table = F.interpolate(
+                    self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
+                                                                                 2 * self.window_size[0] - 1,
+                                                                                 2 * self.window_size[1] - 1),
+                    size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
+                    align_corners=False)
+                new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
+                                                                                         new_num_relative_distance - 3).permute(
+                    1, 0)
+                new_relative_position_bias_table = torch.cat(
+                    [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
+
+                # get pair-wise relative position index for each token inside the window
+                coords_h = torch.arange(training_window_size[0])
+                coords_w = torch.arange(training_window_size[1])
+                coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+                coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+                relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+                relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+                relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0
+                relative_coords[:, :, 1] += training_window_size[1] - 1
+                relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+                relative_position_index = \
+                    torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
+                                dtype=relative_coords.dtype)
+                relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+                relative_position_index[0, 0:] = new_num_relative_distance - 3
+                relative_position_index[0:, 0] = new_num_relative_distance - 2
+                relative_position_index[0, 0] = new_num_relative_distance - 1
+
+                relative_position_bias = \
+                    new_relative_position_bias_table[relative_position_index.view(-1)].view(
+                        training_window_size[0] * training_window_size[1] + 1,
+                        training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+                attn = attn + relative_position_bias.unsqueeze(0)
+
+        if rel_pos_bias is not None:
+            attn = attn + rel_pos_bias
+
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+                 window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if init_values is not None:
+            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+        else:
+            self.gamma_1, self.gamma_2 = None, None
+
+    def forward(self, x, rel_pos_bias=None, training_window_size=None):
+        if self.gamma_1 is None:
+            x = x + self.drop_path(
+                self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
+            x = x + self.drop_path(self.mlp(self.norm2(x)))
+        else:
+            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
+                                                            training_window_size=training_window_size))
+            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+        self.num_patches_w = self.patch_shape[0]
+        self.num_patches_h = self.patch_shape[1]
+        # the so-called patch_shape is the patch shape during pre-training
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x, position_embedding=None, **kwargs):
+        # FIXME look at relaxing size constraints
+        # assert H == self.img_size[0] and W == self.img_size[1], \
+        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        Hp, Wp = x.shape[2], x.shape[3]
+
+        if position_embedding is not None:
+            # interpolate the position embedding to the corresponding size
+            position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
+                                                                                                                  1, 2)
+            position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
+            x = x + position_embedding
+
+        x = x.flatten(2).transpose(1, 2)
+        return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+    """ CNN Feature Map Embedding
+    Extract feature map from CNN, flatten, project to embedding dim.
+    """
+
+    def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
+        super().__init__()
+        assert isinstance(backbone, nn.Module)
+        img_size = to_2tuple(img_size)
+        self.img_size = img_size
+        self.backbone = backbone
+        if feature_size is None:
+            with torch.no_grad():
+                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+                # map for all networks, the feature metadata has reliable channel and stride info, but using
+                # stride to calc feature dim requires info about padding of each stage that isn't captured.
+                training = backbone.training
+                if training:
+                    backbone.eval()
+                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+                feature_size = o.shape[-2:]
+                feature_dim = o.shape[1]
+                backbone.train(training)
+        else:
+            feature_size = to_2tuple(feature_size)
+            feature_dim = self.backbone.feature_info.channels()[-1]
+        self.num_patches = feature_size[0] * feature_size[1]
+        self.proj = nn.Linear(feature_dim, embed_dim)
+
+    def forward(self, x):
+        x = self.backbone(x)[-1]
+        x = x.flatten(2).transpose(1, 2)
+        x = self.proj(x)
+        return x
+
+
+class RelativePositionBias(nn.Module):
+
+    def __init__(self, window_size, num_heads):
+        super().__init__()
+        self.window_size = window_size
+        self.num_heads = num_heads
+        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+        # cls to token & token 2 cls & cls to cls
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(window_size[0])
+        coords_w = torch.arange(window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+        relative_position_index = \
+            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        relative_position_index[0, 0:] = self.num_relative_distance - 3
+        relative_position_index[0:, 0] = self.num_relative_distance - 2
+        relative_position_index[0, 0] = self.num_relative_distance - 1
+
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+    def forward(self, training_window_size):
+        if training_window_size == self.window_size:
+            relative_position_bias = \
+                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                    self.window_size[0] * self.window_size[1] + 1,
+                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        else:
+            training_window_size = tuple(training_window_size.tolist())
+            new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+            # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+            new_relative_position_bias_table = F.interpolate(
+                self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
+                                                                             2 * self.window_size[0] - 1,
+                                                                             2 * self.window_size[1] - 1),
+                size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
+                align_corners=False)
+            new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
+                                                                                     new_num_relative_distance - 3).permute(
+                1, 0)
+            new_relative_position_bias_table = torch.cat(
+                [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(training_window_size[0])
+            coords_w = torch.arange(training_window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += training_window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
+                            dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = new_num_relative_distance - 3
+            relative_position_index[0:, 0] = new_num_relative_distance - 2
+            relative_position_index[0, 0] = new_num_relative_distance - 1
+
+            relative_position_bias = \
+                new_relative_position_bias_table[relative_position_index.view(-1)].view(
+                    training_window_size[0] * training_window_size[1] + 1,
+                    training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+
+        return relative_position_bias
+
+
+class BEiT(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+
+    def __init__(self,
+                 img_size=[224, 224],
+                 patch_size=16,
+                 in_chans=3,
+                 num_classes=80,
+                 embed_dim=768,
+                 depth=12,
+                 num_heads=12,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 hybrid_backbone=None,
+                 norm_layer=None,
+                 init_values=None,
+                 use_abs_pos_emb=False,
+                 use_rel_pos_bias=False,
+                 use_shared_rel_pos_bias=False,
+                 use_checkpoint=True,
+                 pretrained=None,
+                 out_features=None,
+                 ):
+
+        super(BEiT, self).__init__()
+
+        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        self.use_checkpoint = use_checkpoint
+
+        if hybrid_backbone is not None:
+            self.patch_embed = HybridEmbed(
+                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+        else:
+            self.patch_embed = PatchEmbed(
+                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+        self.out_features = out_features
+        self.out_indices = [int(name[5:]) for name in out_features]
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        if use_abs_pos_emb:
+            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        else:
+            self.pos_embed = None
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
+        if use_shared_rel_pos_bias:
+            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+        else:
+            self.rel_pos_bias = None
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.use_rel_pos_bias = use_rel_pos_bias
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+            for i in range(depth)])
+
+        # trunc_normal_(self.mask_token, std=.02)
+
+        if patch_size == 16:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+                # nn.SyncBatchNorm(embed_dim),
+                nn.BatchNorm2d(embed_dim),
+                nn.GELU(),
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn3 = nn.Identity()
+
+            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+        elif patch_size == 8:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Identity()
+
+            self.fpn3 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=2, stride=2),
+            )
+
+            self.fpn4 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=4, stride=4),
+            )
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    '''
+    def init_weights(self):
+        """Initialize the weights in backbone.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        logger = get_root_logger()
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+
+        if self.init_cfg is None:
+            logger.warn(f'No pre-trained weights for '
+                        f'{self.__class__.__name__}, '
+                        f'training start from scratch')
+        else:
+            assert 'checkpoint' in self.init_cfg, f'Only support ' \
+                                                  f'specify `Pretrained` in ' \
+                                                  f'`init_cfg` in ' \
+                                                  f'{self.__class__.__name__} '
+            logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
+            load_checkpoint(self,
+                            filename=self.init_cfg['checkpoint'],
+                            strict=False,
+                            logger=logger,
+                            beit_spec_expand_rel_pos = self.use_rel_pos_bias,
+                            )
+    '''
+
+    def get_num_layers(self):
+        return len(self.blocks)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def forward_features(self, x):
+        B, C, H, W = x.shape
+        x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
+        # Hp, Wp are HW for patches
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        if self.pos_embed is not None:
+            cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
+        x = torch.cat((cls_tokens, x), dim=1)
+        x = self.pos_drop(x)
+
+        features = []
+        training_window_size = torch.tensor([Hp, Wp])
+
+        rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
+
+        for i, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
+            else:
+                x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
+            if i in self.out_indices:
+                xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+                features.append(xp.contiguous())
+
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for i in range(len(features)):
+            features[i] = ops[i](features[i])
+
+        feat_out = {}
+
+        for name, value in zip(self.out_features, features):
+            feat_out[name] = value
+
+        return feat_out
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+def beit_base_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=None,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def beit_large_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=None,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def dit_base_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=0.1,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def dit_large_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=1e-5,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+if __name__ == '__main__':
+    model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
+    model = model.to("cuda:0")
+    input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
+    input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
+    input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
+    output1 = model(input1)
+    output2 = model(input2)
+    output3 = model(input3)
+    print("all done")

+ 32 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/config.py

@@ -0,0 +1,32 @@
+from detectron2.config import CfgNode as CN
+
+
+def add_vit_config(cfg):
+    """
+    Add config for VIT.
+    """
+    _C = cfg
+
+    _C.MODEL.VIT = CN()
+
+    # CoaT model name.
+    _C.MODEL.VIT.NAME = ""
+
+    # Output features from CoaT backbone.
+    _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
+
+    _C.MODEL.VIT.IMG_SIZE = [224, 224]
+
+    _C.MODEL.VIT.POS_TYPE = "shared_rel"
+
+    _C.MODEL.VIT.DROP_PATH = 0.
+
+    _C.MODEL.VIT.MODEL_KWARGS = "{}"
+
+    _C.SOLVER.OPTIMIZER = "ADAMW"
+
+    _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
+
+    _C.AUG = CN()
+
+    _C.AUG.DETR = False

+ 476 - 0
ocr_tools/universal_doc_parser/dit_support/ditod/deit.py

@@ -0,0 +1,476 @@
+"""
+Mostly copy-paste from DINO and timm library:
+https://github.com/facebookresearch/dino
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+import warnings
+
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import trunc_normal_, drop_path, to_2tuple
+from functools import partial
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        B, N, C = x.shape
+        q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
+                                      C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
+                       act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        x = x + self.drop_path(self.attn(self.norm1(x)))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+
+        self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+
+        self.num_patches_w, self.num_patches_h = self.window_size
+
+        self.num_patches = self.window_size[0] * self.window_size[1]
+        self.img_size = img_size
+        self.patch_size = patch_size
+
+        self.proj = nn.Conv2d(in_chans, embed_dim,
+                              kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        x = self.proj(x)
+        return x
+
+
+class HybridEmbed(nn.Module):
+    """ CNN Feature Map Embedding
+    Extract feature map from CNN, flatten, project to embedding dim.
+    """
+
+    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+        super().__init__()
+        assert isinstance(backbone, nn.Module)
+        img_size = to_2tuple(img_size)
+        self.img_size = img_size
+        self.backbone = backbone
+        if feature_size is None:
+            with torch.no_grad():
+                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+                # map for all networks, the feature metadata has reliable channel and stride info, but using
+                # stride to calc feature dim requires info about padding of each stage that isn't captured.
+                training = backbone.training
+                if training:
+                    backbone.eval()
+                o = self.backbone(torch.zeros(
+                    1, in_chans, img_size[0], img_size[1]))[-1]
+                feature_size = o.shape[-2:]
+                feature_dim = o.shape[1]
+                backbone.train(training)
+        else:
+            feature_size = to_2tuple(feature_size)
+            feature_dim = self.backbone.feature_info.channels()[-1]
+        self.num_patches = feature_size[0] * feature_size[1]
+        self.proj = nn.Linear(feature_dim, embed_dim)
+
+    def forward(self, x):
+        x = self.backbone(x)[-1]
+        x = x.flatten(2).transpose(1, 2)
+        x = self.proj(x)
+        return x
+
+
+class ViT(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+
+    def __init__(self,
+                 model_name='vit_base_patch16_224',
+                 img_size=384,
+                 patch_size=16,
+                 in_chans=3,
+                 embed_dim=1024,
+                 depth=24,
+                 num_heads=16,
+                 num_classes=19,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.1,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 hybrid_backbone=None,
+                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
+                 norm_cfg=None,
+                 pos_embed_interp=False,
+                 random_init=False,
+                 align_corners=False,
+                 use_checkpoint=False,
+                 num_extra_tokens=1,
+                 out_features=None,
+                 **kwargs,
+                 ):
+
+        super(ViT, self).__init__()
+        self.model_name = model_name
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+        self.depth = depth
+        self.num_heads = num_heads
+        self.num_classes = num_classes
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.qk_scale = qk_scale
+        self.drop_rate = drop_rate
+        self.attn_drop_rate = attn_drop_rate
+        self.drop_path_rate = drop_path_rate
+        self.hybrid_backbone = hybrid_backbone
+        self.norm_layer = norm_layer
+        self.norm_cfg = norm_cfg
+        self.pos_embed_interp = pos_embed_interp
+        self.random_init = random_init
+        self.align_corners = align_corners
+        self.use_checkpoint = use_checkpoint
+        self.num_extra_tokens = num_extra_tokens
+        self.out_features = out_features
+        self.out_indices = [int(name[5:]) for name in out_features]
+
+        # self.num_stages = self.depth
+        # self.out_indices = tuple(range(self.num_stages))
+
+        if self.hybrid_backbone is not None:
+            self.patch_embed = HybridEmbed(
+                self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
+        else:
+            self.patch_embed = PatchEmbed(
+                img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
+        self.num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+        if self.num_extra_tokens == 2:
+            self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+        self.pos_embed = nn.Parameter(torch.zeros(
+            1, self.num_patches + self.num_extra_tokens, self.embed_dim))
+        self.pos_drop = nn.Dropout(p=self.drop_rate)
+
+        # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
+        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
+                                                self.depth)]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
+                qk_scale=self.qk_scale,
+                drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
+            for i in range(self.depth)])
+
+        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
+        # self.repr = nn.Linear(embed_dim, representation_size)
+        # self.repr_act = nn.Tanh()
+
+        if patch_size == 16:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+                nn.SyncBatchNorm(embed_dim),
+                nn.GELU(),
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn3 = nn.Identity()
+
+            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+        elif patch_size == 8:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Identity()
+
+            self.fpn3 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=2, stride=2),
+            )
+
+            self.fpn4 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=4, stride=4),
+            )
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        if self.num_extra_tokens==2:
+            trunc_normal_(self.dist_token, std=0.2)
+        self.apply(self._init_weights)
+        # self.fix_init_weight()
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    '''
+    def init_weights(self):
+        logger = get_root_logger()
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+
+        if self.init_cfg is None:
+            logger.warn(f'No pre-trained weights for '
+                        f'{self.__class__.__name__}, '
+                        f'training start from scratch')
+        else:
+            assert 'checkpoint' in self.init_cfg, f'Only support ' \
+                                                  f'specify `Pretrained` in ' \
+                                                  f'`init_cfg` in ' \
+                                                  f'{self.__class__.__name__} '
+            logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
+            load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
+    '''
+
+    def get_num_layers(self):
+        return len(self.blocks)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def _conv_filter(self, state_dict, patch_size=16):
+        """ convert patch embedding weight from manual patchify + linear proj to conv"""
+        out_dict = {}
+        for k, v in state_dict.items():
+            if 'patch_embed.proj.weight' in k:
+                v = v.reshape((v.shape[0], 3, patch_size, patch_size))
+            out_dict[k] = v
+        return out_dict
+
+    def to_2D(self, x):
+        n, hw, c = x.shape
+        h = w = int(math.sqrt(hw))
+        x = x.transpose(1, 2).reshape(n, c, h, w)
+        return x
+
+    def to_1D(self, x):
+        n, c, h, w = x.shape
+        x = x.reshape(n, c, -1).transpose(1, 2)
+        return x
+
+    def interpolate_pos_encoding(self, x, w, h):
+        npatch = x.shape[1] - self.num_extra_tokens
+        N = self.pos_embed.shape[1] - self.num_extra_tokens
+        if npatch == N and w == h:
+            return self.pos_embed
+
+        class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
+
+        patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
+
+        dim = x.shape[-1]
+        w0 = w // self.patch_embed.patch_size[0]
+        h0 = h // self.patch_embed.patch_size[1]
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        w0, h0 = w0 + 0.1, h0 + 0.1
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            mode='bicubic',
+        )
+        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
+
+    def prepare_tokens(self, x, mask=None):
+        B, nc, w, h = x.shape
+        # patch linear embedding
+        x = self.patch_embed(x)
+
+        # mask image modeling
+        if mask is not None:
+            x = self.mask_model(x, mask)
+        x = x.flatten(2).transpose(1, 2)
+
+        # add the [CLS] token to the embed patch tokens
+        all_tokens = [self.cls_token.expand(B, -1, -1)]
+
+        if self.num_extra_tokens == 2:
+            dist_tokens = self.dist_token.expand(B, -1, -1)
+            all_tokens.append(dist_tokens)
+        all_tokens.append(x)
+
+        x = torch.cat(all_tokens, dim=1)
+
+        # add positional encoding to each token
+        x = x + self.interpolate_pos_encoding(x, w, h)
+
+        return self.pos_drop(x)
+
+    def forward_features(self, x):
+        # print(f"==========shape of x is {x.shape}==========")
+        B, _, H, W = x.shape
+        Hp, Wp = H // self.patch_size, W // self.patch_size
+        x = self.prepare_tokens(x)
+
+        features = []
+        for i, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+            if i in self.out_indices:
+                xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+                features.append(xp.contiguous())
+
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for i in range(len(features)):
+            features[i] = ops[i](features[i])
+
+        feat_out = {}
+
+        for name, value in zip(self.out_features, features):
+            feat_out[name] = value
+
+        return feat_out
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+def deit_base_patch16(pretrained=False, **kwargs):
+    model = ViT(
+        patch_size=16,
+        drop_rate=0.,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        num_classes=1000,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        use_checkpoint=True,
+        num_extra_tokens=2,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def mae_base_patch16(pretrained=False, **kwargs):
+    model = ViT(
+        patch_size=16,
+        drop_rate=0.,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        num_classes=1000,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        use_checkpoint=True,
+        num_extra_tokens=1,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model