rec_svtrnet.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from ..common import Activation
  5. def drop_path(x, drop_prob=0.0, training=False):
  6. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  7. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  8. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
  9. """
  10. if drop_prob == 0.0 or not training:
  11. return x
  12. keep_prob = torch.as_tensor(1 - drop_prob)
  13. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  14. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
  15. random_tensor = torch.floor(random_tensor) # binarize
  16. output = x.divide(keep_prob) * random_tensor
  17. return output
  18. class ConvBNLayer(nn.Module):
  19. def __init__(
  20. self,
  21. in_channels,
  22. out_channels,
  23. kernel_size=3,
  24. stride=1,
  25. padding=0,
  26. bias_attr=False,
  27. groups=1,
  28. act="gelu",
  29. ):
  30. super().__init__()
  31. self.conv = nn.Conv2d(
  32. in_channels=in_channels,
  33. out_channels=out_channels,
  34. kernel_size=kernel_size,
  35. stride=stride,
  36. padding=padding,
  37. groups=groups,
  38. bias=bias_attr,
  39. )
  40. self.norm = nn.BatchNorm2d(out_channels)
  41. self.act = Activation(act_type=act, inplace=True)
  42. def forward(self, inputs):
  43. out = self.conv(inputs)
  44. out = self.norm(out)
  45. out = self.act(out)
  46. return out
  47. class DropPath(nn.Module):
  48. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  49. def __init__(self, drop_prob=None):
  50. super(DropPath, self).__init__()
  51. self.drop_prob = drop_prob
  52. def forward(self, x):
  53. return drop_path(x, self.drop_prob, self.training)
  54. class Identity(nn.Module):
  55. def __init__(self):
  56. super(Identity, self).__init__()
  57. def forward(self, input):
  58. return input
  59. class Mlp(nn.Module):
  60. def __init__(
  61. self,
  62. in_features,
  63. hidden_features=None,
  64. out_features=None,
  65. act_layer="gelu",
  66. drop=0.0,
  67. ):
  68. super().__init__()
  69. out_features = out_features or in_features
  70. hidden_features = hidden_features or in_features
  71. self.fc1 = nn.Linear(in_features, hidden_features)
  72. self.act = Activation(act_type=act_layer, inplace=True)
  73. self.fc2 = nn.Linear(hidden_features, out_features)
  74. self.drop = nn.Dropout(drop)
  75. def forward(self, x):
  76. x = self.fc1(x)
  77. x = self.act(x)
  78. x = self.drop(x)
  79. x = self.fc2(x)
  80. x = self.drop(x)
  81. return x
  82. class ConvMixer(nn.Module):
  83. def __init__(
  84. self,
  85. dim,
  86. num_heads=8,
  87. HW=[8, 25],
  88. local_k=[3, 3],
  89. ):
  90. super().__init__()
  91. self.HW = HW
  92. self.dim = dim
  93. self.local_mixer = nn.Conv2d(
  94. dim,
  95. dim,
  96. local_k,
  97. 1,
  98. [local_k[0] // 2, local_k[1] // 2],
  99. groups=num_heads,
  100. )
  101. def forward(self, x):
  102. h = self.HW[0]
  103. w = self.HW[1]
  104. x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
  105. x = self.local_mixer(x)
  106. x = x.flatten(2).permute(0, 2, 1)
  107. return x
  108. class Attention(nn.Module):
  109. def __init__(
  110. self,
  111. dim,
  112. num_heads=8,
  113. mixer="Global",
  114. HW=[8, 25],
  115. local_k=[7, 11],
  116. qkv_bias=False,
  117. qk_scale=None,
  118. attn_drop=0.0,
  119. proj_drop=0.0,
  120. ):
  121. super().__init__()
  122. self.num_heads = num_heads
  123. head_dim = dim // num_heads
  124. self.scale = qk_scale or head_dim**-0.5
  125. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  126. self.attn_drop = nn.Dropout(attn_drop)
  127. self.proj = nn.Linear(dim, dim)
  128. self.proj_drop = nn.Dropout(proj_drop)
  129. self.HW = HW
  130. if HW is not None:
  131. H = HW[0]
  132. W = HW[1]
  133. self.N = H * W
  134. self.C = dim
  135. if mixer == "Local" and HW is not None:
  136. hk = local_k[0]
  137. wk = local_k[1]
  138. mask = torch.ones(H * W, H + hk - 1, W + wk - 1, dtype=torch.float32)
  139. for h in range(0, H):
  140. for w in range(0, W):
  141. mask[h * W + w, h : h + hk, w : w + wk] = 0.0
  142. mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(
  143. 1
  144. )
  145. mask_inf = torch.full(
  146. [H * W, H * W], fill_value=float("-Inf"), dtype=torch.float32
  147. )
  148. mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
  149. self.mask = mask.unsqueeze(0).unsqueeze(1)
  150. # self.mask = mask[None, None, :]
  151. self.mixer = mixer
  152. def forward(self, x):
  153. if self.HW is not None:
  154. N = self.N
  155. C = self.C
  156. else:
  157. _, N, C = x.shape
  158. qkv = self.qkv(x)
  159. qkv = qkv.reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute(
  160. 2, 0, 3, 1, 4
  161. )
  162. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  163. attn = q.matmul(k.permute(0, 1, 3, 2))
  164. if self.mixer == "Local":
  165. attn += self.mask
  166. attn = nn.functional.softmax(attn, dim=-1)
  167. attn = self.attn_drop(attn)
  168. x = (attn.matmul(v)).permute(0, 2, 1, 3).reshape((-1, N, C))
  169. x = self.proj(x)
  170. x = self.proj_drop(x)
  171. return x
  172. class Block(nn.Module):
  173. def __init__(
  174. self,
  175. dim,
  176. num_heads,
  177. mixer="Global",
  178. local_mixer=[7, 11],
  179. HW=None,
  180. mlp_ratio=4.0,
  181. qkv_bias=False,
  182. qk_scale=None,
  183. drop=0.0,
  184. attn_drop=0.0,
  185. drop_path=0.0,
  186. act_layer="gelu",
  187. norm_layer="nn.LayerNorm",
  188. epsilon=1e-6,
  189. prenorm=True,
  190. ):
  191. super().__init__()
  192. if isinstance(norm_layer, str):
  193. self.norm1 = eval(norm_layer)(dim, eps=epsilon)
  194. else:
  195. self.norm1 = norm_layer(dim)
  196. if mixer == "Global" or mixer == "Local":
  197. self.mixer = Attention(
  198. dim,
  199. num_heads=num_heads,
  200. mixer=mixer,
  201. HW=HW,
  202. local_k=local_mixer,
  203. qkv_bias=qkv_bias,
  204. qk_scale=qk_scale,
  205. attn_drop=attn_drop,
  206. proj_drop=drop,
  207. )
  208. elif mixer == "Conv":
  209. self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
  210. else:
  211. raise TypeError("The mixer must be one of [Global, Local, Conv]")
  212. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
  213. if isinstance(norm_layer, str):
  214. self.norm2 = eval(norm_layer)(dim, eps=epsilon)
  215. else:
  216. self.norm2 = norm_layer(dim)
  217. mlp_hidden_dim = int(dim * mlp_ratio)
  218. self.mlp_ratio = mlp_ratio
  219. self.mlp = Mlp(
  220. in_features=dim,
  221. hidden_features=mlp_hidden_dim,
  222. act_layer=act_layer,
  223. drop=drop,
  224. )
  225. self.prenorm = prenorm
  226. def forward(self, x):
  227. if self.prenorm:
  228. x = self.norm1(x + self.drop_path(self.mixer(x)))
  229. x = self.norm2(x + self.drop_path(self.mlp(x)))
  230. else:
  231. x = x + self.drop_path(self.mixer(self.norm1(x)))
  232. x = x + self.drop_path(self.mlp(self.norm2(x)))
  233. return x
  234. class PatchEmbed(nn.Module):
  235. """Image to Patch Embedding"""
  236. def __init__(
  237. self,
  238. img_size=[32, 100],
  239. in_channels=3,
  240. embed_dim=768,
  241. sub_num=2,
  242. patch_size=[4, 4],
  243. mode="pope",
  244. ):
  245. super().__init__()
  246. num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
  247. self.img_size = img_size
  248. self.num_patches = num_patches
  249. self.embed_dim = embed_dim
  250. self.norm = None
  251. if mode == "pope":
  252. if sub_num == 2:
  253. self.proj = nn.Sequential(
  254. ConvBNLayer(
  255. in_channels=in_channels,
  256. out_channels=embed_dim // 2,
  257. kernel_size=3,
  258. stride=2,
  259. padding=1,
  260. act="gelu",
  261. bias_attr=True,
  262. ),
  263. ConvBNLayer(
  264. in_channels=embed_dim // 2,
  265. out_channels=embed_dim,
  266. kernel_size=3,
  267. stride=2,
  268. padding=1,
  269. act="gelu",
  270. bias_attr=True,
  271. ),
  272. )
  273. if sub_num == 3:
  274. self.proj = nn.Sequential(
  275. ConvBNLayer(
  276. in_channels=in_channels,
  277. out_channels=embed_dim // 4,
  278. kernel_size=3,
  279. stride=2,
  280. padding=1,
  281. act="gelu",
  282. bias_attr=True,
  283. ),
  284. ConvBNLayer(
  285. in_channels=embed_dim // 4,
  286. out_channels=embed_dim // 2,
  287. kernel_size=3,
  288. stride=2,
  289. padding=1,
  290. act="gelu",
  291. bias_attr=True,
  292. ),
  293. ConvBNLayer(
  294. in_channels=embed_dim // 2,
  295. out_channels=embed_dim,
  296. kernel_size=3,
  297. stride=2,
  298. padding=1,
  299. act="gelu",
  300. bias_attr=True,
  301. ),
  302. )
  303. elif mode == "linear":
  304. self.proj = nn.Conv2d(
  305. 1, embed_dim, kernel_size=patch_size, stride=patch_size
  306. )
  307. self.num_patches = (
  308. img_size[0] // patch_size[0] * img_size[1] // patch_size[1]
  309. )
  310. def forward(self, x):
  311. B, C, H, W = x.shape
  312. assert (
  313. H == self.img_size[0] and W == self.img_size[1]
  314. ), "Input image size ({}*{}) doesn't match model ({}*{}).".format(
  315. H, W, self.img_size[0], self.img_size[1]
  316. )
  317. x = self.proj(x).flatten(2).permute(0, 2, 1)
  318. return x
  319. class SubSample(nn.Module):
  320. def __init__(
  321. self,
  322. in_channels,
  323. out_channels,
  324. types="Pool",
  325. stride=[2, 1],
  326. sub_norm="nn.LayerNorm",
  327. act=None,
  328. ):
  329. super().__init__()
  330. self.types = types
  331. if types == "Pool":
  332. self.avgpool = nn.AvgPool2d(
  333. kernel_size=[3, 5], stride=stride, padding=[1, 2]
  334. )
  335. self.maxpool = nn.MaxPool2d(
  336. kernel_size=[3, 5], stride=stride, padding=[1, 2]
  337. )
  338. self.proj = nn.Linear(in_channels, out_channels)
  339. else:
  340. self.conv = nn.Conv2d(
  341. in_channels,
  342. out_channels,
  343. kernel_size=3,
  344. stride=stride,
  345. padding=1,
  346. )
  347. self.norm = eval(sub_norm)(out_channels)
  348. if act is not None:
  349. self.act = act()
  350. else:
  351. self.act = None
  352. def forward(self, x):
  353. if self.types == "Pool":
  354. x1 = self.avgpool(x)
  355. x2 = self.maxpool(x)
  356. x = (x1 + x2) * 0.5
  357. out = self.proj(x.flatten(2).permute(0, 2, 1))
  358. else:
  359. x = self.conv(x)
  360. out = x.flatten(2).permute(0, 2, 1)
  361. out = self.norm(out)
  362. if self.act is not None:
  363. out = self.act(out)
  364. return out
  365. class SVTRNet(nn.Module):
  366. def __init__(
  367. self,
  368. img_size=[32, 100],
  369. in_channels=3,
  370. embed_dim=[64, 128, 256],
  371. depth=[3, 6, 3],
  372. num_heads=[2, 4, 8],
  373. mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
  374. local_mixer=[[7, 11], [7, 11], [7, 11]],
  375. patch_merging="Conv", # Conv, Pool, None
  376. mlp_ratio=4,
  377. qkv_bias=True,
  378. qk_scale=None,
  379. drop_rate=0.0,
  380. last_drop=0.0,
  381. attn_drop_rate=0.0,
  382. drop_path_rate=0.1,
  383. norm_layer="nn.LayerNorm",
  384. sub_norm="nn.LayerNorm",
  385. epsilon=1e-6,
  386. out_channels=192,
  387. out_char_num=25,
  388. block_unit="Block",
  389. act="gelu",
  390. last_stage=True,
  391. sub_num=2,
  392. prenorm=True,
  393. use_lenhead=False,
  394. **kwargs
  395. ):
  396. super().__init__()
  397. self.img_size = img_size
  398. self.embed_dim = embed_dim
  399. self.out_channels = out_channels
  400. self.prenorm = prenorm
  401. patch_merging = (
  402. None
  403. if patch_merging != "Conv" and patch_merging != "Pool"
  404. else patch_merging
  405. )
  406. self.patch_embed = PatchEmbed(
  407. img_size=img_size,
  408. in_channels=in_channels,
  409. embed_dim=embed_dim[0],
  410. sub_num=sub_num,
  411. )
  412. num_patches = self.patch_embed.num_patches
  413. self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
  414. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
  415. self.pos_drop = nn.Dropout(p=drop_rate)
  416. Block_unit = eval(block_unit)
  417. dpr = np.linspace(0, drop_path_rate, sum(depth))
  418. self.blocks1 = nn.ModuleList(
  419. [
  420. Block_unit(
  421. dim=embed_dim[0],
  422. num_heads=num_heads[0],
  423. mixer=mixer[0 : depth[0]][i],
  424. HW=self.HW,
  425. local_mixer=local_mixer[0],
  426. mlp_ratio=mlp_ratio,
  427. qkv_bias=qkv_bias,
  428. qk_scale=qk_scale,
  429. drop=drop_rate,
  430. act_layer=act,
  431. attn_drop=attn_drop_rate,
  432. drop_path=dpr[0 : depth[0]][i],
  433. norm_layer=norm_layer,
  434. epsilon=epsilon,
  435. prenorm=prenorm,
  436. )
  437. for i in range(depth[0])
  438. ]
  439. )
  440. if patch_merging is not None:
  441. self.sub_sample1 = SubSample(
  442. embed_dim[0],
  443. embed_dim[1],
  444. sub_norm=sub_norm,
  445. stride=[2, 1],
  446. types=patch_merging,
  447. )
  448. HW = [self.HW[0] // 2, self.HW[1]]
  449. else:
  450. HW = self.HW
  451. self.patch_merging = patch_merging
  452. self.blocks2 = nn.ModuleList(
  453. [
  454. Block_unit(
  455. dim=embed_dim[1],
  456. num_heads=num_heads[1],
  457. mixer=mixer[depth[0] : depth[0] + depth[1]][i],
  458. HW=HW,
  459. local_mixer=local_mixer[1],
  460. mlp_ratio=mlp_ratio,
  461. qkv_bias=qkv_bias,
  462. qk_scale=qk_scale,
  463. drop=drop_rate,
  464. act_layer=act,
  465. attn_drop=attn_drop_rate,
  466. drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
  467. norm_layer=norm_layer,
  468. epsilon=epsilon,
  469. prenorm=prenorm,
  470. )
  471. for i in range(depth[1])
  472. ]
  473. )
  474. if patch_merging is not None:
  475. self.sub_sample2 = SubSample(
  476. embed_dim[1],
  477. embed_dim[2],
  478. sub_norm=sub_norm,
  479. stride=[2, 1],
  480. types=patch_merging,
  481. )
  482. HW = [self.HW[0] // 4, self.HW[1]]
  483. else:
  484. HW = self.HW
  485. self.blocks3 = nn.ModuleList(
  486. [
  487. Block_unit(
  488. dim=embed_dim[2],
  489. num_heads=num_heads[2],
  490. mixer=mixer[depth[0] + depth[1] :][i],
  491. HW=HW,
  492. local_mixer=local_mixer[2],
  493. mlp_ratio=mlp_ratio,
  494. qkv_bias=qkv_bias,
  495. qk_scale=qk_scale,
  496. drop=drop_rate,
  497. act_layer=act,
  498. attn_drop=attn_drop_rate,
  499. drop_path=dpr[depth[0] + depth[1] :][i],
  500. norm_layer=norm_layer,
  501. epsilon=epsilon,
  502. prenorm=prenorm,
  503. )
  504. for i in range(depth[2])
  505. ]
  506. )
  507. self.last_stage = last_stage
  508. if last_stage:
  509. self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num])
  510. self.last_conv = nn.Conv2d(
  511. in_channels=embed_dim[2],
  512. out_channels=self.out_channels,
  513. kernel_size=1,
  514. stride=1,
  515. padding=0,
  516. bias=False,
  517. )
  518. self.hardswish = Activation("hard_swish", inplace=True) # nn.Hardswish()
  519. # self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
  520. self.dropout = nn.Dropout(p=last_drop)
  521. if not prenorm:
  522. self.norm = eval(norm_layer)(embed_dim[-1], eps=epsilon)
  523. self.use_lenhead = use_lenhead
  524. if use_lenhead:
  525. self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
  526. self.hardswish_len = Activation(
  527. "hard_swish", inplace=True
  528. ) # nn.Hardswish()
  529. self.dropout_len = nn.Dropout(p=last_drop)
  530. torch.nn.init.xavier_normal_(self.pos_embed)
  531. self.apply(self._init_weights)
  532. def _init_weights(self, m):
  533. # weight initialization
  534. if isinstance(m, nn.Conv2d):
  535. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  536. if m.bias is not None:
  537. nn.init.zeros_(m.bias)
  538. elif isinstance(m, nn.BatchNorm2d):
  539. nn.init.ones_(m.weight)
  540. nn.init.zeros_(m.bias)
  541. elif isinstance(m, nn.Linear):
  542. nn.init.normal_(m.weight, 0, 0.01)
  543. if m.bias is not None:
  544. nn.init.zeros_(m.bias)
  545. elif isinstance(m, nn.ConvTranspose2d):
  546. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  547. if m.bias is not None:
  548. nn.init.zeros_(m.bias)
  549. elif isinstance(m, nn.LayerNorm):
  550. nn.init.ones_(m.weight)
  551. nn.init.zeros_(m.bias)
  552. def forward_features(self, x):
  553. x = self.patch_embed(x)
  554. x = x + self.pos_embed
  555. x = self.pos_drop(x)
  556. for blk in self.blocks1:
  557. x = blk(x)
  558. if self.patch_merging is not None:
  559. x = self.sub_sample1(
  560. x.permute(0, 2, 1).reshape(
  561. [-1, self.embed_dim[0], self.HW[0], self.HW[1]]
  562. )
  563. )
  564. for blk in self.blocks2:
  565. x = blk(x)
  566. if self.patch_merging is not None:
  567. x = self.sub_sample2(
  568. x.permute(0, 2, 1).reshape(
  569. [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]
  570. )
  571. )
  572. for blk in self.blocks3:
  573. x = blk(x)
  574. if not self.prenorm:
  575. x = self.norm(x)
  576. return x
  577. def forward(self, x):
  578. x = self.forward_features(x)
  579. if self.use_lenhead:
  580. len_x = self.len_conv(x.mean(1))
  581. len_x = self.dropout_len(self.hardswish_len(len_x))
  582. if self.last_stage:
  583. if self.patch_merging is not None:
  584. h = self.HW[0] // 4
  585. else:
  586. h = self.HW[0]
  587. x = self.avg_pool(
  588. x.permute(0, 2, 1).reshape([-1, self.embed_dim[2], h, self.HW[1]])
  589. )
  590. x = self.last_conv(x)
  591. x = self.hardswish(x)
  592. x = self.dropout(x)
  593. if self.use_lenhead:
  594. return x, len_x
  595. return x