base_model.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from torch import nn
  2. from ..backbones import build_backbone
  3. from ..heads import build_head
  4. from ..necks import build_neck
  5. class BaseModel(nn.Module):
  6. def __init__(self, config, **kwargs):
  7. """
  8. the module for OCR.
  9. args:
  10. config (dict): the super parameters for module.
  11. """
  12. super(BaseModel, self).__init__()
  13. in_channels = config.get("in_channels", 3)
  14. model_type = config["model_type"]
  15. # build backbone, backbone is need for del, rec and cls
  16. if "Backbone" not in config or config["Backbone"] is None:
  17. self.use_backbone = False
  18. else:
  19. self.use_backbone = True
  20. config["Backbone"]["in_channels"] = in_channels
  21. self.backbone = build_backbone(config["Backbone"], model_type)
  22. in_channels = self.backbone.out_channels
  23. # build neck
  24. # for rec, neck can be cnn,rnn or reshape(None)
  25. # for det, neck can be FPN, BIFPN and so on.
  26. # for cls, neck should be none
  27. if "Neck" not in config or config["Neck"] is None:
  28. self.use_neck = False
  29. else:
  30. self.use_neck = True
  31. config["Neck"]["in_channels"] = in_channels
  32. self.neck = build_neck(config["Neck"])
  33. in_channels = self.neck.out_channels
  34. # # build head, head is need for det, rec and cls
  35. if "Head" not in config or config["Head"] is None:
  36. self.use_head = False
  37. else:
  38. self.use_head = True
  39. config["Head"]["in_channels"] = in_channels
  40. self.head = build_head(config["Head"], **kwargs)
  41. self.return_all_feats = config.get("return_all_feats", False)
  42. self._initialize_weights()
  43. def _initialize_weights(self):
  44. # weight initialization
  45. for m in self.modules():
  46. if isinstance(m, nn.Conv2d):
  47. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  48. if m.bias is not None:
  49. nn.init.zeros_(m.bias)
  50. elif isinstance(m, nn.BatchNorm2d):
  51. nn.init.ones_(m.weight)
  52. nn.init.zeros_(m.bias)
  53. elif isinstance(m, nn.Linear):
  54. nn.init.normal_(m.weight, 0, 0.01)
  55. if m.bias is not None:
  56. nn.init.zeros_(m.bias)
  57. elif isinstance(m, nn.ConvTranspose2d):
  58. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  59. if m.bias is not None:
  60. nn.init.zeros_(m.bias)
  61. def forward(self, x):
  62. y = dict()
  63. if self.use_backbone:
  64. x = self.backbone(x)
  65. if isinstance(x, dict):
  66. y.update(x)
  67. else:
  68. y["backbone_out"] = x
  69. final_name = "backbone_out"
  70. if self.use_neck:
  71. x = self.neck(x)
  72. if isinstance(x, dict):
  73. y.update(x)
  74. else:
  75. y["neck_out"] = x
  76. final_name = "neck_out"
  77. if self.use_head:
  78. x = self.head(x)
  79. # for multi head, save ctc neck out for udml
  80. if isinstance(x, dict) and "ctc_nect" in x.keys():
  81. y["neck_out"] = x["ctc_neck"]
  82. y["head_out"] = x
  83. elif isinstance(x, dict):
  84. y.update(x)
  85. else:
  86. y["head_out"] = x
  87. if self.return_all_feats:
  88. if self.training:
  89. return y
  90. elif isinstance(x, dict):
  91. return x
  92. else:
  93. return {final_name: x}
  94. else:
  95. return x