rec_pphgnetv2.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import numpy as np
  6. from .rec_donut_swin import DonutSwinModelOutput
  7. from typing import List, Dict, Union, Callable
  8. class IdentityBasedConv1x1(nn.Conv2d):
  9. def __init__(self, channels, groups=1):
  10. super(IdentityBasedConv1x1, self).__init__(
  11. in_channels=channels,
  12. out_channels=channels,
  13. kernel_size=1,
  14. stride=1,
  15. padding=0,
  16. groups=groups,
  17. bias_attr=False,
  18. )
  19. assert channels % groups == 0
  20. input_dim = channels // groups
  21. id_value = np.zeros((channels, input_dim, 1, 1))
  22. for i in range(channels):
  23. id_value[i, i % input_dim, 0, 0] = 1
  24. self.id_tensor = torch.Tensor(id_value)
  25. self.weight.set_value(torch.zeros_like(self.weight))
  26. def forward(self, input):
  27. kernel = self.weight + self.id_tensor
  28. result = F.conv2d(
  29. input,
  30. kernel,
  31. None,
  32. stride=1,
  33. padding=0,
  34. dilation=self._dilation,
  35. groups=self._groups,
  36. )
  37. return result
  38. def get_actual_kernel(self):
  39. return self.weight + self.id_tensor
  40. class BNAndPad(nn.Module):
  41. def __init__(
  42. self,
  43. pad_pixels,
  44. num_features,
  45. epsilon=1e-5,
  46. momentum=0.1,
  47. last_conv_bias=None,
  48. bn=nn.BatchNorm2d,
  49. ):
  50. super().__init__()
  51. self.bn = bn(num_features, momentum=momentum, epsilon=epsilon)
  52. self.pad_pixels = pad_pixels
  53. self.last_conv_bias = last_conv_bias
  54. def forward(self, input):
  55. output = self.bn(input)
  56. if self.pad_pixels > 0:
  57. bias = -self.bn._mean
  58. if self.last_conv_bias is not None:
  59. bias += self.last_conv_bias
  60. pad_values = self.bn.bias + self.bn.weight * (
  61. bias / torch.sqrt(self.bn._variance + self.bn._epsilon)
  62. )
  63. """ pad """
  64. # TODO: n,h,w,c format is not supported yet
  65. n, c, h, w = output.shape
  66. values = pad_values.reshape([1, -1, 1, 1])
  67. w_values = values.expand([n, -1, self.pad_pixels, w])
  68. x = torch.cat([w_values, output, w_values], dim=2)
  69. h = h + self.pad_pixels * 2
  70. h_values = values.expand([n, -1, h, self.pad_pixels])
  71. x = torch.cat([h_values, x, h_values], dim=3)
  72. output = x
  73. return output
  74. @property
  75. def weight(self):
  76. return self.bn.weight
  77. @property
  78. def bias(self):
  79. return self.bn.bias
  80. @property
  81. def _mean(self):
  82. return self.bn._mean
  83. @property
  84. def _variance(self):
  85. return self.bn._variance
  86. @property
  87. def _epsilon(self):
  88. return self.bn._epsilon
  89. def conv_bn(
  90. in_channels,
  91. out_channels,
  92. kernel_size,
  93. stride=1,
  94. padding=0,
  95. dilation=1,
  96. groups=1,
  97. padding_mode="zeros",
  98. ):
  99. conv_layer = nn.Conv2d(
  100. in_channels=in_channels,
  101. out_channels=out_channels,
  102. kernel_size=kernel_size,
  103. stride=stride,
  104. padding=padding,
  105. dilation=dilation,
  106. groups=groups,
  107. bias_attr=False,
  108. padding_mode=padding_mode,
  109. )
  110. bn_layer = nn.BatchNorm2D(num_features=out_channels)
  111. se = nn.Sequential()
  112. se.add_sublayer("conv", conv_layer)
  113. se.add_sublayer("bn", bn_layer)
  114. return se
  115. def transI_fusebn(kernel, bn):
  116. gamma = bn.weight
  117. std = (bn._variance + bn._epsilon).sqrt()
  118. return (
  119. kernel * ((gamma / std).reshape([-1, 1, 1, 1])),
  120. bn.bias - bn._mean * gamma / std,
  121. )
  122. def transII_addbranch(kernels, biases):
  123. return sum(kernels), sum(biases)
  124. def transIII_1x1_kxk(k1, b1, k2, b2, groups):
  125. if groups == 1:
  126. k = F.conv2d(k2, k1.transpose([1, 0, 2, 3]))
  127. b_hat = (k2 * b1.reshape([1, -1, 1, 1])).sum((1, 2, 3))
  128. else:
  129. k_slices = []
  130. b_slices = []
  131. k1_T = k1.transpose([1, 0, 2, 3])
  132. k1_group_width = k1.shape[0] // groups
  133. k2_group_width = k2.shape[0] // groups
  134. for g in range(groups):
  135. k1_T_slice = k1_T[:, g * k1_group_width : (g + 1) * k1_group_width, :, :]
  136. k2_slice = k2[g * k2_group_width : (g + 1) * k2_group_width, :, :, :]
  137. k_slices.append(F.conv2d(k2_slice, k1_T_slice))
  138. b_slices.append(
  139. (
  140. k2_slice
  141. * b1[g * k1_group_width : (g + 1) * k1_group_width].reshape(
  142. [1, -1, 1, 1]
  143. )
  144. ).sum((1, 2, 3))
  145. )
  146. k, b_hat = transIV_depthconcat(k_slices, b_slices)
  147. return k, b_hat + b2
  148. def transIV_depthconcat(kernels, biases):
  149. return torch.cat(kernels, dim=0), torch.cat(biases)
  150. def transV_avg(channels, kernel_size, groups):
  151. input_dim = channels // groups
  152. k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
  153. k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = (
  154. 1.0 / kernel_size**2
  155. )
  156. return k
  157. def transVI_multiscale(kernel, target_kernel_size):
  158. H_pixels_to_pad = (target_kernel_size - kernel.shape[2]) // 2
  159. W_pixels_to_pad = (target_kernel_size - kernel.shape[3]) // 2
  160. return F.pad(
  161. kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]
  162. )
  163. class DiverseBranchBlock(nn.Module):
  164. def __init__(
  165. self,
  166. num_channels,
  167. num_filters,
  168. filter_size,
  169. stride=1,
  170. groups=1,
  171. act=None,
  172. is_repped=False,
  173. single_init=False,
  174. **kwargs,
  175. ):
  176. super().__init__()
  177. padding = (filter_size - 1) // 2
  178. dilation = 1
  179. in_channels = num_channels
  180. out_channels = num_filters
  181. kernel_size = filter_size
  182. internal_channels_1x1_3x3 = None
  183. nonlinear = act
  184. self.is_repped = is_repped
  185. if nonlinear is None:
  186. self.nonlinear = nn.Identity()
  187. else:
  188. self.nonlinear = nn.ReLU()
  189. self.kernel_size = kernel_size
  190. self.out_channels = out_channels
  191. self.groups = groups
  192. assert padding == kernel_size // 2
  193. if is_repped:
  194. self.dbb_reparam = nn.Conv2d(
  195. in_channels=in_channels,
  196. out_channels=out_channels,
  197. kernel_size=kernel_size,
  198. stride=stride,
  199. padding=padding,
  200. dilation=dilation,
  201. groups=groups,
  202. bias=True,
  203. )
  204. else:
  205. self.dbb_origin = conv_bn(
  206. in_channels=in_channels,
  207. out_channels=out_channels,
  208. kernel_size=kernel_size,
  209. stride=stride,
  210. padding=padding,
  211. dilation=dilation,
  212. groups=groups,
  213. )
  214. self.dbb_avg = nn.Sequential()
  215. if groups < out_channels:
  216. self.dbb_avg.add_sublayer(
  217. "conv",
  218. nn.Conv2d(
  219. in_channels=in_channels,
  220. out_channels=out_channels,
  221. kernel_size=1,
  222. stride=1,
  223. padding=0,
  224. groups=groups,
  225. bias=False,
  226. ),
  227. )
  228. self.dbb_avg.add_sublayer(
  229. "bn", BNAndPad(pad_pixels=padding, num_features=out_channels)
  230. )
  231. self.dbb_avg.add_sublayer(
  232. "avg",
  233. nn.AvgPool2D(kernel_size=kernel_size, stride=stride, padding=0),
  234. )
  235. self.dbb_1x1 = conv_bn(
  236. in_channels=in_channels,
  237. out_channels=out_channels,
  238. kernel_size=1,
  239. stride=stride,
  240. padding=0,
  241. groups=groups,
  242. )
  243. else:
  244. self.dbb_avg.add_sublayer(
  245. "avg",
  246. nn.AvgPool2D(
  247. kernel_size=kernel_size, stride=stride, padding=padding
  248. ),
  249. )
  250. self.dbb_avg.add_sublayer("avgbn", nn.BatchNorm2D(out_channels))
  251. if internal_channels_1x1_3x3 is None:
  252. internal_channels_1x1_3x3 = (
  253. in_channels if groups < out_channels else 2 * in_channels
  254. ) # For mobilenet, it is better to have 2X internal channels
  255. self.dbb_1x1_kxk = nn.Sequential()
  256. if internal_channels_1x1_3x3 == in_channels:
  257. self.dbb_1x1_kxk.add_sublayer(
  258. "idconv1", IdentityBasedConv1x1(channels=in_channels, groups=groups)
  259. )
  260. else:
  261. self.dbb_1x1_kxk.add_sublayer(
  262. "conv1",
  263. nn.Conv2d(
  264. in_channels=in_channels,
  265. out_channels=internal_channels_1x1_3x3,
  266. kernel_size=1,
  267. stride=1,
  268. padding=0,
  269. groups=groups,
  270. bias=False,
  271. ),
  272. )
  273. self.dbb_1x1_kxk.add_sublayer(
  274. "bn1",
  275. BNAndPad(pad_pixels=padding, num_features=internal_channels_1x1_3x3),
  276. )
  277. self.dbb_1x1_kxk.add_sublayer(
  278. "conv2",
  279. nn.Conv2d(
  280. in_channels=internal_channels_1x1_3x3,
  281. out_channels=out_channels,
  282. kernel_size=kernel_size,
  283. stride=stride,
  284. padding=0,
  285. groups=groups,
  286. bias=False,
  287. ),
  288. )
  289. self.dbb_1x1_kxk.add_sublayer("bn2", nn.BatchNorm2D(out_channels))
  290. # The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
  291. if single_init:
  292. # Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
  293. self.single_init()
  294. def forward(self, inputs):
  295. if self.is_repped:
  296. return self.nonlinear(self.dbb_reparam(inputs))
  297. out = self.dbb_origin(inputs)
  298. if hasattr(self, "dbb_1x1"):
  299. out += self.dbb_1x1(inputs)
  300. out += self.dbb_avg(inputs)
  301. out += self.dbb_1x1_kxk(inputs)
  302. return self.nonlinear(out)
  303. def init_gamma(self, gamma_value):
  304. if hasattr(self, "dbb_origin"):
  305. torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
  306. if hasattr(self, "dbb_1x1"):
  307. torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
  308. if hasattr(self, "dbb_avg"):
  309. torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
  310. if hasattr(self, "dbb_1x1_kxk"):
  311. torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
  312. def single_init(self):
  313. self.init_gamma(0.0)
  314. if hasattr(self, "dbb_origin"):
  315. torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
  316. def get_equivalent_kernel_bias(self):
  317. k_origin, b_origin = transI_fusebn(
  318. self.dbb_origin.conv.weight, self.dbb_origin.bn
  319. )
  320. if hasattr(self, "dbb_1x1"):
  321. k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
  322. k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
  323. else:
  324. k_1x1, b_1x1 = 0, 0
  325. if hasattr(self.dbb_1x1_kxk, "idconv1"):
  326. k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
  327. else:
  328. k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
  329. k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(
  330. k_1x1_kxk_first, self.dbb_1x1_kxk.bn1
  331. )
  332. k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(
  333. self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2
  334. )
  335. k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(
  336. k_1x1_kxk_first,
  337. b_1x1_kxk_first,
  338. k_1x1_kxk_second,
  339. b_1x1_kxk_second,
  340. groups=self.groups,
  341. )
  342. k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
  343. k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg, self.dbb_avg.avgbn)
  344. if hasattr(self.dbb_avg, "conv"):
  345. k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(
  346. self.dbb_avg.conv.weight, self.dbb_avg.bn
  347. )
  348. k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(
  349. k_1x1_avg_first,
  350. b_1x1_avg_first,
  351. k_1x1_avg_second,
  352. b_1x1_avg_second,
  353. groups=self.groups,
  354. )
  355. else:
  356. k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
  357. return transII_addbranch(
  358. (k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
  359. (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged),
  360. )
  361. def re_parameterize(self):
  362. if self.is_repped:
  363. return
  364. kernel, bias = self.get_equivalent_kernel_bias()
  365. self.dbb_reparam = nn.Conv2d(
  366. in_channels=self.dbb_origin.conv._in_channels,
  367. out_channels=self.dbb_origin.conv._out_channels,
  368. kernel_size=self.dbb_origin.conv._kernel_size,
  369. stride=self.dbb_origin.conv._stride,
  370. padding=self.dbb_origin.conv._padding,
  371. dilation=self.dbb_origin.conv._dilation,
  372. groups=self.dbb_origin.conv._groups,
  373. bias=True,
  374. )
  375. self.dbb_reparam.weight.set_value(kernel)
  376. self.dbb_reparam.bias.set_value(bias)
  377. self.__delattr__("dbb_origin")
  378. self.__delattr__("dbb_avg")
  379. if hasattr(self, "dbb_1x1"):
  380. self.__delattr__("dbb_1x1")
  381. self.__delattr__("dbb_1x1_kxk")
  382. self.is_repped = True
  383. class Identity(nn.Module):
  384. def __init__(self):
  385. super(Identity, self).__init__()
  386. def forward(self, inputs):
  387. return inputs
  388. class TheseusLayer(nn.Module):
  389. def __init__(self, *args, **kwargs):
  390. super().__init__()
  391. self.res_dict = {}
  392. # self.res_name = self.full_name()
  393. self.res_name = self.__class__.__name__.lower()
  394. self.pruner = None
  395. self.quanter = None
  396. self.init_net(*args, **kwargs)
  397. def _return_dict_hook(self, layer, input, output):
  398. res_dict = {"logits": output}
  399. # 'list' is needed to avoid error raised by popping self.res_dict
  400. for res_key in list(self.res_dict):
  401. # clear the res_dict because the forward process may change according to input
  402. res_dict[res_key] = self.res_dict.pop(res_key)
  403. return res_dict
  404. def init_net(
  405. self,
  406. stages_pattern=None,
  407. return_patterns=None,
  408. return_stages=None,
  409. freeze_befor=None,
  410. stop_after=None,
  411. *args,
  412. **kwargs,
  413. ):
  414. # init the output of net
  415. if return_patterns or return_stages:
  416. if return_patterns and return_stages:
  417. msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
  418. return_stages = None
  419. if return_stages is True:
  420. return_patterns = stages_pattern
  421. # return_stages is int or bool
  422. if type(return_stages) is int:
  423. return_stages = [return_stages]
  424. if isinstance(return_stages, list):
  425. if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
  426. msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
  427. return_stages = [
  428. val
  429. for val in return_stages
  430. if val >= 0 and val < len(stages_pattern)
  431. ]
  432. return_patterns = [stages_pattern[i] for i in return_stages]
  433. if return_patterns:
  434. # call update_res function after the __init__ of the object has completed execution, that is, the constructing of layer or model has been completed.
  435. def update_res_hook(layer, input):
  436. self.update_res(return_patterns)
  437. self.register_forward_pre_hook(update_res_hook)
  438. # freeze subnet
  439. if freeze_befor is not None:
  440. self.freeze_befor(freeze_befor)
  441. # set subnet to Identity
  442. if stop_after is not None:
  443. self.stop_after(stop_after)
  444. def init_res(self, stages_pattern, return_patterns=None, return_stages=None):
  445. if return_patterns and return_stages:
  446. return_stages = None
  447. if return_stages is True:
  448. return_patterns = stages_pattern
  449. # return_stages is int or bool
  450. if type(return_stages) is int:
  451. return_stages = [return_stages]
  452. if isinstance(return_stages, list):
  453. if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
  454. return_stages = [
  455. val
  456. for val in return_stages
  457. if val >= 0 and val < len(stages_pattern)
  458. ]
  459. return_patterns = [stages_pattern[i] for i in return_stages]
  460. if return_patterns:
  461. self.update_res(return_patterns)
  462. def replace_sub(self, *args, **kwargs) -> None:
  463. msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
  464. raise DeprecationWarning(msg)
  465. def upgrade_sublayer(
  466. self,
  467. layer_name_pattern: Union[str, List[str]],
  468. handle_func: Callable[[nn.Module, str], nn.Module],
  469. ) -> Dict[str, nn.Module]:
  470. """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
  471. Args:
  472. layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
  473. handle_func (Callable[[nn.Module, str], nn.Module]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Module) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
  474. Returns:
  475. Dict[str, nn.Module]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
  476. Examples:
  477. from paddle import nn
  478. import paddleclas
  479. def rep_func(layer: nn.Module, pattern: str):
  480. new_layer = nn.Conv2d(
  481. in_channels=layer._in_channels,
  482. out_channels=layer._out_channels,
  483. kernel_size=5,
  484. padding=2
  485. )
  486. return new_layer
  487. net = paddleclas.MobileNetV1()
  488. res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
  489. print(res)
  490. # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
  491. """
  492. if not isinstance(layer_name_pattern, list):
  493. layer_name_pattern = [layer_name_pattern]
  494. hit_layer_pattern_list = []
  495. for pattern in layer_name_pattern:
  496. # parse pattern to find target layer and its parent
  497. layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
  498. if not layer_list:
  499. continue
  500. sub_layer_parent = layer_list[-2]["layer"] if len(layer_list) > 1 else self
  501. sub_layer = layer_list[-1]["layer"]
  502. sub_layer_name = layer_list[-1]["name"]
  503. sub_layer_index_list = layer_list[-1]["index_list"]
  504. new_sub_layer = handle_func(sub_layer, pattern)
  505. if sub_layer_index_list:
  506. if len(sub_layer_index_list) > 1:
  507. sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[
  508. sub_layer_index_list[0]
  509. ]
  510. for sub_layer_index in sub_layer_index_list[1:-1]:
  511. sub_layer_parent = sub_layer_parent[sub_layer_index]
  512. sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
  513. else:
  514. getattr(sub_layer_parent, sub_layer_name)[
  515. sub_layer_index_list[0]
  516. ] = new_sub_layer
  517. else:
  518. setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
  519. hit_layer_pattern_list.append(pattern)
  520. return hit_layer_pattern_list
  521. def stop_after(self, stop_layer_name: str) -> bool:
  522. """stop forward and backward after 'stop_layer_name'.
  523. Args:
  524. stop_layer_name (str): The name of layer that stop forward and backward after this layer.
  525. Returns:
  526. bool: 'True' if successful, 'False' otherwise.
  527. """
  528. layer_list = parse_pattern_str(stop_layer_name, self)
  529. if not layer_list:
  530. return False
  531. parent_layer = self
  532. for layer_dict in layer_list:
  533. name, index_list = layer_dict["name"], layer_dict["index_list"]
  534. if not set_identity(parent_layer, name, index_list):
  535. msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
  536. return False
  537. parent_layer = layer_dict["layer"]
  538. return True
  539. def freeze_befor(self, layer_name: str) -> bool:
  540. """freeze the layer named layer_name and its previous layer.
  541. Args:
  542. layer_name (str): The name of layer that would be freezed.
  543. Returns:
  544. bool: 'True' if successful, 'False' otherwise.
  545. """
  546. def stop_grad(layer, pattern):
  547. class StopGradLayer(nn.Module):
  548. def __init__(self):
  549. super().__init__()
  550. self.layer = layer
  551. def forward(self, x):
  552. x = self.layer(x)
  553. x.stop_gradient = True
  554. return x
  555. new_layer = StopGradLayer()
  556. return new_layer
  557. res = self.upgrade_sublayer(layer_name, stop_grad)
  558. if len(res) == 0:
  559. msg = "Failed to stop the gradient before the layer named '{layer_name}'"
  560. return False
  561. return True
  562. def update_res(self, return_patterns: Union[str, List[str]]) -> Dict[str, nn.Module]:
  563. """update the result(s) to be returned.
  564. Args:
  565. return_patterns (Union[str, List[str]]): The name of layer to return output.
  566. Returns:
  567. Dict[str, nn.Module]: The pattern(str) and corresponding layer(nn.Module) that have been set successfully.
  568. """
  569. # clear res_dict that could have been set
  570. self.res_dict = {}
  571. class Handler(object):
  572. def __init__(self, res_dict):
  573. # res_dict is a reference
  574. self.res_dict = res_dict
  575. def __call__(self, layer, pattern):
  576. layer.res_dict = self.res_dict
  577. layer.res_name = pattern
  578. if hasattr(layer, "hook_remove_helper"):
  579. layer.hook_remove_helper.remove()
  580. layer.hook_remove_helper = layer.register_forward_post_hook(
  581. save_sub_res_hook
  582. )
  583. return layer
  584. handle_func = Handler(self.res_dict)
  585. hit_layer_pattern_list = self.upgrade_sublayer(
  586. return_patterns, handle_func=handle_func
  587. )
  588. if hasattr(self, "hook_remove_helper"):
  589. self.hook_remove_helper.remove()
  590. self.hook_remove_helper = self.register_forward_post_hook(
  591. self._return_dict_hook
  592. )
  593. return hit_layer_pattern_list
  594. def save_sub_res_hook(layer, input, output):
  595. layer.res_dict[layer.res_name] = output
  596. def set_identity(
  597. parent_layer: nn.Module, layer_name: str, layer_index_list: str = None
  598. ) -> bool:
  599. """set the layer specified by layer_name and layer_index_list to Identity.
  600. Args:
  601. parent_layer (nn.Module): The parent layer of target layer specified by layer_name and layer_index_list.
  602. layer_name (str): The name of target layer to be set to Identity.
  603. layer_index_list (str, optional): The index of target layer to be set to Identity in parent_layer. Defaults to None.
  604. Returns:
  605. bool: True if successfully, False otherwise.
  606. """
  607. stop_after = False
  608. for sub_layer_name in parent_layer._sub_layers:
  609. if stop_after:
  610. parent_layer._sub_layers[sub_layer_name] = Identity()
  611. continue
  612. if sub_layer_name == layer_name:
  613. stop_after = True
  614. if layer_index_list and stop_after:
  615. layer_container = parent_layer._sub_layers[layer_name]
  616. for num, layer_index in enumerate(layer_index_list):
  617. stop_after = False
  618. for i in range(num):
  619. layer_container = layer_container[layer_index_list[i]]
  620. for sub_layer_index in layer_container._sub_layers:
  621. if stop_after:
  622. parent_layer._sub_layers[layer_name][sub_layer_index] = Identity()
  623. continue
  624. if layer_index == sub_layer_index:
  625. stop_after = True
  626. return stop_after
  627. def parse_pattern_str(
  628. pattern: str, parent_layer: nn.Module
  629. ) -> Union[None, List[Dict[str, Union[nn.Module, str, None]]]]:
  630. """parse the string type pattern.
  631. Args:
  632. pattern (str): The pattern to describe layer.
  633. parent_layer (nn.Module): The root layer relative to the pattern.
  634. Returns:
  635. Union[None, List[Dict[str, Union[nn.Module, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
  636. [
  637. {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
  638. {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
  639. ...
  640. ]
  641. """
  642. pattern_list = pattern.split(".")
  643. if not pattern_list:
  644. msg = f"The pattern('{pattern}') is illegal. Please check and retry."
  645. return None
  646. layer_list = []
  647. while len(pattern_list) > 0:
  648. if "[" in pattern_list[0]:
  649. target_layer_name = pattern_list[0].split("[")[0]
  650. target_layer_index_list = list(
  651. index.split("]")[0] for index in pattern_list[0].split("[")[1:]
  652. )
  653. else:
  654. target_layer_name = pattern_list[0]
  655. target_layer_index_list = None
  656. target_layer = getattr(parent_layer, target_layer_name, None)
  657. if target_layer is None:
  658. msg = f"Not found layer named('{target_layer_name}') specified in pattern('{pattern}')."
  659. return None
  660. if target_layer_index_list:
  661. for target_layer_index in target_layer_index_list:
  662. if int(target_layer_index) < 0 or int(target_layer_index) >= len(
  663. target_layer
  664. ):
  665. msg = f"Not found layer by index('{target_layer_index}') specified in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
  666. return None
  667. target_layer = target_layer[target_layer_index]
  668. layer_list.append(
  669. {
  670. "layer": target_layer,
  671. "name": target_layer_name,
  672. "index_list": target_layer_index_list,
  673. }
  674. )
  675. pattern_list = pattern_list[1:]
  676. parent_layer = target_layer
  677. return layer_list
  678. class LearnableAffineBlock(TheseusLayer):
  679. """
  680. Create a learnable affine block module. This module can significantly improve accuracy on smaller models.
  681. Args:
  682. scale_value (float): The initial value of the scale parameter, default is 1.0.
  683. bias_value (float): The initial value of the bias parameter, default is 0.0.
  684. lr_mult (float): The learning rate multiplier, default is 1.0.
  685. lab_lr (float): The learning rate, default is 0.01.
  686. """
  687. def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.01):
  688. super().__init__()
  689. # self.scale = self.create_parameter(
  690. # shape=[
  691. # 1,
  692. # ],
  693. # default_initializer=nn.init.Constant(value=scale_value),
  694. # # attr=ParamAttr(learning_rate=lr_mult * lab_lr),
  695. # )
  696. # self.add_parameter("scale", self.scale)
  697. self.scale = torch.Parameter(
  698. nn.init.constant_(
  699. torch.ones(1).to(torch.float32), val=scale_value
  700. )
  701. )
  702. self.register_parameter("scale", self.scale)
  703. # self.bias = self.create_parameter(
  704. # shape=[
  705. # 1,
  706. # ],
  707. # default_initializer=nn.init.Constant(value=bias_value),
  708. # # attr=ParamAttr(learning_rate=lr_mult * lab_lr),
  709. # )
  710. # self.add_parameter("bias", self.bias)
  711. self.bias = torch.Parameter(
  712. nn.init.constant_(
  713. torch.ones(1).to(torch.float32), val=bias_value
  714. )
  715. )
  716. self.register_parameter("bias", self.bias)
  717. def forward(self, x):
  718. return self.scale * x + self.bias
  719. class ConvBNAct(TheseusLayer):
  720. """
  721. ConvBNAct is a combination of convolution and batchnorm layers.
  722. Args:
  723. in_channels (int): Number of input channels.
  724. out_channels (int): Number of output channels.
  725. kernel_size (int): Size of the convolution kernel. Defaults to 3.
  726. stride (int): Stride of the convolution. Defaults to 1.
  727. padding (int/str): Padding or padding type for the convolution. Defaults to 1.
  728. groups (int): Number of groups for the convolution. Defaults to 1.
  729. use_act: (bool): Whether to use activation function. Defaults to True.
  730. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  731. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  732. """
  733. def __init__(
  734. self,
  735. in_channels,
  736. out_channels,
  737. kernel_size=3,
  738. stride=1,
  739. padding=1,
  740. groups=1,
  741. use_act=True,
  742. use_lab=False,
  743. lr_mult=1.0,
  744. ):
  745. super().__init__()
  746. self.use_act = use_act
  747. self.use_lab = use_lab
  748. self.conv = nn.Conv2d(
  749. in_channels,
  750. out_channels,
  751. kernel_size,
  752. stride,
  753. padding=padding if isinstance(padding, str) else (kernel_size - 1) // 2,
  754. groups=groups,
  755. bias=False,
  756. )
  757. self.bn = nn.BatchNorm2d(
  758. out_channels,
  759. )
  760. if self.use_act:
  761. self.act = nn.ReLU()
  762. if self.use_lab:
  763. self.lab = LearnableAffineBlock(lr_mult=lr_mult)
  764. def forward(self, x):
  765. x = self.conv(x)
  766. x = self.bn(x)
  767. if self.use_act:
  768. x = self.act(x)
  769. if self.use_lab:
  770. x = self.lab(x)
  771. return x
  772. class LightConvBNAct(TheseusLayer):
  773. """
  774. LightConvBNAct is a combination of pw and dw layers.
  775. Args:
  776. in_channels (int): Number of input channels.
  777. out_channels (int): Number of output channels.
  778. kernel_size (int): Size of the depth-wise convolution kernel.
  779. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  780. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  781. """
  782. def __init__(
  783. self,
  784. in_channels,
  785. out_channels,
  786. kernel_size,
  787. use_lab=False,
  788. lr_mult=1.0,
  789. **kwargs,
  790. ):
  791. super().__init__()
  792. self.conv1 = ConvBNAct(
  793. in_channels=in_channels,
  794. out_channels=out_channels,
  795. kernel_size=1,
  796. use_act=False,
  797. use_lab=use_lab,
  798. lr_mult=lr_mult,
  799. )
  800. self.conv2 = ConvBNAct(
  801. in_channels=out_channels,
  802. out_channels=out_channels,
  803. kernel_size=kernel_size,
  804. groups=out_channels,
  805. use_act=True,
  806. use_lab=use_lab,
  807. lr_mult=lr_mult,
  808. )
  809. def forward(self, x):
  810. x = self.conv1(x)
  811. x = self.conv2(x)
  812. return x
  813. class PaddingSameAsPaddleMaxPool2d(torch.nn.Module):
  814. def __init__(self, kernel_size, stride=1):
  815. super().__init__()
  816. self.kernel_size = kernel_size
  817. self.stride = stride
  818. self.pool = torch.nn.MaxPool2d(kernel_size, stride, padding=0, ceil_mode=True)
  819. def forward(self, x):
  820. _, _, h, w = x.shape
  821. pad_h_total = max(0, (math.ceil(h / self.stride) - 1) * self.stride + self.kernel_size - h)
  822. pad_w_total = max(0, (math.ceil(w / self.stride) - 1) * self.stride + self.kernel_size - w)
  823. pad_h = pad_h_total // 2
  824. pad_w = pad_w_total // 2
  825. x = torch.nn.functional.pad(x, [pad_w, pad_w_total - pad_w, pad_h, pad_h_total - pad_h])
  826. return self.pool(x)
  827. class StemBlock(TheseusLayer):
  828. """
  829. StemBlock for PP-HGNetV2.
  830. Args:
  831. in_channels (int): Number of input channels.
  832. mid_channels (int): Number of middle channels.
  833. out_channels (int): Number of output channels.
  834. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  835. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  836. """
  837. def __init__(
  838. self,
  839. in_channels,
  840. mid_channels,
  841. out_channels,
  842. use_lab=False,
  843. lr_mult=1.0,
  844. text_rec=False,
  845. ):
  846. super().__init__()
  847. self.stem1 = ConvBNAct(
  848. in_channels=in_channels,
  849. out_channels=mid_channels,
  850. kernel_size=3,
  851. stride=2,
  852. use_lab=use_lab,
  853. lr_mult=lr_mult,
  854. )
  855. self.stem2a = ConvBNAct(
  856. in_channels=mid_channels,
  857. out_channels=mid_channels // 2,
  858. kernel_size=2,
  859. stride=1,
  860. padding="same",
  861. use_lab=use_lab,
  862. lr_mult=lr_mult,
  863. )
  864. self.stem2b = ConvBNAct(
  865. in_channels=mid_channels // 2,
  866. out_channels=mid_channels,
  867. kernel_size=2,
  868. stride=1,
  869. padding="same",
  870. use_lab=use_lab,
  871. lr_mult=lr_mult,
  872. )
  873. self.stem3 = ConvBNAct(
  874. in_channels=mid_channels * 2,
  875. out_channels=mid_channels,
  876. kernel_size=3,
  877. stride=1 if text_rec else 2,
  878. use_lab=use_lab,
  879. lr_mult=lr_mult,
  880. )
  881. self.stem4 = ConvBNAct(
  882. in_channels=mid_channels,
  883. out_channels=out_channels,
  884. kernel_size=1,
  885. stride=1,
  886. use_lab=use_lab,
  887. lr_mult=lr_mult,
  888. )
  889. self.pool = PaddingSameAsPaddleMaxPool2d(
  890. kernel_size=2, stride=1,
  891. )
  892. def forward(self, x):
  893. x = self.stem1(x)
  894. x2 = self.stem2a(x)
  895. x2 = self.stem2b(x2)
  896. x1 = self.pool(x)
  897. x = torch.cat([x1, x2], 1)
  898. x = self.stem3(x)
  899. x = self.stem4(x)
  900. return x
  901. class HGV2_Block(TheseusLayer):
  902. """
  903. HGV2_Block, the basic unit that constitutes the HGV2_Stage.
  904. Args:
  905. in_channels (int): Number of input channels.
  906. mid_channels (int): Number of middle channels.
  907. out_channels (int): Number of output channels.
  908. kernel_size (int): Size of the convolution kernel. Defaults to 3.
  909. layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
  910. stride (int): Stride of the convolution. Defaults to 1.
  911. padding (int/str): Padding or padding type for the convolution. Defaults to 1.
  912. groups (int): Number of groups for the convolution. Defaults to 1.
  913. use_act (bool): Whether to use activation function. Defaults to True.
  914. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  915. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  916. """
  917. def __init__(
  918. self,
  919. in_channels,
  920. mid_channels,
  921. out_channels,
  922. kernel_size=3,
  923. layer_num=6,
  924. identity=False,
  925. light_block=True,
  926. use_lab=False,
  927. lr_mult=1.0,
  928. ):
  929. super().__init__()
  930. self.identity = identity
  931. self.layers = nn.ModuleList()
  932. block_type = "LightConvBNAct" if light_block else "ConvBNAct"
  933. for i in range(layer_num):
  934. self.layers.append(
  935. eval(block_type)(
  936. in_channels=in_channels if i == 0 else mid_channels,
  937. out_channels=mid_channels,
  938. stride=1,
  939. kernel_size=kernel_size,
  940. use_lab=use_lab,
  941. lr_mult=lr_mult,
  942. )
  943. )
  944. # feature aggregation
  945. total_channels = in_channels + layer_num * mid_channels
  946. self.aggregation_squeeze_conv = ConvBNAct(
  947. in_channels=total_channels,
  948. out_channels=out_channels // 2,
  949. kernel_size=1,
  950. stride=1,
  951. use_lab=use_lab,
  952. lr_mult=lr_mult,
  953. )
  954. self.aggregation_excitation_conv = ConvBNAct(
  955. in_channels=out_channels // 2,
  956. out_channels=out_channels,
  957. kernel_size=1,
  958. stride=1,
  959. use_lab=use_lab,
  960. lr_mult=lr_mult,
  961. )
  962. def forward(self, x):
  963. identity = x
  964. output = []
  965. output.append(x)
  966. for layer in self.layers:
  967. x = layer(x)
  968. output.append(x)
  969. x = torch.cat(output, dim=1)
  970. x = self.aggregation_squeeze_conv(x)
  971. x = self.aggregation_excitation_conv(x)
  972. if self.identity:
  973. x += identity
  974. return x
  975. class HGV2_Stage(TheseusLayer):
  976. """
  977. HGV2_Stage, the basic unit that constitutes the PPHGNetV2.
  978. Args:
  979. in_channels (int): Number of input channels.
  980. mid_channels (int): Number of middle channels.
  981. out_channels (int): Number of output channels.
  982. block_num (int): Number of blocks in the HGV2 stage.
  983. layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
  984. is_downsample (bool): Whether to use downsampling operation. Defaults to False.
  985. light_block (bool): Whether to use light block. Defaults to True.
  986. kernel_size (int): Size of the convolution kernel. Defaults to 3.
  987. use_lab (bool, optional): Whether to use the LAB operation. Defaults to False.
  988. lr_mult (float, optional): Learning rate multiplier for the layer. Defaults to 1.0.
  989. """
  990. def __init__(
  991. self,
  992. in_channels,
  993. mid_channels,
  994. out_channels,
  995. block_num,
  996. layer_num=6,
  997. is_downsample=True,
  998. light_block=True,
  999. kernel_size=3,
  1000. use_lab=False,
  1001. stride=2,
  1002. lr_mult=1.0,
  1003. ):
  1004. super().__init__()
  1005. self.is_downsample = is_downsample
  1006. if self.is_downsample:
  1007. self.downsample = ConvBNAct(
  1008. in_channels=in_channels,
  1009. out_channels=in_channels,
  1010. kernel_size=3,
  1011. stride=stride,
  1012. groups=in_channels,
  1013. use_act=False,
  1014. use_lab=use_lab,
  1015. lr_mult=lr_mult,
  1016. )
  1017. blocks_list = []
  1018. for i in range(block_num):
  1019. blocks_list.append(
  1020. HGV2_Block(
  1021. in_channels=in_channels if i == 0 else out_channels,
  1022. mid_channels=mid_channels,
  1023. out_channels=out_channels,
  1024. kernel_size=kernel_size,
  1025. layer_num=layer_num,
  1026. identity=False if i == 0 else True,
  1027. light_block=light_block,
  1028. use_lab=use_lab,
  1029. lr_mult=lr_mult,
  1030. )
  1031. )
  1032. self.blocks = nn.Sequential(*blocks_list)
  1033. def forward(self, x):
  1034. if self.is_downsample:
  1035. x = self.downsample(x)
  1036. x = self.blocks(x)
  1037. return x
  1038. class PPHGNetV2(TheseusLayer):
  1039. """
  1040. PPHGNetV2
  1041. Args:
  1042. stage_config (dict): Config for PPHGNetV2 stages. such as the number of channels, stride, etc.
  1043. stem_channels: (list): Number of channels of the stem of the PPHGNetV2.
  1044. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  1045. use_last_conv (bool): Whether to use the last conv layer as the output channel. Defaults to True.
  1046. class_expand (int): Number of channels for the last 1x1 convolutional layer.
  1047. drop_prob (float): Dropout probability for the last 1x1 convolutional layer. Defaults to 0.0.
  1048. class_num (int): The number of classes for the classification layer. Defaults to 1000.
  1049. lr_mult_list (list): Learning rate multiplier for the stages. Defaults to [1.0, 1.0, 1.0, 1.0, 1.0].
  1050. Returns:
  1051. model: nn.Module. Specific PPHGNetV2 model depends on args.
  1052. """
  1053. def __init__(
  1054. self,
  1055. stage_config,
  1056. stem_channels=[3, 32, 64],
  1057. use_lab=False,
  1058. use_last_conv=True,
  1059. class_expand=2048,
  1060. dropout_prob=0.0,
  1061. class_num=1000,
  1062. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  1063. det=False,
  1064. text_rec=False,
  1065. out_indices=None,
  1066. **kwargs,
  1067. ):
  1068. super().__init__()
  1069. self.det = det
  1070. self.text_rec = text_rec
  1071. self.use_lab = use_lab
  1072. self.use_last_conv = use_last_conv
  1073. self.class_expand = class_expand
  1074. self.class_num = class_num
  1075. self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
  1076. self.out_channels = []
  1077. # stem
  1078. self.stem = StemBlock(
  1079. in_channels=stem_channels[0],
  1080. mid_channels=stem_channels[1],
  1081. out_channels=stem_channels[2],
  1082. use_lab=use_lab,
  1083. lr_mult=lr_mult_list[0],
  1084. text_rec=text_rec,
  1085. )
  1086. # stages
  1087. self.stages = nn.ModuleList()
  1088. for i, k in enumerate(stage_config):
  1089. (
  1090. in_channels,
  1091. mid_channels,
  1092. out_channels,
  1093. block_num,
  1094. is_downsample,
  1095. light_block,
  1096. kernel_size,
  1097. layer_num,
  1098. stride,
  1099. ) = stage_config[k]
  1100. self.stages.append(
  1101. HGV2_Stage(
  1102. in_channels,
  1103. mid_channels,
  1104. out_channels,
  1105. block_num,
  1106. layer_num,
  1107. is_downsample,
  1108. light_block,
  1109. kernel_size,
  1110. use_lab,
  1111. stride,
  1112. lr_mult=lr_mult_list[i + 1],
  1113. )
  1114. )
  1115. if i in self.out_indices:
  1116. self.out_channels.append(out_channels)
  1117. if not self.det:
  1118. self.out_channels = stage_config["stage4"][2]
  1119. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  1120. if self.use_last_conv:
  1121. self.last_conv = nn.Conv2d(
  1122. in_channels=out_channels,
  1123. out_channels=self.class_expand,
  1124. kernel_size=1,
  1125. stride=1,
  1126. padding=0,
  1127. bias=False,
  1128. )
  1129. self.act = nn.ReLU()
  1130. if self.use_lab:
  1131. self.lab = LearnableAffineBlock()
  1132. # self.dropout = nn.Dropout(p=dropout_prob, mode="downscale_in_infer")
  1133. self.dropout = nn.Dropout(p=dropout_prob)
  1134. self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
  1135. if not self.det:
  1136. self.fc = nn.Linear(
  1137. self.class_expand if self.use_last_conv else out_channels,
  1138. self.class_num,
  1139. )
  1140. self._init_weights()
  1141. def _init_weights(self):
  1142. for m in self.modules():
  1143. if isinstance(m, nn.Conv2d):
  1144. nn.init.kaiming_normal_(m.weight)
  1145. elif isinstance(m, (nn.BatchNorm2d)):
  1146. nn.init.ones_(m.weight)
  1147. nn.init.zeros_(m.bias)
  1148. elif isinstance(m, nn.Linear):
  1149. nn.init.zeros_(m.bias)
  1150. def forward(self, x):
  1151. x = self.stem(x)
  1152. out = []
  1153. for i, stage in enumerate(self.stages):
  1154. x = stage(x)
  1155. if self.det and i in self.out_indices:
  1156. out.append(x)
  1157. if self.det:
  1158. return out
  1159. if self.text_rec:
  1160. if self.training:
  1161. x = F.adaptive_avg_pool2d(x, [1, 40])
  1162. else:
  1163. x = F.avg_pool2d(x, [3, 2])
  1164. return x
  1165. def PPHGNetV2_B0(pretrained=False, use_ssld=False, **kwargs):
  1166. """
  1167. PPHGNetV2_B0
  1168. Args:
  1169. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1170. If str, means the path of the pretrained model.
  1171. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1172. Returns:
  1173. model: nn.Module. Specific `PPHGNetV2_B0` model depends on args.
  1174. """
  1175. stage_config = {
  1176. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1177. "stage1": [16, 16, 64, 1, False, False, 3, 3],
  1178. "stage2": [64, 32, 256, 1, True, False, 3, 3],
  1179. "stage3": [256, 64, 512, 2, True, True, 5, 3],
  1180. "stage4": [512, 128, 1024, 1, True, True, 5, 3],
  1181. }
  1182. model = PPHGNetV2(
  1183. stem_channels=[3, 16, 16], stage_config=stage_config, use_lab=True, **kwargs
  1184. )
  1185. return model
  1186. def PPHGNetV2_B1(pretrained=False, use_ssld=False, **kwargs):
  1187. """
  1188. PPHGNetV2_B1
  1189. Args:
  1190. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1191. If str, means the path of the pretrained model.
  1192. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1193. Returns:
  1194. model: nn.Module. Specific `PPHGNetV2_B1` model depends on args.
  1195. """
  1196. stage_config = {
  1197. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1198. "stage1": [32, 32, 64, 1, False, False, 3, 3],
  1199. "stage2": [64, 48, 256, 1, True, False, 3, 3],
  1200. "stage3": [256, 96, 512, 2, True, True, 5, 3],
  1201. "stage4": [512, 192, 1024, 1, True, True, 5, 3],
  1202. }
  1203. model = PPHGNetV2(
  1204. stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
  1205. )
  1206. return model
  1207. def PPHGNetV2_B2(pretrained=False, use_ssld=False, **kwargs):
  1208. """
  1209. PPHGNetV2_B2
  1210. Args:
  1211. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1212. If str, means the path of the pretrained model.
  1213. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1214. Returns:
  1215. model: nn.Module. Specific `PPHGNetV2_B2` model depends on args.
  1216. """
  1217. stage_config = {
  1218. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1219. "stage1": [32, 32, 96, 1, False, False, 3, 4],
  1220. "stage2": [96, 64, 384, 1, True, False, 3, 4],
  1221. "stage3": [384, 128, 768, 3, True, True, 5, 4],
  1222. "stage4": [768, 256, 1536, 1, True, True, 5, 4],
  1223. }
  1224. model = PPHGNetV2(
  1225. stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
  1226. )
  1227. return model
  1228. def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
  1229. """
  1230. PPHGNetV2_B3
  1231. Args:
  1232. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1233. If str, means the path of the pretrained model.
  1234. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1235. Returns:
  1236. model: nn.Module. Specific `PPHGNetV2_B3` model depends on args.
  1237. """
  1238. stage_config = {
  1239. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1240. "stage1": [32, 32, 128, 1, False, False, 3, 5],
  1241. "stage2": [128, 64, 512, 1, True, False, 3, 5],
  1242. "stage3": [512, 128, 1024, 3, True, True, 5, 5],
  1243. "stage4": [1024, 256, 2048, 1, True, True, 5, 5],
  1244. }
  1245. model = PPHGNetV2(
  1246. stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
  1247. )
  1248. return model
  1249. def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
  1250. """
  1251. PPHGNetV2_B4
  1252. Args:
  1253. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1254. If str, means the path of the pretrained model.
  1255. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1256. Returns:
  1257. model: nn.Module. Specific `PPHGNetV2_B4` model depends on args.
  1258. """
  1259. stage_config_rec = {
  1260. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
  1261. "stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
  1262. "stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
  1263. "stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
  1264. "stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
  1265. }
  1266. stage_config_det = {
  1267. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1268. "stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
  1269. "stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
  1270. "stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
  1271. "stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
  1272. }
  1273. model = PPHGNetV2(
  1274. stem_channels=[3, 32, 48],
  1275. stage_config=stage_config_det if det else stage_config_rec,
  1276. use_lab=False,
  1277. det=det,
  1278. text_rec=text_rec,
  1279. **kwargs,
  1280. )
  1281. return model
  1282. def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
  1283. """
  1284. PPHGNetV2_B5
  1285. Args:
  1286. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1287. If str, means the path of the pretrained model.
  1288. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1289. Returns:
  1290. model: nn.Module. Specific `PPHGNetV2_B5` model depends on args.
  1291. """
  1292. stage_config = {
  1293. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1294. "stage1": [64, 64, 128, 1, False, False, 3, 6],
  1295. "stage2": [128, 128, 512, 2, True, False, 3, 6],
  1296. "stage3": [512, 256, 1024, 5, True, True, 5, 6],
  1297. "stage4": [1024, 512, 2048, 2, True, True, 5, 6],
  1298. }
  1299. model = PPHGNetV2(
  1300. stem_channels=[3, 32, 64], stage_config=stage_config, use_lab=False, **kwargs
  1301. )
  1302. return model
  1303. def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
  1304. """
  1305. PPHGNetV2_B6
  1306. Args:
  1307. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1308. If str, means the path of the pretrained model.
  1309. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1310. Returns:
  1311. model: nn.Module. Specific `PPHGNetV2_B6` model depends on args.
  1312. """
  1313. stage_config = {
  1314. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1315. "stage1": [96, 96, 192, 2, False, False, 3, 6],
  1316. "stage2": [192, 192, 512, 3, True, False, 3, 6],
  1317. "stage3": [512, 384, 1024, 6, True, True, 5, 6],
  1318. "stage4": [1024, 768, 2048, 3, True, True, 5, 6],
  1319. }
  1320. model = PPHGNetV2(
  1321. stem_channels=[3, 48, 96], stage_config=stage_config, use_lab=False, **kwargs
  1322. )
  1323. return model
  1324. class PPHGNetV2_B4_Formula(nn.Module):
  1325. """
  1326. PPHGNetV2_B4_Formula
  1327. Args:
  1328. in_channels (int): Number of input channels. Default is 3 (for RGB images).
  1329. class_num (int): Number of classes for classification. Default is 1000.
  1330. Returns:
  1331. model: nn.Module. Specific `PPHGNetV2_B4` model with defined architecture.
  1332. """
  1333. def __init__(self, in_channels=3, class_num=1000):
  1334. super().__init__()
  1335. self.in_channels = in_channels
  1336. self.out_channels = 2048
  1337. stage_config = {
  1338. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1339. "stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
  1340. "stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
  1341. "stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
  1342. "stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
  1343. }
  1344. self.pphgnet_b4 = PPHGNetV2(
  1345. stem_channels=[3, 32, 48],
  1346. stage_config=stage_config,
  1347. class_num=class_num,
  1348. use_lab=False,
  1349. )
  1350. def forward(self, input_data):
  1351. if self.training:
  1352. pixel_values, label, attention_mask = input_data
  1353. else:
  1354. if isinstance(input_data, list):
  1355. pixel_values = input_data[0]
  1356. else:
  1357. pixel_values = input_data
  1358. num_channels = pixel_values.shape[1]
  1359. if num_channels == 1:
  1360. pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
  1361. pphgnet_b4_output = self.pphgnet_b4(pixel_values)
  1362. b, c, h, w = pphgnet_b4_output.shape
  1363. pphgnet_b4_output = pphgnet_b4_output.reshape([b, c, h * w]).transpose(
  1364. [0, 2, 1]
  1365. )
  1366. pphgnet_b4_output = DonutSwinModelOutput(
  1367. last_hidden_state=pphgnet_b4_output,
  1368. pooler_output=None,
  1369. hidden_states=None,
  1370. attentions=False,
  1371. reshaped_hidden_states=None,
  1372. )
  1373. if self.training:
  1374. return pphgnet_b4_output, label, attention_mask
  1375. else:
  1376. return pphgnet_b4_output
  1377. class PPHGNetV2_B6_Formula(nn.Module):
  1378. """
  1379. PPHGNetV2_B6_Formula
  1380. Args:
  1381. in_channels (int): Number of input channels. Default is 3 (for RGB images).
  1382. class_num (int): Number of classes for classification. Default is 1000.
  1383. Returns:
  1384. model: nn.Module. Specific `PPHGNetV2_B6` model with defined architecture.
  1385. """
  1386. def __init__(self, in_channels=3, class_num=1000):
  1387. super().__init__()
  1388. self.in_channels = in_channels
  1389. self.out_channels = 2048
  1390. stage_config = {
  1391. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1392. "stage1": [96, 96, 192, 2, False, False, 3, 6, 2],
  1393. "stage2": [192, 192, 512, 3, True, False, 3, 6, 2],
  1394. "stage3": [512, 384, 1024, 6, True, True, 5, 6, 2],
  1395. "stage4": [1024, 768, 2048, 3, True, True, 5, 6, 2],
  1396. }
  1397. self.pphgnet_b6 = PPHGNetV2(
  1398. stem_channels=[3, 48, 96],
  1399. class_num=class_num,
  1400. stage_config=stage_config,
  1401. use_lab=False,
  1402. )
  1403. def forward(self, input_data):
  1404. if self.training:
  1405. pixel_values, label, attention_mask = input_data
  1406. else:
  1407. if isinstance(input_data, list):
  1408. pixel_values = input_data[0]
  1409. else:
  1410. pixel_values = input_data
  1411. num_channels = pixel_values.shape[1]
  1412. if num_channels == 1:
  1413. pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
  1414. pphgnet_b6_output = self.pphgnet_b6(pixel_values)
  1415. b, c, h, w = pphgnet_b6_output.shape
  1416. pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).permute(
  1417. 0, 2, 1
  1418. )
  1419. pphgnet_b6_output = DonutSwinModelOutput(
  1420. last_hidden_state=pphgnet_b6_output,
  1421. pooler_output=None,
  1422. hidden_states=None,
  1423. attentions=False,
  1424. reshaped_hidden_states=None,
  1425. )
  1426. if self.training:
  1427. return pphgnet_b6_output, label, attention_mask
  1428. else:
  1429. return pphgnet_b6_output