cls.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. from paddleslim import L1NormFilterPruner
  16. from . import cv
  17. from paddlex.cv.transforms import cls_transforms
  18. import paddlex.utils.logging as logging
  19. transforms = cls_transforms
  20. class ResNet18(cv.models.ResNet18):
  21. def __init__(self, num_classes=1000, input_channel=None):
  22. if input_channel is not None:
  23. logging.warning(
  24. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  25. )
  26. super(ResNet18, self).__init__(num_classes=num_classes)
  27. def train(self,
  28. num_epochs,
  29. train_dataset,
  30. train_batch_size=64,
  31. eval_dataset=None,
  32. save_interval_epochs=1,
  33. log_interval_steps=2,
  34. save_dir='output',
  35. pretrain_weights='IMAGENET',
  36. optimizer=None,
  37. learning_rate=0.025,
  38. warmup_steps=0,
  39. warmup_start_lr=0.0,
  40. lr_decay_epochs=[30, 60, 90],
  41. lr_decay_gamma=0.1,
  42. use_vdl=False,
  43. sensitivities_file=None,
  44. pruned_flops=.2,
  45. early_stop=False,
  46. early_stop_patience=5):
  47. _legacy_train(
  48. self,
  49. num_epochs=num_epochs,
  50. train_dataset=train_dataset,
  51. train_batch_size=train_batch_size,
  52. eval_dataset=eval_dataset,
  53. save_interval_epochs=save_interval_epochs,
  54. log_interval_steps=log_interval_steps,
  55. save_dir=save_dir,
  56. pretrain_weights=pretrain_weights,
  57. optimizer=optimizer,
  58. learning_rate=learning_rate,
  59. warmup_steps=warmup_steps,
  60. warmup_start_lr=warmup_start_lr,
  61. lr_decay_epochs=lr_decay_epochs,
  62. lr_decay_gamma=lr_decay_gamma,
  63. use_vdl=use_vdl,
  64. sensitivities_file=sensitivities_file,
  65. pruned_flops=pruned_flops,
  66. early_stop=early_stop,
  67. early_stop_patience=early_stop_patience)
  68. class ResNet34(cv.models.ResNet34):
  69. def __init__(self, num_classes=1000, input_channel=None):
  70. if input_channel is not None:
  71. logging.warning(
  72. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  73. )
  74. super(ResNet34, self).__init__(num_classes=num_classes)
  75. def train(self,
  76. num_epochs,
  77. train_dataset,
  78. train_batch_size=64,
  79. eval_dataset=None,
  80. save_interval_epochs=1,
  81. log_interval_steps=2,
  82. save_dir='output',
  83. pretrain_weights='IMAGENET',
  84. optimizer=None,
  85. learning_rate=0.025,
  86. warmup_steps=0,
  87. warmup_start_lr=0.0,
  88. lr_decay_epochs=[30, 60, 90],
  89. lr_decay_gamma=0.1,
  90. use_vdl=False,
  91. sensitivities_file=None,
  92. pruned_flops=.2,
  93. early_stop=False,
  94. early_stop_patience=5):
  95. _legacy_train(
  96. self,
  97. num_epochs=num_epochs,
  98. train_dataset=train_dataset,
  99. train_batch_size=train_batch_size,
  100. eval_dataset=eval_dataset,
  101. save_interval_epochs=save_interval_epochs,
  102. log_interval_steps=log_interval_steps,
  103. save_dir=save_dir,
  104. pretrain_weights=pretrain_weights,
  105. optimizer=optimizer,
  106. learning_rate=learning_rate,
  107. warmup_steps=warmup_steps,
  108. warmup_start_lr=warmup_start_lr,
  109. lr_decay_epochs=lr_decay_epochs,
  110. lr_decay_gamma=lr_decay_gamma,
  111. use_vdl=use_vdl,
  112. sensitivities_file=sensitivities_file,
  113. pruned_flops=pruned_flops,
  114. early_stop=early_stop,
  115. early_stop_patience=early_stop_patience)
  116. class ResNet50(cv.models.ResNet50):
  117. def __init__(self, num_classes=1000, input_channel=None):
  118. if input_channel is not None:
  119. logging.warning(
  120. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  121. )
  122. super(ResNet50, self).__init__(num_classes=num_classes)
  123. def train(self,
  124. num_epochs,
  125. train_dataset,
  126. train_batch_size=64,
  127. eval_dataset=None,
  128. save_interval_epochs=1,
  129. log_interval_steps=2,
  130. save_dir='output',
  131. pretrain_weights='IMAGENET',
  132. optimizer=None,
  133. learning_rate=0.025,
  134. warmup_steps=0,
  135. warmup_start_lr=0.0,
  136. lr_decay_epochs=[30, 60, 90],
  137. lr_decay_gamma=0.1,
  138. use_vdl=False,
  139. sensitivities_file=None,
  140. pruned_flops=.2,
  141. early_stop=False,
  142. early_stop_patience=5):
  143. _legacy_train(
  144. self,
  145. num_epochs=num_epochs,
  146. train_dataset=train_dataset,
  147. train_batch_size=train_batch_size,
  148. eval_dataset=eval_dataset,
  149. save_interval_epochs=save_interval_epochs,
  150. log_interval_steps=log_interval_steps,
  151. save_dir=save_dir,
  152. pretrain_weights=pretrain_weights,
  153. optimizer=optimizer,
  154. learning_rate=learning_rate,
  155. warmup_steps=warmup_steps,
  156. warmup_start_lr=warmup_start_lr,
  157. lr_decay_epochs=lr_decay_epochs,
  158. lr_decay_gamma=lr_decay_gamma,
  159. use_vdl=use_vdl,
  160. sensitivities_file=sensitivities_file,
  161. pruned_flops=pruned_flops,
  162. early_stop=early_stop,
  163. early_stop_patience=early_stop_patience)
  164. class ResNet101(cv.models.ResNet101):
  165. def __init__(self, num_classes=1000, input_channel=None):
  166. if input_channel is not None:
  167. logging.warning(
  168. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  169. )
  170. super(ResNet101, self).__init__(num_classes=num_classes)
  171. def train(self,
  172. num_epochs,
  173. train_dataset,
  174. train_batch_size=64,
  175. eval_dataset=None,
  176. save_interval_epochs=1,
  177. log_interval_steps=2,
  178. save_dir='output',
  179. pretrain_weights='IMAGENET',
  180. optimizer=None,
  181. learning_rate=0.025,
  182. warmup_steps=0,
  183. warmup_start_lr=0.0,
  184. lr_decay_epochs=[30, 60, 90],
  185. lr_decay_gamma=0.1,
  186. use_vdl=False,
  187. sensitivities_file=None,
  188. pruned_flops=.2,
  189. early_stop=False,
  190. early_stop_patience=5):
  191. _legacy_train(
  192. self,
  193. num_epochs=num_epochs,
  194. train_dataset=train_dataset,
  195. train_batch_size=train_batch_size,
  196. eval_dataset=eval_dataset,
  197. save_interval_epochs=save_interval_epochs,
  198. log_interval_steps=log_interval_steps,
  199. save_dir=save_dir,
  200. pretrain_weights=pretrain_weights,
  201. optimizer=optimizer,
  202. learning_rate=learning_rate,
  203. warmup_steps=warmup_steps,
  204. warmup_start_lr=warmup_start_lr,
  205. lr_decay_epochs=lr_decay_epochs,
  206. lr_decay_gamma=lr_decay_gamma,
  207. use_vdl=use_vdl,
  208. sensitivities_file=sensitivities_file,
  209. pruned_flops=pruned_flops,
  210. early_stop=early_stop,
  211. early_stop_patience=early_stop_patience)
  212. class ResNet50_vd(cv.models.ResNet50_vd):
  213. def __init__(self, num_classes=1000, input_channel=None):
  214. if input_channel is not None:
  215. logging.warning(
  216. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  217. )
  218. super(ResNet50_vd, self).__init__(num_classes=num_classes)
  219. def train(self,
  220. num_epochs,
  221. train_dataset,
  222. train_batch_size=64,
  223. eval_dataset=None,
  224. save_interval_epochs=1,
  225. log_interval_steps=2,
  226. save_dir='output',
  227. pretrain_weights='IMAGENET',
  228. optimizer=None,
  229. learning_rate=0.025,
  230. warmup_steps=0,
  231. warmup_start_lr=0.0,
  232. lr_decay_epochs=[30, 60, 90],
  233. lr_decay_gamma=0.1,
  234. use_vdl=False,
  235. sensitivities_file=None,
  236. pruned_flops=.2,
  237. early_stop=False,
  238. early_stop_patience=5):
  239. _legacy_train(
  240. self,
  241. num_epochs=num_epochs,
  242. train_dataset=train_dataset,
  243. train_batch_size=train_batch_size,
  244. eval_dataset=eval_dataset,
  245. save_interval_epochs=save_interval_epochs,
  246. log_interval_steps=log_interval_steps,
  247. save_dir=save_dir,
  248. pretrain_weights=pretrain_weights,
  249. optimizer=optimizer,
  250. learning_rate=learning_rate,
  251. warmup_steps=warmup_steps,
  252. warmup_start_lr=warmup_start_lr,
  253. lr_decay_epochs=lr_decay_epochs,
  254. lr_decay_gamma=lr_decay_gamma,
  255. use_vdl=use_vdl,
  256. sensitivities_file=sensitivities_file,
  257. pruned_flops=pruned_flops,
  258. early_stop=early_stop,
  259. early_stop_patience=early_stop_patience)
  260. class ResNet101_vd(cv.models.ResNet101_vd):
  261. def __init__(self, num_classes=1000, input_channel=None):
  262. if input_channel is not None:
  263. logging.warning(
  264. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  265. )
  266. super(ResNet101_vd, self).__init__(num_classes=num_classes)
  267. def train(self,
  268. num_epochs,
  269. train_dataset,
  270. train_batch_size=64,
  271. eval_dataset=None,
  272. save_interval_epochs=1,
  273. log_interval_steps=2,
  274. save_dir='output',
  275. pretrain_weights='IMAGENET',
  276. optimizer=None,
  277. learning_rate=0.025,
  278. warmup_steps=0,
  279. warmup_start_lr=0.0,
  280. lr_decay_epochs=[30, 60, 90],
  281. lr_decay_gamma=0.1,
  282. use_vdl=False,
  283. sensitivities_file=None,
  284. pruned_flops=.2,
  285. early_stop=False,
  286. early_stop_patience=5):
  287. _legacy_train(
  288. self,
  289. num_epochs=num_epochs,
  290. train_dataset=train_dataset,
  291. train_batch_size=train_batch_size,
  292. eval_dataset=eval_dataset,
  293. save_interval_epochs=save_interval_epochs,
  294. log_interval_steps=log_interval_steps,
  295. save_dir=save_dir,
  296. pretrain_weights=pretrain_weights,
  297. optimizer=optimizer,
  298. learning_rate=learning_rate,
  299. warmup_steps=warmup_steps,
  300. warmup_start_lr=warmup_start_lr,
  301. lr_decay_epochs=lr_decay_epochs,
  302. lr_decay_gamma=lr_decay_gamma,
  303. use_vdl=use_vdl,
  304. sensitivities_file=sensitivities_file,
  305. pruned_flops=pruned_flops,
  306. early_stop=early_stop,
  307. early_stop_patience=early_stop_patience)
  308. class ResNet50_vd_ssld(cv.models.ResNet50_vd_ssld):
  309. def __init__(self, num_classes=1000, input_channel=None):
  310. if input_channel is not None:
  311. logging.warning(
  312. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  313. )
  314. super(ResNet50_vd_ssld, self).__init__(num_classes=num_classes)
  315. def train(self,
  316. num_epochs,
  317. train_dataset,
  318. train_batch_size=64,
  319. eval_dataset=None,
  320. save_interval_epochs=1,
  321. log_interval_steps=2,
  322. save_dir='output',
  323. pretrain_weights='IMAGENET',
  324. optimizer=None,
  325. learning_rate=0.025,
  326. warmup_steps=0,
  327. warmup_start_lr=0.0,
  328. lr_decay_epochs=[30, 60, 90],
  329. lr_decay_gamma=0.1,
  330. use_vdl=False,
  331. sensitivities_file=None,
  332. pruned_flops=.2,
  333. early_stop=False,
  334. early_stop_patience=5):
  335. _legacy_train(
  336. self,
  337. num_epochs=num_epochs,
  338. train_dataset=train_dataset,
  339. train_batch_size=train_batch_size,
  340. eval_dataset=eval_dataset,
  341. save_interval_epochs=save_interval_epochs,
  342. log_interval_steps=log_interval_steps,
  343. save_dir=save_dir,
  344. pretrain_weights=pretrain_weights,
  345. optimizer=optimizer,
  346. learning_rate=learning_rate,
  347. warmup_steps=warmup_steps,
  348. warmup_start_lr=warmup_start_lr,
  349. lr_decay_epochs=lr_decay_epochs,
  350. lr_decay_gamma=lr_decay_gamma,
  351. use_vdl=use_vdl,
  352. sensitivities_file=sensitivities_file,
  353. pruned_flops=pruned_flops,
  354. early_stop=early_stop,
  355. early_stop_patience=early_stop_patience)
  356. class ResNet101_vd_ssld(cv.models.ResNet101_vd_ssld):
  357. def __init__(self, num_classes=1000, input_channel=None):
  358. if input_channel is not None:
  359. logging.warning(
  360. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  361. )
  362. super(ResNet101_vd_ssld, self).__init__(num_classes=num_classes)
  363. def train(self,
  364. num_epochs,
  365. train_dataset,
  366. train_batch_size=64,
  367. eval_dataset=None,
  368. save_interval_epochs=1,
  369. log_interval_steps=2,
  370. save_dir='output',
  371. pretrain_weights='IMAGENET',
  372. optimizer=None,
  373. learning_rate=0.025,
  374. warmup_steps=0,
  375. warmup_start_lr=0.0,
  376. lr_decay_epochs=[30, 60, 90],
  377. lr_decay_gamma=0.1,
  378. use_vdl=False,
  379. sensitivities_file=None,
  380. pruned_flops=.2,
  381. early_stop=False,
  382. early_stop_patience=5):
  383. _legacy_train(
  384. self,
  385. num_epochs=num_epochs,
  386. train_dataset=train_dataset,
  387. train_batch_size=train_batch_size,
  388. eval_dataset=eval_dataset,
  389. save_interval_epochs=save_interval_epochs,
  390. log_interval_steps=log_interval_steps,
  391. save_dir=save_dir,
  392. pretrain_weights=pretrain_weights,
  393. optimizer=optimizer,
  394. learning_rate=learning_rate,
  395. warmup_steps=warmup_steps,
  396. warmup_start_lr=warmup_start_lr,
  397. lr_decay_epochs=lr_decay_epochs,
  398. lr_decay_gamma=lr_decay_gamma,
  399. use_vdl=use_vdl,
  400. sensitivities_file=sensitivities_file,
  401. pruned_flops=pruned_flops,
  402. early_stop=early_stop,
  403. early_stop_patience=early_stop_patience)
  404. class DarkNet53(cv.models.DarkNet53):
  405. def __init__(self, num_classes=1000, input_channel=None):
  406. if input_channel is not None:
  407. logging.warning(
  408. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  409. )
  410. super(DarkNet53, self).__init__(num_classes=num_classes)
  411. def train(self,
  412. num_epochs,
  413. train_dataset,
  414. train_batch_size=64,
  415. eval_dataset=None,
  416. save_interval_epochs=1,
  417. log_interval_steps=2,
  418. save_dir='output',
  419. pretrain_weights='IMAGENET',
  420. optimizer=None,
  421. learning_rate=0.025,
  422. warmup_steps=0,
  423. warmup_start_lr=0.0,
  424. lr_decay_epochs=[30, 60, 90],
  425. lr_decay_gamma=0.1,
  426. use_vdl=False,
  427. sensitivities_file=None,
  428. pruned_flops=.2,
  429. early_stop=False,
  430. early_stop_patience=5):
  431. _legacy_train(
  432. self,
  433. num_epochs=num_epochs,
  434. train_dataset=train_dataset,
  435. train_batch_size=train_batch_size,
  436. eval_dataset=eval_dataset,
  437. save_interval_epochs=save_interval_epochs,
  438. log_interval_steps=log_interval_steps,
  439. save_dir=save_dir,
  440. pretrain_weights=pretrain_weights,
  441. optimizer=optimizer,
  442. learning_rate=learning_rate,
  443. warmup_steps=warmup_steps,
  444. warmup_start_lr=warmup_start_lr,
  445. lr_decay_epochs=lr_decay_epochs,
  446. lr_decay_gamma=lr_decay_gamma,
  447. use_vdl=use_vdl,
  448. sensitivities_file=sensitivities_file,
  449. pruned_flops=pruned_flops,
  450. early_stop=early_stop,
  451. early_stop_patience=early_stop_patience)
  452. class MobileNetV1(cv.models.MobileNetV1):
  453. def __init__(self, num_classes=1000, input_channel=None):
  454. if input_channel is not None:
  455. logging.warning(
  456. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  457. )
  458. super(MobileNetV1, self).__init__(num_classes=num_classes)
  459. def train(self,
  460. num_epochs,
  461. train_dataset,
  462. train_batch_size=64,
  463. eval_dataset=None,
  464. save_interval_epochs=1,
  465. log_interval_steps=2,
  466. save_dir='output',
  467. pretrain_weights='IMAGENET',
  468. optimizer=None,
  469. learning_rate=0.025,
  470. warmup_steps=0,
  471. warmup_start_lr=0.0,
  472. lr_decay_epochs=[30, 60, 90],
  473. lr_decay_gamma=0.1,
  474. use_vdl=False,
  475. sensitivities_file=None,
  476. pruned_flops=.2,
  477. early_stop=False,
  478. early_stop_patience=5):
  479. _legacy_train(
  480. self,
  481. num_epochs=num_epochs,
  482. train_dataset=train_dataset,
  483. train_batch_size=train_batch_size,
  484. eval_dataset=eval_dataset,
  485. save_interval_epochs=save_interval_epochs,
  486. log_interval_steps=log_interval_steps,
  487. save_dir=save_dir,
  488. pretrain_weights=pretrain_weights,
  489. optimizer=optimizer,
  490. learning_rate=learning_rate,
  491. warmup_steps=warmup_steps,
  492. warmup_start_lr=warmup_start_lr,
  493. lr_decay_epochs=lr_decay_epochs,
  494. lr_decay_gamma=lr_decay_gamma,
  495. use_vdl=use_vdl,
  496. sensitivities_file=sensitivities_file,
  497. pruned_flops=pruned_flops,
  498. early_stop=early_stop,
  499. early_stop_patience=early_stop_patience)
  500. class MobileNetV2(cv.models.MobileNetV2):
  501. def __init__(self, num_classes=1000, input_channel=None):
  502. if input_channel is not None:
  503. logging.warning(
  504. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  505. )
  506. super(MobileNetV2, self).__init__(num_classes=num_classes)
  507. def train(self,
  508. num_epochs,
  509. train_dataset,
  510. train_batch_size=64,
  511. eval_dataset=None,
  512. save_interval_epochs=1,
  513. log_interval_steps=2,
  514. save_dir='output',
  515. pretrain_weights='IMAGENET',
  516. optimizer=None,
  517. learning_rate=0.025,
  518. warmup_steps=0,
  519. warmup_start_lr=0.0,
  520. lr_decay_epochs=[30, 60, 90],
  521. lr_decay_gamma=0.1,
  522. use_vdl=False,
  523. sensitivities_file=None,
  524. pruned_flops=.2,
  525. early_stop=False,
  526. early_stop_patience=5):
  527. _legacy_train(
  528. self,
  529. num_epochs=num_epochs,
  530. train_dataset=train_dataset,
  531. train_batch_size=train_batch_size,
  532. eval_dataset=eval_dataset,
  533. save_interval_epochs=save_interval_epochs,
  534. log_interval_steps=log_interval_steps,
  535. save_dir=save_dir,
  536. pretrain_weights=pretrain_weights,
  537. optimizer=optimizer,
  538. learning_rate=learning_rate,
  539. warmup_steps=warmup_steps,
  540. warmup_start_lr=warmup_start_lr,
  541. lr_decay_epochs=lr_decay_epochs,
  542. lr_decay_gamma=lr_decay_gamma,
  543. use_vdl=use_vdl,
  544. sensitivities_file=sensitivities_file,
  545. pruned_flops=pruned_flops,
  546. early_stop=early_stop,
  547. early_stop_patience=early_stop_patience)
  548. class MobileNetV3_small(cv.models.MobileNetV3_small):
  549. def __init__(self, num_classes=1000, input_channel=None):
  550. if input_channel is not None:
  551. logging.warning(
  552. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  553. )
  554. super(MobileNetV3_small, self).__init__(num_classes=num_classes)
  555. def train(self,
  556. num_epochs,
  557. train_dataset,
  558. train_batch_size=64,
  559. eval_dataset=None,
  560. save_interval_epochs=1,
  561. log_interval_steps=2,
  562. save_dir='output',
  563. pretrain_weights='IMAGENET',
  564. optimizer=None,
  565. learning_rate=0.025,
  566. warmup_steps=0,
  567. warmup_start_lr=0.0,
  568. lr_decay_epochs=[30, 60, 90],
  569. lr_decay_gamma=0.1,
  570. use_vdl=False,
  571. sensitivities_file=None,
  572. pruned_flops=.2,
  573. early_stop=False,
  574. early_stop_patience=5):
  575. _legacy_train(
  576. self,
  577. num_epochs=num_epochs,
  578. train_dataset=train_dataset,
  579. train_batch_size=train_batch_size,
  580. eval_dataset=eval_dataset,
  581. save_interval_epochs=save_interval_epochs,
  582. log_interval_steps=log_interval_steps,
  583. save_dir=save_dir,
  584. pretrain_weights=pretrain_weights,
  585. optimizer=optimizer,
  586. learning_rate=learning_rate,
  587. warmup_steps=warmup_steps,
  588. warmup_start_lr=warmup_start_lr,
  589. lr_decay_epochs=lr_decay_epochs,
  590. lr_decay_gamma=lr_decay_gamma,
  591. use_vdl=use_vdl,
  592. sensitivities_file=sensitivities_file,
  593. pruned_flops=pruned_flops,
  594. early_stop=early_stop,
  595. early_stop_patience=early_stop_patience)
  596. class MobileNetV3_large(cv.models.MobileNetV3_large):
  597. def __init__(self, num_classes=1000, input_channel=None):
  598. if input_channel is not None:
  599. logging.warning(
  600. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  601. )
  602. super(MobileNetV3_large, self).__init__(num_classes=num_classes)
  603. def train(self,
  604. num_epochs,
  605. train_dataset,
  606. train_batch_size=64,
  607. eval_dataset=None,
  608. save_interval_epochs=1,
  609. log_interval_steps=2,
  610. save_dir='output',
  611. pretrain_weights='IMAGENET',
  612. optimizer=None,
  613. learning_rate=0.025,
  614. warmup_steps=0,
  615. warmup_start_lr=0.0,
  616. lr_decay_epochs=[30, 60, 90],
  617. lr_decay_gamma=0.1,
  618. use_vdl=False,
  619. sensitivities_file=None,
  620. pruned_flops=.2,
  621. early_stop=False,
  622. early_stop_patience=5):
  623. _legacy_train(
  624. self,
  625. num_epochs=num_epochs,
  626. train_dataset=train_dataset,
  627. train_batch_size=train_batch_size,
  628. eval_dataset=eval_dataset,
  629. save_interval_epochs=save_interval_epochs,
  630. log_interval_steps=log_interval_steps,
  631. save_dir=save_dir,
  632. pretrain_weights=pretrain_weights,
  633. optimizer=optimizer,
  634. learning_rate=learning_rate,
  635. warmup_steps=warmup_steps,
  636. warmup_start_lr=warmup_start_lr,
  637. lr_decay_epochs=lr_decay_epochs,
  638. lr_decay_gamma=lr_decay_gamma,
  639. use_vdl=use_vdl,
  640. sensitivities_file=sensitivities_file,
  641. pruned_flops=pruned_flops,
  642. early_stop=early_stop,
  643. early_stop_patience=early_stop_patience)
  644. class MobileNetV3_small_ssld(cv.models.MobileNetV3_small_ssld):
  645. def __init__(self, num_classes=1000, input_channel=None):
  646. if input_channel is not None:
  647. logging.warning(
  648. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  649. )
  650. super(MobileNetV3_small_ssld, self).__init__(num_classes=num_classes)
  651. def train(self,
  652. num_epochs,
  653. train_dataset,
  654. train_batch_size=64,
  655. eval_dataset=None,
  656. save_interval_epochs=1,
  657. log_interval_steps=2,
  658. save_dir='output',
  659. pretrain_weights='IMAGENET',
  660. optimizer=None,
  661. learning_rate=0.025,
  662. warmup_steps=0,
  663. warmup_start_lr=0.0,
  664. lr_decay_epochs=[30, 60, 90],
  665. lr_decay_gamma=0.1,
  666. use_vdl=False,
  667. sensitivities_file=None,
  668. pruned_flops=.2,
  669. early_stop=False,
  670. early_stop_patience=5):
  671. _legacy_train(
  672. self,
  673. num_epochs=num_epochs,
  674. train_dataset=train_dataset,
  675. train_batch_size=train_batch_size,
  676. eval_dataset=eval_dataset,
  677. save_interval_epochs=save_interval_epochs,
  678. log_interval_steps=log_interval_steps,
  679. save_dir=save_dir,
  680. pretrain_weights=pretrain_weights,
  681. optimizer=optimizer,
  682. learning_rate=learning_rate,
  683. warmup_steps=warmup_steps,
  684. warmup_start_lr=warmup_start_lr,
  685. lr_decay_epochs=lr_decay_epochs,
  686. lr_decay_gamma=lr_decay_gamma,
  687. use_vdl=use_vdl,
  688. sensitivities_file=sensitivities_file,
  689. pruned_flops=pruned_flops,
  690. early_stop=early_stop,
  691. early_stop_patience=early_stop_patience)
  692. class MobileNetV3_large_ssld(cv.models.MobileNetV3_large_ssld):
  693. def __init__(self, num_classes=1000, input_channel=None):
  694. if input_channel is not None:
  695. logging.warning(
  696. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  697. )
  698. super(MobileNetV3_large_ssld, self).__init__(num_classes=num_classes)
  699. def train(self,
  700. num_epochs,
  701. train_dataset,
  702. train_batch_size=64,
  703. eval_dataset=None,
  704. save_interval_epochs=1,
  705. log_interval_steps=2,
  706. save_dir='output',
  707. pretrain_weights='IMAGENET',
  708. optimizer=None,
  709. learning_rate=0.025,
  710. warmup_steps=0,
  711. warmup_start_lr=0.0,
  712. lr_decay_epochs=[30, 60, 90],
  713. lr_decay_gamma=0.1,
  714. use_vdl=False,
  715. sensitivities_file=None,
  716. pruned_flops=.2,
  717. early_stop=False,
  718. early_stop_patience=5):
  719. _legacy_train(
  720. self,
  721. num_epochs=num_epochs,
  722. train_dataset=train_dataset,
  723. train_batch_size=train_batch_size,
  724. eval_dataset=eval_dataset,
  725. save_interval_epochs=save_interval_epochs,
  726. log_interval_steps=log_interval_steps,
  727. save_dir=save_dir,
  728. pretrain_weights=pretrain_weights,
  729. optimizer=optimizer,
  730. learning_rate=learning_rate,
  731. warmup_steps=warmup_steps,
  732. warmup_start_lr=warmup_start_lr,
  733. lr_decay_epochs=lr_decay_epochs,
  734. lr_decay_gamma=lr_decay_gamma,
  735. use_vdl=use_vdl,
  736. sensitivities_file=sensitivities_file,
  737. pruned_flops=pruned_flops,
  738. early_stop=early_stop,
  739. early_stop_patience=early_stop_patience)
  740. class Xception41(cv.models.Xception41):
  741. def __init__(self, num_classes=1000, input_channel=None):
  742. if input_channel is not None:
  743. logging.warning(
  744. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  745. )
  746. super(Xception41, self).__init__(num_classes=num_classes)
  747. def train(self,
  748. num_epochs,
  749. train_dataset,
  750. train_batch_size=64,
  751. eval_dataset=None,
  752. save_interval_epochs=1,
  753. log_interval_steps=2,
  754. save_dir='output',
  755. pretrain_weights='IMAGENET',
  756. optimizer=None,
  757. learning_rate=0.025,
  758. warmup_steps=0,
  759. warmup_start_lr=0.0,
  760. lr_decay_epochs=[30, 60, 90],
  761. lr_decay_gamma=0.1,
  762. use_vdl=False,
  763. sensitivities_file=None,
  764. pruned_flops=.2,
  765. early_stop=False,
  766. early_stop_patience=5):
  767. _legacy_train(
  768. self,
  769. num_epochs=num_epochs,
  770. train_dataset=train_dataset,
  771. train_batch_size=train_batch_size,
  772. eval_dataset=eval_dataset,
  773. save_interval_epochs=save_interval_epochs,
  774. log_interval_steps=log_interval_steps,
  775. save_dir=save_dir,
  776. pretrain_weights=pretrain_weights,
  777. optimizer=optimizer,
  778. learning_rate=learning_rate,
  779. warmup_steps=warmup_steps,
  780. warmup_start_lr=warmup_start_lr,
  781. lr_decay_epochs=lr_decay_epochs,
  782. lr_decay_gamma=lr_decay_gamma,
  783. use_vdl=use_vdl,
  784. sensitivities_file=sensitivities_file,
  785. pruned_flops=pruned_flops,
  786. early_stop=early_stop,
  787. early_stop_patience=early_stop_patience)
  788. class Xception65(cv.models.Xception65):
  789. def __init__(self, num_classes=1000, input_channel=None):
  790. if input_channel is not None:
  791. logging.warning(
  792. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  793. )
  794. super(Xception65, self).__init__(num_classes=num_classes)
  795. def train(self,
  796. num_epochs,
  797. train_dataset,
  798. train_batch_size=64,
  799. eval_dataset=None,
  800. save_interval_epochs=1,
  801. log_interval_steps=2,
  802. save_dir='output',
  803. pretrain_weights='IMAGENET',
  804. optimizer=None,
  805. learning_rate=0.025,
  806. warmup_steps=0,
  807. warmup_start_lr=0.0,
  808. lr_decay_epochs=[30, 60, 90],
  809. lr_decay_gamma=0.1,
  810. use_vdl=False,
  811. sensitivities_file=None,
  812. pruned_flops=.2,
  813. early_stop=False,
  814. early_stop_patience=5):
  815. _legacy_train(
  816. self,
  817. num_epochs=num_epochs,
  818. train_dataset=train_dataset,
  819. train_batch_size=train_batch_size,
  820. eval_dataset=eval_dataset,
  821. save_interval_epochs=save_interval_epochs,
  822. log_interval_steps=log_interval_steps,
  823. save_dir=save_dir,
  824. pretrain_weights=pretrain_weights,
  825. optimizer=optimizer,
  826. learning_rate=learning_rate,
  827. warmup_steps=warmup_steps,
  828. warmup_start_lr=warmup_start_lr,
  829. lr_decay_epochs=lr_decay_epochs,
  830. lr_decay_gamma=lr_decay_gamma,
  831. use_vdl=use_vdl,
  832. sensitivities_file=sensitivities_file,
  833. pruned_flops=pruned_flops,
  834. early_stop=early_stop,
  835. early_stop_patience=early_stop_patience)
  836. class DenseNet121(cv.models.DenseNet121):
  837. def __init__(self, num_classes=1000, input_channel=None):
  838. if input_channel is not None:
  839. logging.warning(
  840. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  841. )
  842. super(DenseNet121, self).__init__(num_classes=num_classes)
  843. def train(self,
  844. num_epochs,
  845. train_dataset,
  846. train_batch_size=64,
  847. eval_dataset=None,
  848. save_interval_epochs=1,
  849. log_interval_steps=2,
  850. save_dir='output',
  851. pretrain_weights='IMAGENET',
  852. optimizer=None,
  853. learning_rate=0.025,
  854. warmup_steps=0,
  855. warmup_start_lr=0.0,
  856. lr_decay_epochs=[30, 60, 90],
  857. lr_decay_gamma=0.1,
  858. use_vdl=False,
  859. sensitivities_file=None,
  860. pruned_flops=.2,
  861. early_stop=False,
  862. early_stop_patience=5):
  863. _legacy_train(
  864. self,
  865. num_epochs=num_epochs,
  866. train_dataset=train_dataset,
  867. train_batch_size=train_batch_size,
  868. eval_dataset=eval_dataset,
  869. save_interval_epochs=save_interval_epochs,
  870. log_interval_steps=log_interval_steps,
  871. save_dir=save_dir,
  872. pretrain_weights=pretrain_weights,
  873. optimizer=optimizer,
  874. learning_rate=learning_rate,
  875. warmup_steps=warmup_steps,
  876. warmup_start_lr=warmup_start_lr,
  877. lr_decay_epochs=lr_decay_epochs,
  878. lr_decay_gamma=lr_decay_gamma,
  879. use_vdl=use_vdl,
  880. sensitivities_file=sensitivities_file,
  881. pruned_flops=pruned_flops,
  882. early_stop=early_stop,
  883. early_stop_patience=early_stop_patience)
  884. class DenseNet161(cv.models.DenseNet161):
  885. def __init__(self, num_classes=1000, input_channel=None):
  886. if input_channel is not None:
  887. logging.warning(
  888. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  889. )
  890. super(DenseNet161, self).__init__(num_classes=num_classes)
  891. def train(self,
  892. num_epochs,
  893. train_dataset,
  894. train_batch_size=64,
  895. eval_dataset=None,
  896. save_interval_epochs=1,
  897. log_interval_steps=2,
  898. save_dir='output',
  899. pretrain_weights='IMAGENET',
  900. optimizer=None,
  901. learning_rate=0.025,
  902. warmup_steps=0,
  903. warmup_start_lr=0.0,
  904. lr_decay_epochs=[30, 60, 90],
  905. lr_decay_gamma=0.1,
  906. use_vdl=False,
  907. sensitivities_file=None,
  908. pruned_flops=.2,
  909. early_stop=False,
  910. early_stop_patience=5):
  911. _legacy_train(
  912. self,
  913. num_epochs=num_epochs,
  914. train_dataset=train_dataset,
  915. train_batch_size=train_batch_size,
  916. eval_dataset=eval_dataset,
  917. save_interval_epochs=save_interval_epochs,
  918. log_interval_steps=log_interval_steps,
  919. save_dir=save_dir,
  920. pretrain_weights=pretrain_weights,
  921. optimizer=optimizer,
  922. learning_rate=learning_rate,
  923. warmup_steps=warmup_steps,
  924. warmup_start_lr=warmup_start_lr,
  925. lr_decay_epochs=lr_decay_epochs,
  926. lr_decay_gamma=lr_decay_gamma,
  927. use_vdl=use_vdl,
  928. sensitivities_file=sensitivities_file,
  929. pruned_flops=pruned_flops,
  930. early_stop=early_stop,
  931. early_stop_patience=early_stop_patience)
  932. class DenseNet201(cv.models.DenseNet201):
  933. def __init__(self, num_classes=1000, input_channel=None):
  934. if input_channel is not None:
  935. logging.warning(
  936. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  937. )
  938. super(DenseNet201, self).__init__(num_classes=num_classes)
  939. def train(self,
  940. num_epochs,
  941. train_dataset,
  942. train_batch_size=64,
  943. eval_dataset=None,
  944. save_interval_epochs=1,
  945. log_interval_steps=2,
  946. save_dir='output',
  947. pretrain_weights='IMAGENET',
  948. optimizer=None,
  949. learning_rate=0.025,
  950. warmup_steps=0,
  951. warmup_start_lr=0.0,
  952. lr_decay_epochs=[30, 60, 90],
  953. lr_decay_gamma=0.1,
  954. use_vdl=False,
  955. sensitivities_file=None,
  956. pruned_flops=.2,
  957. early_stop=False,
  958. early_stop_patience=5):
  959. _legacy_train(
  960. self,
  961. num_epochs=num_epochs,
  962. train_dataset=train_dataset,
  963. train_batch_size=train_batch_size,
  964. eval_dataset=eval_dataset,
  965. save_interval_epochs=save_interval_epochs,
  966. log_interval_steps=log_interval_steps,
  967. save_dir=save_dir,
  968. pretrain_weights=pretrain_weights,
  969. optimizer=optimizer,
  970. learning_rate=learning_rate,
  971. warmup_steps=warmup_steps,
  972. warmup_start_lr=warmup_start_lr,
  973. lr_decay_epochs=lr_decay_epochs,
  974. lr_decay_gamma=lr_decay_gamma,
  975. use_vdl=use_vdl,
  976. sensitivities_file=sensitivities_file,
  977. pruned_flops=pruned_flops,
  978. early_stop=early_stop,
  979. early_stop_patience=early_stop_patience)
  980. class ShuffleNetV2(cv.models.ShuffleNetV2):
  981. def __init__(self, num_classes=1000, input_channel=None):
  982. if input_channel is not None:
  983. logging.warning(
  984. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  985. )
  986. super(ShuffleNetV2, self).__init__(num_classes=num_classes)
  987. def train(self,
  988. num_epochs,
  989. train_dataset,
  990. train_batch_size=64,
  991. eval_dataset=None,
  992. save_interval_epochs=1,
  993. log_interval_steps=2,
  994. save_dir='output',
  995. pretrain_weights='IMAGENET',
  996. optimizer=None,
  997. learning_rate=0.025,
  998. warmup_steps=0,
  999. warmup_start_lr=0.0,
  1000. lr_decay_epochs=[30, 60, 90],
  1001. lr_decay_gamma=0.1,
  1002. use_vdl=False,
  1003. sensitivities_file=None,
  1004. pruned_flops=.2,
  1005. early_stop=False,
  1006. early_stop_patience=5):
  1007. _legacy_train(
  1008. self,
  1009. num_epochs=num_epochs,
  1010. train_dataset=train_dataset,
  1011. train_batch_size=train_batch_size,
  1012. eval_dataset=eval_dataset,
  1013. save_interval_epochs=save_interval_epochs,
  1014. log_interval_steps=log_interval_steps,
  1015. save_dir=save_dir,
  1016. pretrain_weights=pretrain_weights,
  1017. optimizer=optimizer,
  1018. learning_rate=learning_rate,
  1019. warmup_steps=warmup_steps,
  1020. warmup_start_lr=warmup_start_lr,
  1021. lr_decay_epochs=lr_decay_epochs,
  1022. lr_decay_gamma=lr_decay_gamma,
  1023. use_vdl=use_vdl,
  1024. sensitivities_file=sensitivities_file,
  1025. pruned_flops=pruned_flops,
  1026. early_stop=early_stop,
  1027. early_stop_patience=early_stop_patience)
  1028. class HRNet_W18(cv.models.HRNet_W18_C):
  1029. def __init__(self, num_classes=1000, input_channel=None):
  1030. if input_channel is not None:
  1031. logging.warning(
  1032. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  1033. )
  1034. super(HRNet_W18, self).__init__(num_classes=num_classes)
  1035. def train(self,
  1036. num_epochs,
  1037. train_dataset,
  1038. train_batch_size=64,
  1039. eval_dataset=None,
  1040. save_interval_epochs=1,
  1041. log_interval_steps=2,
  1042. save_dir='output',
  1043. pretrain_weights='IMAGENET',
  1044. optimizer=None,
  1045. learning_rate=0.025,
  1046. warmup_steps=0,
  1047. warmup_start_lr=0.0,
  1048. lr_decay_epochs=[30, 60, 90],
  1049. lr_decay_gamma=0.1,
  1050. use_vdl=False,
  1051. sensitivities_file=None,
  1052. pruned_flops=.2,
  1053. early_stop=False,
  1054. early_stop_patience=5):
  1055. _legacy_train(
  1056. self,
  1057. num_epochs=num_epochs,
  1058. train_dataset=train_dataset,
  1059. train_batch_size=train_batch_size,
  1060. eval_dataset=eval_dataset,
  1061. save_interval_epochs=save_interval_epochs,
  1062. log_interval_steps=log_interval_steps,
  1063. save_dir=save_dir,
  1064. pretrain_weights=pretrain_weights,
  1065. optimizer=optimizer,
  1066. learning_rate=learning_rate,
  1067. warmup_steps=warmup_steps,
  1068. warmup_start_lr=warmup_start_lr,
  1069. lr_decay_epochs=lr_decay_epochs,
  1070. lr_decay_gamma=lr_decay_gamma,
  1071. use_vdl=use_vdl,
  1072. sensitivities_file=sensitivities_file,
  1073. pruned_flops=pruned_flops,
  1074. early_stop=early_stop,
  1075. early_stop_patience=early_stop_patience)
  1076. class AlexNet(cv.models.AlexNet):
  1077. def __init__(self, num_classes=1000, input_channel=None):
  1078. if input_channel is not None:
  1079. logging.warning(
  1080. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  1081. )
  1082. super(AlexNet, self).__init__(num_classes=num_classes)
  1083. def train(self,
  1084. num_epochs,
  1085. train_dataset,
  1086. train_batch_size=64,
  1087. eval_dataset=None,
  1088. save_interval_epochs=1,
  1089. log_interval_steps=2,
  1090. save_dir='output',
  1091. pretrain_weights='IMAGENET',
  1092. optimizer=None,
  1093. learning_rate=0.025,
  1094. warmup_steps=0,
  1095. warmup_start_lr=0.0,
  1096. lr_decay_epochs=[30, 60, 90],
  1097. lr_decay_gamma=0.1,
  1098. use_vdl=False,
  1099. sensitivities_file=None,
  1100. pruned_flops=.2,
  1101. early_stop=False,
  1102. early_stop_patience=5):
  1103. _legacy_train(
  1104. self,
  1105. num_epochs=num_epochs,
  1106. train_dataset=train_dataset,
  1107. train_batch_size=train_batch_size,
  1108. eval_dataset=eval_dataset,
  1109. save_interval_epochs=save_interval_epochs,
  1110. log_interval_steps=log_interval_steps,
  1111. save_dir=save_dir,
  1112. pretrain_weights=pretrain_weights,
  1113. optimizer=optimizer,
  1114. learning_rate=learning_rate,
  1115. warmup_steps=warmup_steps,
  1116. warmup_start_lr=warmup_start_lr,
  1117. lr_decay_epochs=lr_decay_epochs,
  1118. lr_decay_gamma=lr_decay_gamma,
  1119. use_vdl=use_vdl,
  1120. sensitivities_file=sensitivities_file,
  1121. pruned_flops=pruned_flops,
  1122. early_stop=early_stop,
  1123. early_stop_patience=early_stop_patience)
  1124. def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
  1125. eval_dataset, save_interval_epochs, log_interval_steps,
  1126. save_dir, pretrain_weights, optimizer, learning_rate,
  1127. warmup_steps, warmup_start_lr, lr_decay_epochs,
  1128. lr_decay_gamma, use_vdl, sensitivities_file, pruned_flops,
  1129. early_stop, early_stop_patience):
  1130. model.labels = train_dataset.labels
  1131. # initiate weights
  1132. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  1133. if pretrain_weights not in ['IMAGENET']:
  1134. logging.warning("Path of pretrain_weights('{}') does not exist!".
  1135. format(pretrain_weights))
  1136. logging.warning("Pretrain_weights is forcibly set to 'IMAGENET'. "
  1137. "If don't want to use pretrain weights, "
  1138. "set pretrain_weights to be None.")
  1139. pretrain_weights = 'IMAGENET'
  1140. pretrained_dir = osp.join(save_dir, 'pretrain')
  1141. model.net_initialize(
  1142. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  1143. if sensitivities_file is not None:
  1144. dataset = eval_dataset or train_dataset
  1145. inputs = [1, 3] + list(dataset[0]['image'].shape[:2])
  1146. model.pruner = L1NormFilterPruner(
  1147. model.net, inputs=inputs, sen_file=sensitivities_file)
  1148. model.pruner.sensitive_prune(pruned_flops=pruned_flops)
  1149. # build optimizer if not defined
  1150. if optimizer is None:
  1151. num_steps_each_epoch = len(train_dataset) // train_batch_size
  1152. model.optimizer = model.default_optimizer(
  1153. parameters=model.net.parameters(),
  1154. learning_rate=learning_rate,
  1155. warmup_steps=warmup_steps,
  1156. warmup_start_lr=warmup_start_lr,
  1157. lr_decay_epochs=lr_decay_epochs,
  1158. lr_decay_gamma=lr_decay_gamma,
  1159. num_steps_each_epoch=num_steps_each_epoch)
  1160. else:
  1161. model.optimizer = optimizer
  1162. model.train_loop(
  1163. num_epochs=num_epochs,
  1164. train_dataset=train_dataset,
  1165. train_batch_size=train_batch_size,
  1166. eval_dataset=eval_dataset,
  1167. save_interval_epochs=save_interval_epochs,
  1168. log_interval_steps=log_interval_steps,
  1169. save_dir=save_dir,
  1170. early_stop=early_stop,
  1171. early_stop_patience=early_stop_patience,
  1172. use_vdl=use_vdl)