params.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import platform
  15. import os
  16. class Params(object):
  17. def __init__(self):
  18. self.init_train_params()
  19. self.init_transform_params()
  20. def init_train_params(self):
  21. self.cuda_visible_devices = ''
  22. self.batch_size = 2
  23. self.save_interval_epochs = 1
  24. self.pretrain_weights = 'IMAGENET'
  25. self.model = 'MobileNetV2'
  26. self.num_epochs = 4
  27. self.learning_rate = 0.000125
  28. self.lr_decay_epochs = [2, 3]
  29. self.train_num = 0
  30. self.resume_checkpoint = None
  31. self.sensitivities_path = None
  32. self.pruned_flops = None
  33. def init_transform_params(self):
  34. self.image_shape = [224, 224]
  35. self.image_mean = [0.485, 0.456, 0.406]
  36. self.image_std = [0.229, 0.224, 0.225]
  37. self.horizontal_flip_prob = 0.5
  38. self.brightness_range = 0.9
  39. self.brightness_prob = 0.5
  40. self.contrast_range = 0.9
  41. self.contrast_prob = 0.5
  42. self.saturation_range = 0.9
  43. self.saturation_prob = 0.5
  44. self.hue_range = 18
  45. self.hue_prob = 0.5
  46. self.horizontal_flip = True
  47. self.brightness = True
  48. self.contrast = True
  49. self.saturation = True
  50. self.hue = True
  51. def load_from_dict(self, params_dict):
  52. for attr in params_dict:
  53. if hasattr(self, attr):
  54. method = getattr(self, "set_" + attr)
  55. method(params_dict[attr])
  56. def set_cuda_visible_devices(self, cuda_visible_devices):
  57. self.cuda_visible_devices = cuda_visible_devices
  58. def set_batch_size(self, batch_size):
  59. self.batch_size = batch_size
  60. def set_save_interval_epochs(self, save_interval_epochs):
  61. self.save_interval_epochs = save_interval_epochs
  62. def set_pretrain_weights(self, pretrain_weights):
  63. self.pretrain_weights = pretrain_weights
  64. def set_model(self, model):
  65. self.model = model
  66. def set_num_epochs(self, num_epochs):
  67. self.num_epochs = num_epochs
  68. def set_learning_rate(self, learning_rate):
  69. self.learning_rate = learning_rate
  70. def set_lr_decay_epochs(self, lr_decay_epochs):
  71. self.lr_decay_epochs = lr_decay_epochs
  72. def set_resume_checkpoint(self, resume_checkpoint):
  73. self.resume_checkpoint = resume_checkpoint
  74. def set_sensitivities_path(self, sensitivities_path):
  75. self.sensitivities_path = sensitivities_path
  76. def set_eval_metric_loss(self, eval_metric_loss):
  77. self.eval_metric_loss = eval_metric_loss
  78. def set_image_shape(self, image_shape):
  79. self.image_shape = image_shape
  80. def set_image_mean(self, image_mean):
  81. self.image_mean = image_mean
  82. def set_image_std(self, image_std):
  83. self.image_std = image_std
  84. def set_horizontal_flip(self, horizontal_flip):
  85. self.horizontal_flip = horizontal_flip
  86. if not horizontal_flip:
  87. self.horizontal_flip_prob = 0.0
  88. def set_horizontal_flip_prob(self, horizontal_flip_prob):
  89. if self.horizontal_flip:
  90. self.horizontal_flip_prob = horizontal_flip_prob
  91. def set_brightness_range(self, brightness_range):
  92. self.brightness_range = brightness_range
  93. def set_brightness_prob(self, brightness_prob):
  94. if self.brightness:
  95. self.brightness_prob = brightness_prob
  96. def set_brightness(self, brightness):
  97. self.brightness = brightness
  98. if not brightness:
  99. self.brightness_prob = 0.0
  100. def set_contrast(self, contrast):
  101. self.contrast = contrast
  102. if not contrast:
  103. self.contrast_prob = 0.0
  104. def set_contrast_prob(self, contrast_prob):
  105. if self.contrast:
  106. self.contrast_prob = contrast_prob
  107. def set_contrast_range(self, contrast_range):
  108. self.contrast_range = contrast_range
  109. def set_saturation(self, saturation):
  110. self.saturation = saturation
  111. if not saturation:
  112. self.saturation_prob = 0.0
  113. def set_saturation_prob(self, saturation_prob):
  114. if self.saturation:
  115. self.saturation_prob = saturation_prob
  116. def set_saturation_range(self, saturation_range):
  117. self.saturation_range = saturation_range
  118. def set_hue(self, hue):
  119. self.hue = hue
  120. if not hue:
  121. self.hue_prob = 0.0
  122. def set_hue_prob(self, hue_prob):
  123. if self.hue_prob:
  124. self.hue_prob = hue_prob
  125. def set_hue_range(self, hue_range):
  126. self.hue_range = hue_range
  127. def set_train_num(self, train_num):
  128. self.train_num = train_num
  129. class ClsParams(Params):
  130. def __init__(self):
  131. super(ClsParams, self).__init__()
  132. self.lr_policy = 'Piecewise'
  133. self.vertical_flip_prob = 0.0
  134. self.vertical_flip = True
  135. def set_lr_policy(self, lr_policy):
  136. self.lr_policy = lr_policy
  137. def set_vertical_flip(self, vertical_flip):
  138. self.vertical_flip = vertical_flip
  139. if not self.vertical_flip:
  140. self.vertical_flip_prob = 0.0
  141. def set_vertical_flip_prob(self, vertical_flip_prob):
  142. if self.vertical_flip:
  143. self.vertical_flip_prob = vertical_flip_prob
  144. class DetParams(Params):
  145. def __init__(self):
  146. super(DetParams, self).__init__()
  147. self.warmup_steps = 10
  148. self.warmup_start_lr = 0.
  149. self.use_mixup = True
  150. self.mixup_alpha = 1.5
  151. self.mixup_beta = 1.5
  152. self.expand_prob = 0.5
  153. self.expand_image = True
  154. self.crop_image = True
  155. self.backbone = 'ResNet18'
  156. self.pretrain_weights = 'COCO'
  157. self.model = 'FasterRCNN'
  158. self.with_fpn = True
  159. self.random_shape = True
  160. self.random_shape_sizes = [
  161. 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
  162. ]
  163. def set_warmup_steps(self, warmup_steps):
  164. self.warmup_steps = warmup_steps
  165. def set_warmup_start_lr(self, warmup_start_lr):
  166. self.warmup_start_lr = warmup_start_lr
  167. def set_use_mixup(self, use_mixup):
  168. self.use_mixup = use_mixup
  169. def set_mixup_alpha(self, mixup_alpha):
  170. self.mixup_alpha = mixup_alpha
  171. def set_mixup_beta(self, mixup_beta):
  172. self.mixup_beta = mixup_beta
  173. def set_expand_image(self, expand_image):
  174. self.expand_image = expand_image
  175. if not expand_image:
  176. self.expand_prob = 0.0
  177. def set_expand_prob(self, expand_prob):
  178. if self.expand_image:
  179. self.expand_prob = expand_prob
  180. def set_crop_image(self, crop_image):
  181. self.crop_image = crop_image
  182. def set_backbone(self, backbone):
  183. self.backbone = backbone
  184. def set_with_fpn(self, with_fpn):
  185. self.with_fpn = with_fpn
  186. def set_random_shape(self, random_shape):
  187. self.random_shape = random_shape
  188. def set_random_shape_sizes(self, random_shape_sizes):
  189. self.random_shape_sizes = random_shape_sizes
  190. class SegParams(Params):
  191. def __init__(self):
  192. super(SegParams, self).__init__()
  193. self.loss_type = [True, True]
  194. self.lr_policy = 'Piecewise'
  195. self.optimizer = 'Adam'
  196. self.backbone = 'ResNet50_vd'
  197. self.blur = True
  198. self.blur_prob = 0.
  199. self.scale_aspect = False
  200. self.min_ratio = 0.5
  201. self.aspect_ratio = 0.33
  202. self.vertical_flip_prob = 0.0
  203. self.vertical_flip = True
  204. self.model = 'UNet'
  205. def set_loss_type(self, loss_type):
  206. self.loss_type = loss_type
  207. def set_lr_policy(self, lr_policy):
  208. self.lr_policy = lr_policy
  209. def set_optimizer(self, optimizer):
  210. self.optimizer = optimizer
  211. def set_backbone(self, backbone):
  212. self.backbone = backbone
  213. def set_blur(self, blur):
  214. self.blur = blur
  215. if not blur:
  216. self.blur_prob = 0.
  217. def set_blur_prob(self, blur_prob):
  218. if self.blur:
  219. self.blur_prob = blur_prob
  220. def set_scale_aspect(self, scale_aspect):
  221. self.scale_aspect = scale_aspect
  222. def set_min_ratio(self, min_ratio):
  223. self.min_ratio = min_ratio
  224. def set_aspect_ratio(self, aspect_ratio):
  225. self.aspect_ratio = aspect_ratio
  226. def set_vertical_flip(self, vertical_flip):
  227. self.vertical_flip = vertical_flip
  228. if not vertical_flip:
  229. self.vertical_flip_prob = 0.0
  230. def set_vertical_flip_prob(self, vertical_flip_prob):
  231. if vertical_flip_prob:
  232. self.vertical_flip_prob = vertical_flip_prob
  233. PARAMS_CLASS_LIST = [ClsParams, DetParams, SegParams, DetParams, SegParams]
  234. def recommend_parameters(params, train_nums, class_nums, memory_size_per_gpu):
  235. model_type = params['model']
  236. gpu_list = params['cuda_visible_devices']
  237. if 'cpu_num' in params:
  238. cpu_num = params['cpu_num']
  239. else:
  240. cpu_num = int(os.environ.get('CPU_NUM', 1))
  241. if cpu_num > 8:
  242. os.environ['CPU_NUM'] = '8'
  243. if not params['use_gpu']:
  244. gpu_nums = 0
  245. else:
  246. gpu_nums = len(gpu_list.split(','))
  247. # set batch_size
  248. if gpu_nums == 0 or platform.platform().startswith("Darwin"):
  249. if model_type.startswith('MobileNet'):
  250. batch_size = 8 * cpu_num
  251. elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
  252. or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
  253. or model_type.startswith('ShuffleNet'):
  254. batch_size = 4 * cpu_num
  255. elif model_type.startswith('YOLOv3') or model_type.startswith(
  256. 'PPYOLO'):
  257. batch_size = 2 * cpu_num
  258. elif model_type.startswith('FasterRCNN') or model_type.startswith(
  259. 'MaskRCNN'):
  260. batch_size = 1 * cpu_num
  261. elif model_type.startswith('DeepLab') or model_type.startswith('UNet') \
  262. or model_type.startswith('HRNet_W18') or model_type.startswith('FastSCNN') or model_type.startswith('BiSeNetV2'):
  263. batch_size = 2 * cpu_num
  264. else:
  265. if model_type.startswith('MobileNet'):
  266. batch_size = (memory_size_per_gpu - 513) // 57 * gpu_nums
  267. batch_size = min(batch_size, gpu_nums * 125)
  268. elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
  269. or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
  270. or model_type.startswith('ShuffleNet'):
  271. batch_size = (memory_size_per_gpu - 739) // 211 * gpu_nums
  272. batch_size = min(batch_size, gpu_nums * 16)
  273. elif model_type.startswith('YOLOv3'):
  274. batch_size = (memory_size_per_gpu - 1555) // 943 * gpu_nums
  275. batch_size = min(batch_size, gpu_nums * 8)
  276. elif model_type.startswith('PPYOLOTiny'):
  277. batch_size = (memory_size_per_gpu - 579) // 1025 * gpu_nums
  278. batch_size = min(batch_size, gpu_nums * 16)
  279. elif model_type.startswith('PPYOLO'):
  280. batch_size = (memory_size_per_gpu - 1691) // 1025 * gpu_nums
  281. batch_size = min(batch_size, gpu_nums * 4)
  282. elif model_type.startswith('FasterRCNN'):
  283. batch_size = (memory_size_per_gpu - 1755) // 915 * gpu_nums
  284. batch_size = min(batch_size, gpu_nums * 2)
  285. elif model_type.startswith('MaskRCNN'):
  286. batch_size = (memory_size_per_gpu - 2702) // 1188 * gpu_nums
  287. batch_size = min(batch_size, gpu_nums * 2)
  288. elif model_type.startswith('DeepLab'):
  289. batch_size = (memory_size_per_gpu - 1469) // 1605 * gpu_nums
  290. batch_size = min(batch_size, gpu_nums * 4)
  291. elif model_type.startswith('BiSeNetV2'):
  292. batch_size = (memory_size_per_gpu - 591) // 1605 * gpu_nums
  293. batch_size = min(batch_size, gpu_nums * 4)
  294. elif model_type.startswith('UNet'):
  295. batch_size = (memory_size_per_gpu - 1275) // 1256 * gpu_nums
  296. batch_size = min(batch_size, gpu_nums * 4)
  297. elif model_type.startswith('HRNet_W18'):
  298. batch_size = (memory_size_per_gpu - 800) // 682 * gpu_nums
  299. batch_size = min(batch_size, gpu_nums * 4)
  300. elif model_type.startswith('FastSCNN'):
  301. batch_size = (memory_size_per_gpu - 636) // 144 * gpu_nums
  302. batch_size = min(batch_size, gpu_nums * 4)
  303. if batch_size > train_nums // 2:
  304. batch_size = train_nums // 2
  305. gpu_list = '{}'.format(gpu_list.split(',')[0]) if gpu_nums > 0 else ''
  306. if batch_size <= 0:
  307. batch_size = 1
  308. # set learning_rate
  309. if model_type.startswith('MobileNet'):
  310. lr = (batch_size / 500.0) * 0.1
  311. elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
  312. or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
  313. or model_type.startswith('ShuffleNet'):
  314. lr = (batch_size / 256.0) * 0.1
  315. elif model_type.startswith('PPYOLOTiny'):
  316. lr = .005 * batch_size / 16
  317. num_steps_each_epoch = train_nums // batch_size
  318. min_warmup_step = max(3 * num_steps_each_epoch, 50 * class_nums)
  319. if gpu_nums == 0:
  320. gpu_nums = 1
  321. warmup_step = min(min_warmup_step, int(400 * class_nums / gpu_nums))
  322. elif model_type.startswith('YOLOv3') or model_type.startswith('PPYOLO'):
  323. lr = .005 / 12 * batch_size / 8
  324. num_steps_each_epoch = train_nums // batch_size
  325. min_warmup_step = max(3 * num_steps_each_epoch, 50 * class_nums)
  326. if gpu_nums == 0:
  327. gpu_nums = 1
  328. warmup_step = min(min_warmup_step, int(400 * class_nums / gpu_nums))
  329. elif model_type.startswith('FasterRCNN') or model_type.startswith(
  330. 'MaskRCNN'):
  331. lr = 0.02 * batch_size / 16
  332. num_steps_each_epoch = train_nums // batch_size
  333. min_warmup_step = max(num_steps_each_epoch, 50)
  334. if gpu_nums == 0:
  335. gpu_nums = 1
  336. warmup_step = min(min_warmup_step, int(4000 / gpu_nums))
  337. elif model_type.startswith('DeepLab') or model_type.startswith('UNet') \
  338. or model_type.startswith('HRNet_W18') or model_type.startswith('FastSCNN') or model_type.startswith('BiSeNetV2'):
  339. lr = 0.01 * batch_size / 2
  340. loss_type = [False, False]
  341. params['batch_size'] = batch_size
  342. params['learning_rate'] = lr
  343. params['cuda_visible_devices'] = gpu_list
  344. if model_type in [
  345. 'YOLOv3', 'PPYOLO', 'PPYOLOTiny', 'PPYOLOv2', 'FasterRCNN',
  346. 'MaskRCNN'
  347. ]:
  348. num_epochs = params['num_epochs']
  349. lr_decay_epochs = params['lr_decay_epochs']
  350. if warmup_step >= lr_decay_epochs[0] * num_steps_each_epoch:
  351. for i in range(len(lr_decay_epochs)):
  352. lr_decay_epochs[i] += warmup_step // num_steps_each_epoch
  353. num_epochs += warmup_step // num_steps_each_epoch
  354. params['num_epochs'] = num_epochs
  355. params['lr_decay_epochs'] = lr_decay_epochs
  356. params['warmup_steps'] = warmup_step
  357. if 'loss_type' in params:
  358. params['loss_type'] = loss_type