rnn.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import torch
  2. from torch import nn
  3. from ..backbones.rec_svtrnet import Block, ConvBNLayer
  4. class Im2Seq(nn.Module):
  5. def __init__(self, in_channels, **kwargs):
  6. super().__init__()
  7. self.out_channels = in_channels
  8. def forward(self, x):
  9. B, C, H, W = x.shape
  10. # assert H == 1
  11. x = x.squeeze(dim=2)
  12. # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
  13. x = x.permute(0, 2, 1)
  14. return x
  15. # def forward(self, x):
  16. # B, C, H, W = x.shape
  17. # # 处理四维张量,将空间维度展平为序列
  18. # if H == 1:
  19. # # 原来的处理逻辑,适用于H=1的情况
  20. # x = x.squeeze(dim=2)
  21. # x = x.permute(0, 2, 1) # (B, W, C)
  22. # else:
  23. # # 处理H不为1的情况
  24. # x = x.permute(0, 2, 3, 1) # (B, H, W, C)
  25. # x = x.reshape(B, H * W, C) # (B, H*W, C)
  26. #
  27. # return x
  28. class EncoderWithRNN_(nn.Module):
  29. def __init__(self, in_channels, hidden_size):
  30. super(EncoderWithRNN_, self).__init__()
  31. self.out_channels = hidden_size * 2
  32. self.rnn1 = nn.LSTM(
  33. in_channels,
  34. hidden_size,
  35. bidirectional=False,
  36. batch_first=True,
  37. num_layers=2,
  38. )
  39. self.rnn2 = nn.LSTM(
  40. in_channels,
  41. hidden_size,
  42. bidirectional=False,
  43. batch_first=True,
  44. num_layers=2,
  45. )
  46. def forward(self, x):
  47. self.rnn1.flatten_parameters()
  48. self.rnn2.flatten_parameters()
  49. out1, h1 = self.rnn1(x)
  50. out2, h2 = self.rnn2(torch.flip(x, [1]))
  51. return torch.cat([out1, torch.flip(out2, [1])], 2)
  52. class EncoderWithRNN(nn.Module):
  53. def __init__(self, in_channels, hidden_size):
  54. super(EncoderWithRNN, self).__init__()
  55. self.out_channels = hidden_size * 2
  56. self.lstm = nn.LSTM(
  57. in_channels, hidden_size, num_layers=2, batch_first=True, bidirectional=True
  58. ) # batch_first:=True
  59. def forward(self, x):
  60. x, _ = self.lstm(x)
  61. return x
  62. class EncoderWithFC(nn.Module):
  63. def __init__(self, in_channels, hidden_size):
  64. super(EncoderWithFC, self).__init__()
  65. self.out_channels = hidden_size
  66. self.fc = nn.Linear(
  67. in_channels,
  68. hidden_size,
  69. bias=True,
  70. )
  71. def forward(self, x):
  72. x = self.fc(x)
  73. return x
  74. class EncoderWithSVTR(nn.Module):
  75. def __init__(
  76. self,
  77. in_channels,
  78. dims=64, # XS
  79. depth=2,
  80. hidden_dims=120,
  81. use_guide=False,
  82. num_heads=8,
  83. qkv_bias=True,
  84. mlp_ratio=2.0,
  85. drop_rate=0.1,
  86. kernel_size=[3, 3],
  87. attn_drop_rate=0.1,
  88. drop_path=0.0,
  89. qk_scale=None,
  90. ):
  91. super(EncoderWithSVTR, self).__init__()
  92. self.depth = depth
  93. self.use_guide = use_guide
  94. self.conv1 = ConvBNLayer(
  95. in_channels,
  96. in_channels // 8,
  97. kernel_size=kernel_size,
  98. padding=[kernel_size[0] // 2, kernel_size[1] // 2],
  99. act="swish",
  100. )
  101. self.conv2 = ConvBNLayer(
  102. in_channels // 8, hidden_dims, kernel_size=1, act="swish"
  103. )
  104. self.svtr_block = nn.ModuleList(
  105. [
  106. Block(
  107. dim=hidden_dims,
  108. num_heads=num_heads,
  109. mixer="Global",
  110. HW=None,
  111. mlp_ratio=mlp_ratio,
  112. qkv_bias=qkv_bias,
  113. qk_scale=qk_scale,
  114. drop=drop_rate,
  115. act_layer="swish",
  116. attn_drop=attn_drop_rate,
  117. drop_path=drop_path,
  118. norm_layer="nn.LayerNorm",
  119. epsilon=1e-05,
  120. prenorm=False,
  121. )
  122. for i in range(depth)
  123. ]
  124. )
  125. self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
  126. self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
  127. # last conv-nxn, the input is concat of input tensor and conv3 output tensor
  128. self.conv4 = ConvBNLayer(
  129. 2 * in_channels, in_channels // 8, padding=1, act="swish"
  130. )
  131. self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
  132. self.out_channels = dims
  133. self.apply(self._init_weights)
  134. def _init_weights(self, m):
  135. # weight initialization
  136. if isinstance(m, nn.Conv2d):
  137. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  138. if m.bias is not None:
  139. nn.init.zeros_(m.bias)
  140. elif isinstance(m, nn.BatchNorm2d):
  141. nn.init.ones_(m.weight)
  142. nn.init.zeros_(m.bias)
  143. elif isinstance(m, nn.Linear):
  144. nn.init.normal_(m.weight, 0, 0.01)
  145. if m.bias is not None:
  146. nn.init.zeros_(m.bias)
  147. elif isinstance(m, nn.ConvTranspose2d):
  148. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  149. if m.bias is not None:
  150. nn.init.zeros_(m.bias)
  151. elif isinstance(m, nn.LayerNorm):
  152. nn.init.ones_(m.weight)
  153. nn.init.zeros_(m.bias)
  154. def forward(self, x):
  155. # for use guide
  156. if self.use_guide:
  157. z = x.clone()
  158. z.stop_gradient = True
  159. else:
  160. z = x
  161. # for short cut
  162. h = z
  163. # reduce dim
  164. z = self.conv1(z)
  165. z = self.conv2(z)
  166. # SVTR global block
  167. B, C, H, W = z.shape
  168. z = z.flatten(2).permute(0, 2, 1)
  169. for blk in self.svtr_block:
  170. z = blk(z)
  171. z = self.norm(z)
  172. # last stage
  173. z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
  174. z = self.conv3(z)
  175. z = torch.cat((h, z), dim=1)
  176. z = self.conv1x1(self.conv4(z))
  177. return z
  178. class SequenceEncoder(nn.Module):
  179. def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
  180. super(SequenceEncoder, self).__init__()
  181. self.encoder_reshape = Im2Seq(in_channels)
  182. self.out_channels = self.encoder_reshape.out_channels
  183. self.encoder_type = encoder_type
  184. if encoder_type == "reshape":
  185. self.only_reshape = True
  186. else:
  187. support_encoder_dict = {
  188. "reshape": Im2Seq,
  189. "fc": EncoderWithFC,
  190. "rnn": EncoderWithRNN,
  191. "svtr": EncoderWithSVTR,
  192. }
  193. assert encoder_type in support_encoder_dict, "{} must in {}".format(
  194. encoder_type, support_encoder_dict.keys()
  195. )
  196. if encoder_type == "svtr":
  197. self.encoder = support_encoder_dict[encoder_type](
  198. self.encoder_reshape.out_channels, **kwargs
  199. )
  200. else:
  201. self.encoder = support_encoder_dict[encoder_type](
  202. self.encoder_reshape.out_channels, hidden_size
  203. )
  204. self.out_channels = self.encoder.out_channels
  205. self.only_reshape = False
  206. def forward(self, x):
  207. if self.encoder_type != "svtr":
  208. x = self.encoder_reshape(x)
  209. if not self.only_reshape:
  210. x = self.encoder(x)
  211. return x
  212. else:
  213. x = self.encoder(x)
  214. x = self.encoder_reshape(x)
  215. return x