seg.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. from paddleslim import L1NormFilterPruner
  16. from . import cv
  17. from .cv.models.utils.visualize import visualize_segmentation
  18. from paddlex.cv.transforms import seg_transforms
  19. import paddlex.utils.logging as logging
  20. from paddlex.utils.checkpoint import seg_pretrain_weights_dict
  21. transforms = seg_transforms
  22. visualize = visualize_segmentation
  23. class UNet(cv.models.UNet):
  24. def __init__(self,
  25. num_classes=2,
  26. upsample_mode='bilinear',
  27. use_bce_loss=False,
  28. use_dice_loss=False,
  29. class_weight=None,
  30. ignore_index=None,
  31. input_channel=None):
  32. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  33. raise ValueError(
  34. "dice loss and bce loss is only applicable to binary classification"
  35. )
  36. elif num_classes == 2:
  37. if use_bce_loss and use_dice_loss:
  38. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  39. elif use_bce_loss:
  40. use_mixed_loss = [('CrossEntropyLoss', 1)]
  41. elif use_dice_loss:
  42. use_mixed_loss = [('DiceLoss', 1)]
  43. else:
  44. use_mixed_loss = False
  45. else:
  46. use_mixed_loss = False
  47. if class_weight is not None:
  48. logging.warning(
  49. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  50. )
  51. if ignore_index is not None:
  52. logging.warning(
  53. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  54. )
  55. if input_channel is not None:
  56. logging.warning(
  57. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  58. )
  59. if upsample_mode == 'bilinear':
  60. use_deconv = False
  61. else:
  62. use_deconv = True
  63. super(UNet, self).__init__(
  64. num_classes=num_classes,
  65. use_mixed_loss=use_mixed_loss,
  66. use_deconv=use_deconv)
  67. def train(self,
  68. num_epochs,
  69. train_dataset,
  70. train_batch_size=2,
  71. eval_dataset=None,
  72. save_interval_epochs=1,
  73. log_interval_steps=2,
  74. save_dir='output',
  75. pretrain_weights='COCO',
  76. optimizer=None,
  77. learning_rate=0.01,
  78. lr_decay_power=0.9,
  79. use_vdl=False,
  80. sensitivities_file=None,
  81. pruned_flops=.2,
  82. early_stop=False,
  83. early_stop_patience=5):
  84. _legacy_train(
  85. self,
  86. num_epochs=num_epochs,
  87. train_dataset=train_dataset,
  88. train_batch_size=train_batch_size,
  89. eval_dataset=eval_dataset,
  90. save_interval_epochs=save_interval_epochs,
  91. log_interval_steps=log_interval_steps,
  92. save_dir=save_dir,
  93. pretrain_weights=pretrain_weights,
  94. optimizer=optimizer,
  95. learning_rate=learning_rate,
  96. lr_decay_power=lr_decay_power,
  97. use_vdl=use_vdl,
  98. sensitivities_file=sensitivities_file,
  99. pruned_flops=pruned_flops,
  100. early_stop=early_stop,
  101. early_stop_patience=early_stop_patience)
  102. class DeepLabv3p(cv.models.DeepLabV3P):
  103. def __init__(self,
  104. num_classes=2,
  105. backbone='ResNet50_vd',
  106. output_stride=8,
  107. aspp_with_sep_conv=None,
  108. decoder_use_sep_conv=None,
  109. encoder_with_aspp=None,
  110. enable_decoder=None,
  111. use_bce_loss=False,
  112. use_dice_loss=False,
  113. class_weight=None,
  114. ignore_index=None,
  115. pooling_crop_size=None,
  116. input_channel=None):
  117. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  118. raise ValueError(
  119. "dice loss and bce loss is only applicable to binary classification"
  120. )
  121. elif num_classes == 2:
  122. if use_bce_loss and use_dice_loss:
  123. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  124. elif use_bce_loss:
  125. use_mixed_loss = [('CrossEntropyLoss', 1)]
  126. elif use_dice_loss:
  127. use_mixed_loss = [('DiceLoss', 1)]
  128. else:
  129. use_mixed_loss = False
  130. else:
  131. use_mixed_loss = False
  132. if aspp_with_sep_conv is not None:
  133. logging.warning(
  134. "`aspp_with_sep_conv` is deprecated in PaddleX 2.0 and will not take effect. "
  135. "Defaults to True")
  136. if decoder_use_sep_conv is not None:
  137. logging.warning(
  138. "`decoder_use_sep_conv` is deprecated in PaddleX 2.0 and will not take effect. "
  139. "Defaults to True")
  140. if encoder_with_aspp is not None:
  141. logging.warning(
  142. "`encoder_with_aspp` is deprecated in PaddleX 2.0 and will not take effect. "
  143. "Defaults to True")
  144. if enable_decoder is not None:
  145. logging.warning(
  146. "`enable_decoder` is deprecated in PaddleX 2.0 and will not take effect. "
  147. "Defaults to True")
  148. if class_weight is not None:
  149. logging.warning(
  150. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  151. )
  152. if ignore_index is not None:
  153. logging.warning(
  154. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  155. )
  156. if pooling_crop_size is not None:
  157. logging.warning(
  158. "Backbone 'MobileNetV3_large_x1_0_ssld' is currently not supported in PaddleX 2.0. "
  159. "`pooling_crop_size` will not take effect. Defaults to None")
  160. if input_channel is not None:
  161. logging.warning(
  162. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  163. )
  164. super(DeepLabv3p, self).__init__(
  165. num_classes=num_classes,
  166. backbone=backbone,
  167. use_mixed_loss=use_mixed_loss,
  168. output_stride=output_stride)
  169. def train(self,
  170. num_epochs,
  171. train_dataset,
  172. train_batch_size=2,
  173. eval_dataset=None,
  174. save_interval_epochs=1,
  175. log_interval_steps=2,
  176. save_dir='output',
  177. pretrain_weights='IMAGENET',
  178. optimizer=None,
  179. learning_rate=0.01,
  180. lr_decay_power=0.9,
  181. use_vdl=False,
  182. sensitivities_file=None,
  183. pruned_flops=.2,
  184. early_stop=False,
  185. early_stop_patience=5):
  186. _legacy_train(
  187. self,
  188. num_epochs=num_epochs,
  189. train_dataset=train_dataset,
  190. train_batch_size=train_batch_size,
  191. eval_dataset=eval_dataset,
  192. save_interval_epochs=save_interval_epochs,
  193. log_interval_steps=log_interval_steps,
  194. save_dir=save_dir,
  195. pretrain_weights=pretrain_weights,
  196. optimizer=optimizer,
  197. learning_rate=learning_rate,
  198. lr_decay_power=lr_decay_power,
  199. use_vdl=use_vdl,
  200. sensitivities_file=sensitivities_file,
  201. pruned_flops=pruned_flops,
  202. early_stop=early_stop,
  203. early_stop_patience=early_stop_patience)
  204. class HRNet(cv.models.HRNet):
  205. def __init__(self,
  206. num_classes=2,
  207. width=18,
  208. use_bce_loss=False,
  209. use_dice_loss=False,
  210. class_weight=None,
  211. ignore_index=None,
  212. input_channel=None):
  213. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  214. raise ValueError(
  215. "dice loss and bce loss is only applicable to binary classification"
  216. )
  217. elif num_classes == 2:
  218. if use_bce_loss and use_dice_loss:
  219. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  220. elif use_bce_loss:
  221. use_mixed_loss = [('CrossEntropyLoss', 1)]
  222. elif use_dice_loss:
  223. use_mixed_loss = [('DiceLoss', 1)]
  224. else:
  225. use_mixed_loss = False
  226. else:
  227. use_mixed_loss = False
  228. if class_weight is not None:
  229. logging.warning(
  230. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  231. )
  232. if ignore_index is not None:
  233. logging.warning(
  234. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  235. )
  236. if input_channel is not None:
  237. logging.warning(
  238. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  239. )
  240. super(HRNet, self).__init__(
  241. num_classes=num_classes,
  242. width=width,
  243. use_mixed_loss=use_mixed_loss)
  244. def train(self,
  245. num_epochs,
  246. train_dataset,
  247. train_batch_size=2,
  248. eval_dataset=None,
  249. save_interval_epochs=1,
  250. log_interval_steps=2,
  251. save_dir='output',
  252. pretrain_weights='IMAGENET',
  253. optimizer=None,
  254. learning_rate=0.01,
  255. lr_decay_power=0.9,
  256. use_vdl=False,
  257. sensitivities_file=None,
  258. pruned_flops=.2,
  259. early_stop=False,
  260. early_stop_patience=5):
  261. _legacy_train(
  262. self,
  263. num_epochs=num_epochs,
  264. train_dataset=train_dataset,
  265. train_batch_size=train_batch_size,
  266. eval_dataset=eval_dataset,
  267. save_interval_epochs=save_interval_epochs,
  268. log_interval_steps=log_interval_steps,
  269. save_dir=save_dir,
  270. pretrain_weights=pretrain_weights,
  271. optimizer=optimizer,
  272. learning_rate=learning_rate,
  273. lr_decay_power=lr_decay_power,
  274. use_vdl=use_vdl,
  275. sensitivities_file=sensitivities_file,
  276. pruned_flops=pruned_flops,
  277. early_stop=early_stop,
  278. early_stop_patience=early_stop_patience)
  279. class FastSCNN(cv.models.FastSCNN):
  280. def __init__(self,
  281. num_classes=2,
  282. use_bce_loss=False,
  283. use_dice_loss=False,
  284. class_weight=None,
  285. ignore_index=255,
  286. multi_loss_weight=None,
  287. input_channel=3):
  288. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  289. raise ValueError(
  290. "dice loss and bce loss is only applicable to binary classification"
  291. )
  292. elif num_classes == 2:
  293. if use_bce_loss and use_dice_loss:
  294. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  295. elif use_bce_loss:
  296. use_mixed_loss = [('CrossEntropyLoss', 1)]
  297. elif use_dice_loss:
  298. use_mixed_loss = [('DiceLoss', 1)]
  299. else:
  300. use_mixed_loss = False
  301. else:
  302. use_mixed_loss = False
  303. if class_weight is not None:
  304. logging.warning(
  305. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  306. )
  307. if ignore_index is not None:
  308. logging.warning(
  309. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  310. )
  311. if multi_loss_weight is not None:
  312. logging.warning(
  313. "`multi_loss_weight` is deprecated in PaddleX 2.0 and will not take effect. "
  314. "Defaults to [1.0, 0.4]")
  315. if input_channel is not None:
  316. logging.warning(
  317. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  318. )
  319. super(FastSCNN, self).__init__(
  320. num_classes=num_classes, use_mixed_loss=use_mixed_loss)
  321. def train(self,
  322. num_epochs,
  323. train_dataset,
  324. train_batch_size=2,
  325. eval_dataset=None,
  326. save_interval_epochs=1,
  327. log_interval_steps=2,
  328. save_dir='output',
  329. pretrain_weights='CITYSCAPES',
  330. optimizer=None,
  331. learning_rate=0.01,
  332. lr_decay_power=0.9,
  333. use_vdl=False,
  334. sensitivities_file=None,
  335. pruned_flops=.2,
  336. early_stop=False,
  337. early_stop_patience=5):
  338. _legacy_train(
  339. self,
  340. num_epochs=num_epochs,
  341. train_dataset=train_dataset,
  342. train_batch_size=train_batch_size,
  343. eval_dataset=eval_dataset,
  344. save_interval_epochs=save_interval_epochs,
  345. log_interval_steps=log_interval_steps,
  346. save_dir=save_dir,
  347. pretrain_weights=pretrain_weights,
  348. optimizer=optimizer,
  349. learning_rate=learning_rate,
  350. lr_decay_power=lr_decay_power,
  351. use_vdl=use_vdl,
  352. sensitivities_file=sensitivities_file,
  353. pruned_flops=pruned_flops,
  354. early_stop=early_stop,
  355. early_stop_patience=early_stop_patience)
  356. def _legacy_train(model, num_epochs, train_dataset, train_batch_size,
  357. eval_dataset, save_interval_epochs, log_interval_steps,
  358. save_dir, pretrain_weights, optimizer, learning_rate,
  359. lr_decay_power, use_vdl, sensitivities_file, pruned_flops,
  360. early_stop, early_stop_patience):
  361. model.labels = train_dataset.labels
  362. if model.losses is None:
  363. model.losses = model.default_loss()
  364. # initiate weights
  365. if pretrain_weights is not None and not osp.exists(pretrain_weights):
  366. if pretrain_weights not in seg_pretrain_weights_dict[model.model_name]:
  367. logging.warning("Path of pretrain_weights('{}') does not exist!".
  368. format(pretrain_weights))
  369. logging.warning("Pretrain_weights is forcibly set to '{}'. "
  370. "If don't want to use pretrain weights, "
  371. "set pretrain_weights to be None.".format(
  372. seg_pretrain_weights_dict[model.model_name][
  373. 0]))
  374. pretrain_weights = seg_pretrain_weights_dict[model.model_name][0]
  375. pretrained_dir = osp.join(save_dir, 'pretrain')
  376. model.net_initialize(
  377. pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
  378. if sensitivities_file is not None:
  379. dataset = eval_dataset or train_dataset
  380. inputs = [1, 3] + list(dataset[0]['image'].shape[:2])
  381. model.pruner = L1NormFilterPruner(
  382. model.net, inputs=inputs, sen_file=sensitivities_file)
  383. model.pruner.sensitive_prune(pruned_flops=pruned_flops)
  384. # build optimizer if not defined
  385. if optimizer is None:
  386. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  387. model.optimizer = model.default_optimizer(
  388. model.net.parameters(), learning_rate, num_epochs,
  389. num_steps_each_epoch, lr_decay_power)
  390. else:
  391. model.optimizer = optimizer
  392. model.train_loop(
  393. num_epochs=num_epochs,
  394. train_dataset=train_dataset,
  395. train_batch_size=train_batch_size,
  396. eval_dataset=eval_dataset,
  397. save_interval_epochs=save_interval_epochs,
  398. log_interval_steps=log_interval_steps,
  399. save_dir=save_dir,
  400. early_stop=early_stop,
  401. early_stop_patience=early_stop_patience,
  402. use_vdl=use_vdl)