backbone.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # --------------------------------------------------------------------------------
  2. # VIT: Multi-Path Vision Transformer for Dense Prediction
  3. # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
  4. # All Rights Reserved.
  5. # Written by Youngwan Lee
  6. # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. # --------------------------------------------------------------------------------
  9. # References:
  10. # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
  11. # CoaT: https://github.com/mlpc-ucsd/CoaT
  12. # --------------------------------------------------------------------------------
  13. import torch
  14. from detectron2.layers import (
  15. ShapeSpec,
  16. )
  17. from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
  18. from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
  19. from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
  20. from .deit import deit_base_patch16, mae_base_patch16
  21. from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
  22. from transformers import AutoConfig
  23. __all__ = [
  24. "build_vit_fpn_backbone",
  25. ]
  26. class VIT_Backbone(Backbone):
  27. """
  28. Implement VIT backbone.
  29. """
  30. def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
  31. config_path=None, image_only=False, cfg=None):
  32. super().__init__()
  33. self._out_features = out_features
  34. if 'base' in name:
  35. self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
  36. self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
  37. else:
  38. self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
  39. self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
  40. if name == 'beit_base_patch16':
  41. model_func = beit_base_patch16
  42. elif name == 'dit_base_patch16':
  43. model_func = dit_base_patch16
  44. elif name == "deit_base_patch16":
  45. model_func = deit_base_patch16
  46. elif name == "mae_base_patch16":
  47. model_func = mae_base_patch16
  48. elif name == "dit_large_patch16":
  49. model_func = dit_large_patch16
  50. elif name == "beit_large_patch16":
  51. model_func = beit_large_patch16
  52. if 'beit' in name or 'dit' in name:
  53. if pos_type == "abs":
  54. self.backbone = model_func(img_size=img_size,
  55. out_features=out_features,
  56. drop_path_rate=drop_path,
  57. use_abs_pos_emb=True,
  58. **model_kwargs)
  59. elif pos_type == "shared_rel":
  60. self.backbone = model_func(img_size=img_size,
  61. out_features=out_features,
  62. drop_path_rate=drop_path,
  63. use_shared_rel_pos_bias=True,
  64. **model_kwargs)
  65. elif pos_type == "rel":
  66. self.backbone = model_func(img_size=img_size,
  67. out_features=out_features,
  68. drop_path_rate=drop_path,
  69. use_rel_pos_bias=True,
  70. **model_kwargs)
  71. else:
  72. raise ValueError()
  73. elif "layoutlmv3" in name:
  74. config = AutoConfig.from_pretrained(config_path)
  75. # disable relative bias as DiT
  76. config.has_spatial_attention_bias = False
  77. config.has_relative_attention_bias = False
  78. self.backbone = LayoutLMv3Model(config, detection=True,
  79. out_features=out_features, image_only=image_only)
  80. else:
  81. self.backbone = model_func(img_size=img_size,
  82. out_features=out_features,
  83. drop_path_rate=drop_path,
  84. **model_kwargs)
  85. self.name = name
  86. def forward(self, x):
  87. """
  88. Args:
  89. x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
  90. Returns:
  91. dict[str->Tensor]: names and the corresponding features
  92. """
  93. if "layoutlmv3" in self.name:
  94. return self.backbone.forward(
  95. input_ids=x["input_ids"] if "input_ids" in x else None,
  96. bbox=x["bbox"] if "bbox" in x else None,
  97. images=x["images"] if "images" in x else None,
  98. attention_mask=x["attention_mask"] if "attention_mask" in x else None,
  99. # output_hidden_states=True,
  100. )
  101. assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
  102. return self.backbone.forward_features(x)
  103. def output_shape(self):
  104. return {
  105. name: ShapeSpec(
  106. channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
  107. )
  108. for name in self._out_features
  109. }
  110. def build_VIT_backbone(cfg):
  111. """
  112. Create a VIT instance from config.
  113. Args:
  114. cfg: a detectron2 CfgNode
  115. Returns:
  116. A VIT backbone instance.
  117. """
  118. # fmt: off
  119. name = cfg.MODEL.VIT.NAME
  120. out_features = cfg.MODEL.VIT.OUT_FEATURES
  121. drop_path = cfg.MODEL.VIT.DROP_PATH
  122. img_size = cfg.MODEL.VIT.IMG_SIZE
  123. pos_type = cfg.MODEL.VIT.POS_TYPE
  124. model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
  125. if 'layoutlmv3' in name:
  126. if cfg.MODEL.CONFIG_PATH != '':
  127. config_path = cfg.MODEL.CONFIG_PATH
  128. else:
  129. config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
  130. config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
  131. else:
  132. config_path = None
  133. return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
  134. config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
  135. @BACKBONE_REGISTRY.register()
  136. def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
  137. """
  138. Create a VIT w/ FPN backbone.
  139. Args:
  140. cfg: a detectron2 CfgNode
  141. Returns:
  142. backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
  143. """
  144. bottom_up = build_VIT_backbone(cfg)
  145. in_features = cfg.MODEL.FPN.IN_FEATURES
  146. out_channels = cfg.MODEL.FPN.OUT_CHANNELS
  147. backbone = FPN(
  148. bottom_up=bottom_up,
  149. in_features=in_features,
  150. out_channels=out_channels,
  151. norm=cfg.MODEL.FPN.NORM,
  152. top_block=LastLevelMaxPool(),
  153. fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
  154. )
  155. return backbone