rnn.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import torch
  2. from torch import nn
  3. from ..backbones.rec_svtrnet import Block, ConvBNLayer
  4. class Im2Seq(nn.Module):
  5. def __init__(self, in_channels, **kwargs):
  6. super().__init__()
  7. self.out_channels = in_channels
  8. # def forward(self, x):
  9. # B, C, H, W = x.shape
  10. # # assert H == 1
  11. # x = x.squeeze(dim=2)
  12. # # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
  13. # x = x.permute(0, 2, 1)
  14. # return x
  15. def forward(self, x):
  16. B, C, H, W = x.shape
  17. # 处理四维张量,将空间维度展平为序列
  18. if H == 1:
  19. # 原来的处理逻辑,适用于H=1的情况
  20. x = x.squeeze(dim=2)
  21. x = x.permute(0, 2, 1) # (B, W, C)
  22. else:
  23. # 处理H不为1的情况
  24. x = x.permute(0, 2, 3, 1) # (B, H, W, C)
  25. x = x.reshape(B, H * W, C) # (B, H*W, C)
  26. return x
  27. class EncoderWithRNN_(nn.Module):
  28. def __init__(self, in_channels, hidden_size):
  29. super(EncoderWithRNN_, self).__init__()
  30. self.out_channels = hidden_size * 2
  31. self.rnn1 = nn.LSTM(
  32. in_channels,
  33. hidden_size,
  34. bidirectional=False,
  35. batch_first=True,
  36. num_layers=2,
  37. )
  38. self.rnn2 = nn.LSTM(
  39. in_channels,
  40. hidden_size,
  41. bidirectional=False,
  42. batch_first=True,
  43. num_layers=2,
  44. )
  45. def forward(self, x):
  46. self.rnn1.flatten_parameters()
  47. self.rnn2.flatten_parameters()
  48. out1, h1 = self.rnn1(x)
  49. out2, h2 = self.rnn2(torch.flip(x, [1]))
  50. return torch.cat([out1, torch.flip(out2, [1])], 2)
  51. class EncoderWithRNN(nn.Module):
  52. def __init__(self, in_channels, hidden_size):
  53. super(EncoderWithRNN, self).__init__()
  54. self.out_channels = hidden_size * 2
  55. self.lstm = nn.LSTM(
  56. in_channels, hidden_size, num_layers=2, batch_first=True, bidirectional=True
  57. ) # batch_first:=True
  58. def forward(self, x):
  59. x, _ = self.lstm(x)
  60. return x
  61. class EncoderWithFC(nn.Module):
  62. def __init__(self, in_channels, hidden_size):
  63. super(EncoderWithFC, self).__init__()
  64. self.out_channels = hidden_size
  65. self.fc = nn.Linear(
  66. in_channels,
  67. hidden_size,
  68. bias=True,
  69. )
  70. def forward(self, x):
  71. x = self.fc(x)
  72. return x
  73. class EncoderWithSVTR(nn.Module):
  74. def __init__(
  75. self,
  76. in_channels,
  77. dims=64, # XS
  78. depth=2,
  79. hidden_dims=120,
  80. use_guide=False,
  81. num_heads=8,
  82. qkv_bias=True,
  83. mlp_ratio=2.0,
  84. drop_rate=0.1,
  85. kernel_size=[3, 3],
  86. attn_drop_rate=0.1,
  87. drop_path=0.0,
  88. qk_scale=None,
  89. ):
  90. super(EncoderWithSVTR, self).__init__()
  91. self.depth = depth
  92. self.use_guide = use_guide
  93. self.conv1 = ConvBNLayer(
  94. in_channels,
  95. in_channels // 8,
  96. kernel_size=kernel_size,
  97. padding=[kernel_size[0] // 2, kernel_size[1] // 2],
  98. act="swish",
  99. )
  100. self.conv2 = ConvBNLayer(
  101. in_channels // 8, hidden_dims, kernel_size=1, act="swish"
  102. )
  103. self.svtr_block = nn.ModuleList(
  104. [
  105. Block(
  106. dim=hidden_dims,
  107. num_heads=num_heads,
  108. mixer="Global",
  109. HW=None,
  110. mlp_ratio=mlp_ratio,
  111. qkv_bias=qkv_bias,
  112. qk_scale=qk_scale,
  113. drop=drop_rate,
  114. act_layer="swish",
  115. attn_drop=attn_drop_rate,
  116. drop_path=drop_path,
  117. norm_layer="nn.LayerNorm",
  118. epsilon=1e-05,
  119. prenorm=False,
  120. )
  121. for i in range(depth)
  122. ]
  123. )
  124. self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
  125. self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
  126. # last conv-nxn, the input is concat of input tensor and conv3 output tensor
  127. self.conv4 = ConvBNLayer(
  128. 2 * in_channels, in_channels // 8, padding=1, act="swish"
  129. )
  130. self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
  131. self.out_channels = dims
  132. self.apply(self._init_weights)
  133. def _init_weights(self, m):
  134. # weight initialization
  135. if isinstance(m, nn.Conv2d):
  136. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  137. if m.bias is not None:
  138. nn.init.zeros_(m.bias)
  139. elif isinstance(m, nn.BatchNorm2d):
  140. nn.init.ones_(m.weight)
  141. nn.init.zeros_(m.bias)
  142. elif isinstance(m, nn.Linear):
  143. nn.init.normal_(m.weight, 0, 0.01)
  144. if m.bias is not None:
  145. nn.init.zeros_(m.bias)
  146. elif isinstance(m, nn.ConvTranspose2d):
  147. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  148. if m.bias is not None:
  149. nn.init.zeros_(m.bias)
  150. elif isinstance(m, nn.LayerNorm):
  151. nn.init.ones_(m.weight)
  152. nn.init.zeros_(m.bias)
  153. def forward(self, x):
  154. # for use guide
  155. if self.use_guide:
  156. z = x.clone()
  157. z.stop_gradient = True
  158. else:
  159. z = x
  160. # for short cut
  161. h = z
  162. # reduce dim
  163. z = self.conv1(z)
  164. z = self.conv2(z)
  165. # SVTR global block
  166. B, C, H, W = z.shape
  167. z = z.flatten(2).permute(0, 2, 1)
  168. for blk in self.svtr_block:
  169. z = blk(z)
  170. z = self.norm(z)
  171. # last stage
  172. z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
  173. z = self.conv3(z)
  174. z = torch.cat((h, z), dim=1)
  175. z = self.conv1x1(self.conv4(z))
  176. return z
  177. class SequenceEncoder(nn.Module):
  178. def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
  179. super(SequenceEncoder, self).__init__()
  180. self.encoder_reshape = Im2Seq(in_channels)
  181. self.out_channels = self.encoder_reshape.out_channels
  182. self.encoder_type = encoder_type
  183. if encoder_type == "reshape":
  184. self.only_reshape = True
  185. else:
  186. support_encoder_dict = {
  187. "reshape": Im2Seq,
  188. "fc": EncoderWithFC,
  189. "rnn": EncoderWithRNN,
  190. "svtr": EncoderWithSVTR,
  191. }
  192. assert encoder_type in support_encoder_dict, "{} must in {}".format(
  193. encoder_type, support_encoder_dict.keys()
  194. )
  195. if encoder_type == "svtr":
  196. self.encoder = support_encoder_dict[encoder_type](
  197. self.encoder_reshape.out_channels, **kwargs
  198. )
  199. else:
  200. self.encoder = support_encoder_dict[encoder_type](
  201. self.encoder_reshape.out_channels, hidden_size
  202. )
  203. self.out_channels = self.encoder.out_channels
  204. self.only_reshape = False
  205. def forward(self, x):
  206. if self.encoder_type != "svtr":
  207. x = self.encoder_reshape(x)
  208. if not self.only_reshape:
  209. x = self.encoder(x)
  210. return x
  211. else:
  212. x = self.encoder(x)
  213. x = self.encoder_reshape(x)
  214. return x