params.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import platform
  15. import os
  16. class Params(object):
  17. def __init__(self):
  18. self.init_train_params()
  19. self.init_transform_params()
  20. def init_train_params(self):
  21. self.cuda_visible_devices = ''
  22. self.batch_size = 2
  23. self.save_interval_epochs = 1
  24. self.pretrain_weights = 'IMAGENET'
  25. self.model = 'MobileNetV2'
  26. self.num_epochs = 4
  27. self.learning_rate = 0.000125
  28. self.lr_decay_epochs = [2, 3]
  29. self.train_num = 0
  30. self.resume_checkpoint = None
  31. self.sensitivities_path = None
  32. self.eval_metric_loss = None
  33. def init_transform_params(self):
  34. self.image_shape = [224, 224]
  35. self.image_mean = [0.485, 0.456, 0.406]
  36. self.image_std = [0.229, 0.224, 0.225]
  37. self.horizontal_flip_prob = 0.5
  38. self.brightness_range = 0.9
  39. self.brightness_prob = 0.5
  40. self.contrast_range = 0.9
  41. self.contrast_prob = 0.5
  42. self.saturation_range = 0.9
  43. self.saturation_prob = 0.5
  44. self.hue_range = 18
  45. self.hue_prob = 0.5
  46. self.horizontal_flip = True
  47. self.brightness = True
  48. self.contrast = True
  49. self.saturation = True
  50. self.hue = True
  51. def load_from_dict(self, params_dict):
  52. for attr in params_dict:
  53. if hasattr(self, attr):
  54. method = getattr(self, "set_" + attr)
  55. method(params_dict[attr])
  56. def set_cuda_visible_devices(self, cuda_visible_devices):
  57. self.cuda_visible_devices = cuda_visible_devices
  58. def set_batch_size(self, batch_size):
  59. self.batch_size = batch_size
  60. def set_save_interval_epochs(self, save_interval_epochs):
  61. self.save_interval_epochs = save_interval_epochs
  62. def set_pretrain_weights(self, pretrain_weights):
  63. self.pretrain_weights = pretrain_weights
  64. def set_model(self, model):
  65. self.model = model
  66. def set_num_epochs(self, num_epochs):
  67. self.num_epochs = num_epochs
  68. def set_learning_rate(self, learning_rate):
  69. self.learning_rate = learning_rate
  70. def set_lr_decay_epochs(self, lr_decay_epochs):
  71. self.lr_decay_epochs = lr_decay_epochs
  72. def set_resume_checkpoint(self, resume_checkpoint):
  73. self.resume_checkpoint = resume_checkpoint
  74. def set_sensitivities_path(self, sensitivities_path):
  75. self.sensitivities_path = sensitivities_path
  76. def set_eval_metric_loss(self, eval_metric_loss):
  77. self.eval_metric_loss = eval_metric_loss
  78. def set_image_shape(self, image_shape):
  79. self.image_shape = image_shape
  80. def set_image_mean(self, image_mean):
  81. self.image_mean = image_mean
  82. def set_image_std(self, image_std):
  83. self.image_std = image_std
  84. def set_horizontal_flip(self, horizontal_flip):
  85. self.horizontal_flip = horizontal_flip
  86. if not horizontal_flip:
  87. self.horizontal_flip_prob = 0.0
  88. def set_horizontal_flip_prob(self, horizontal_flip_prob):
  89. if self.horizontal_flip:
  90. self.horizontal_flip_prob = horizontal_flip_prob
  91. def set_brightness_range(self, brightness_range):
  92. self.brightness_range = brightness_range
  93. def set_brightness_prob(self, brightness_prob):
  94. if self.brightness:
  95. self.brightness_prob = brightness_prob
  96. def set_brightness(self, brightness):
  97. self.brightness = brightness
  98. if not brightness:
  99. self.brightness_prob = 0.0
  100. def set_contrast(self, contrast):
  101. self.contrast = contrast
  102. if not contrast:
  103. self.contrast_prob = 0.0
  104. def set_contrast_prob(self, contrast_prob):
  105. if self.contrast:
  106. self.contrast_prob = contrast_prob
  107. def set_contrast_range(self, contrast_range):
  108. self.contrast_range = contrast_range
  109. def set_saturation(self, saturation):
  110. self.saturation = saturation
  111. if not saturation:
  112. self.saturation_prob = 0.0
  113. def set_saturation_prob(self, saturation_prob):
  114. if self.saturation:
  115. self.saturation_prob = saturation_prob
  116. def set_saturation_range(self, saturation_range):
  117. self.saturation_range = saturation_range
  118. def set_hue(self, hue):
  119. self.hue = hue
  120. if not hue:
  121. self.hue_prob = 0.0
  122. def set_hue_prob(self, hue_prob):
  123. if self.hue_prob:
  124. self.hue_prob = hue_prob
  125. def set_hue_range(self, hue_range):
  126. self.hue_range = hue_range
  127. def set_train_num(self, train_num):
  128. self.train_num = train_num
  129. class ClsParams(Params):
  130. def __init__(self):
  131. super(ClsParams, self).__init__()
  132. self.lr_policy = 'Piecewise'
  133. self.vertical_flip_prob = 0.0
  134. self.vertical_flip = True
  135. self.rotate_prob = 0.0
  136. self.rotate_range = 30
  137. self.rotate = True
  138. def set_lr_policy(self, lr_policy):
  139. self.lr_policy = lr_policy
  140. def set_vertical_flip(self, vertical_flip):
  141. self.vertical_flip = vertical_flip
  142. if not self.vertical_flip:
  143. self.vertical_flip_prob = 0.0
  144. def set_vertical_flip_prob(self, vertical_flip_prob):
  145. if self.vertical_flip:
  146. self.vertical_flip_prob = vertical_flip_prob
  147. def set_rotate(self, rotate):
  148. self.rotate = rotate
  149. if not rotate:
  150. self.rotate_prob = 0.0
  151. def set_rotate_prob(self, rotate_prob):
  152. if self.rotate:
  153. self.rotate_prob = rotate_prob
  154. def set_rotate_range(self, rotate_range):
  155. self.rotate_range = rotate_range
  156. class DetParams(Params):
  157. def __init__(self):
  158. super(DetParams, self).__init__()
  159. self.warmup_steps = 10
  160. self.warmup_start_lr = 0.
  161. self.use_mixup = True
  162. self.mixup_alpha = 1.5
  163. self.mixup_beta = 1.5
  164. self.expand_prob = 0.5
  165. self.expand_image = True
  166. self.crop_image = True
  167. self.backbone = 'ResNet18'
  168. self.model = 'FasterRCNN'
  169. self.with_fpn = True
  170. self.random_shape = True
  171. self.random_shape_sizes = [
  172. 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
  173. ]
  174. def set_warmup_steps(self, warmup_steps):
  175. self.warmup_steps = warmup_steps
  176. def set_warmup_start_lr(self, warmup_start_lr):
  177. self.warmup_start_lr = warmup_start_lr
  178. def set_use_mixup(self, use_mixup):
  179. self.use_mixup = use_mixup
  180. def set_mixup_alpha(self, mixup_alpha):
  181. self.mixup_alpha = mixup_alpha
  182. def set_mixup_beta(self, mixup_beta):
  183. self.mixup_beta = mixup_beta
  184. def set_expand_image(self, expand_image):
  185. self.expand_image = expand_image
  186. if not expand_image:
  187. self.expand_prob = 0.0
  188. def set_expand_prob(self, expand_prob):
  189. if self.expand_image:
  190. self.expand_prob = expand_prob
  191. def set_crop_image(self, crop_image):
  192. self.crop_image = crop_image
  193. def set_backbone(self, backbone):
  194. self.backbone = backbone
  195. def set_with_fpn(self, with_fpn):
  196. self.with_fpn = with_fpn
  197. def set_random_shape(self, random_shape):
  198. self.random_shape = random_shape
  199. def set_random_shape_sizes(self, random_shape_sizes):
  200. self.random_shape_sizes = random_shape_sizes
  201. class SegParams(Params):
  202. def __init__(self):
  203. super(SegParams, self).__init__()
  204. self.loss_type = [True, True]
  205. self.lr_policy = 'Piecewise'
  206. self.optimizer = 'Adam'
  207. self.backbone = 'MobileNetV2_x1.0'
  208. self.blur = True
  209. self.blur_prob = 0.
  210. self.rotate = False
  211. self.max_rotation = 15
  212. self.scale_aspect = False
  213. self.min_ratio = 0.5
  214. self.aspect_ratio = 0.33
  215. self.vertical_flip_prob = 0.0
  216. self.vertical_flip = True
  217. self.model = 'UNet'
  218. def set_loss_type(self, loss_type):
  219. self.loss_type = loss_type
  220. def set_lr_policy(self, lr_policy):
  221. self.lr_policy = lr_policy
  222. def set_optimizer(self, optimizer):
  223. self.optimizer = optimizer
  224. def set_backbone(self, backbone):
  225. self.backbone = backbone
  226. def set_blur(self, blur):
  227. self.blur = blur
  228. if not blur:
  229. self.blur_prob = 0.
  230. def set_blur_prob(self, blur_prob):
  231. if self.blur:
  232. self.blur_prob = blur_prob
  233. def set_rotate(self, rotate):
  234. self.rotate = rotate
  235. def set_max_rotation(self, max_rotation):
  236. self.max_rotation = max_rotation
  237. def set_scale_aspect(self, scale_aspect):
  238. self.scale_aspect = scale_aspect
  239. def set_min_ratio(self, min_ratio):
  240. self.min_ratio = min_ratio
  241. def set_aspect_ratio(self, aspect_ratio):
  242. self.aspect_ratio = aspect_ratio
  243. def set_vertical_flip(self, vertical_flip):
  244. self.vertical_flip = vertical_flip
  245. if not vertical_flip:
  246. self.vertical_flip_prob = 0.0
  247. def set_vertical_flip_prob(self, vertical_flip_prob):
  248. if vertical_flip_prob:
  249. self.vertical_flip_prob = vertical_flip_prob
  250. PARAMS_CLASS_LIST = [ClsParams, DetParams, SegParams, DetParams, SegParams]
  251. def recommend_parameters(params, train_nums, class_nums, memory_size_per_gpu):
  252. model_type = params['model']
  253. gpu_list = params['cuda_visible_devices']
  254. if 'cpu_num' in params:
  255. cpu_num = params['cpu_num']
  256. else:
  257. cpu_num = int(os.environ.get('CPU_NUM', 1))
  258. if cpu_num > 8:
  259. os.environ['CPU_NUM'] = '8'
  260. if not params['use_gpu']:
  261. gpu_nums = 0
  262. else:
  263. gpu_nums = len(gpu_list.split(','))
  264. # set batch_size
  265. if gpu_nums == 0 or platform.platform().startswith("Darwin"):
  266. if model_type.startswith('MobileNet'):
  267. batch_size = 8 * cpu_num
  268. elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
  269. or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
  270. or model_type.startswith('ShuffleNet'):
  271. batch_size = 4 * cpu_num
  272. elif model_type.startswith('YOLOv3') or model_type.startswith(
  273. 'PPYOLO'):
  274. batch_size = 2 * cpu_num
  275. elif model_type.startswith('FasterRCNN') or model_type.startswith(
  276. 'MaskRCNN'):
  277. batch_size = 1 * cpu_num
  278. elif model_type.startswith('DeepLab') or model_type.startswith('UNet') \
  279. or model_type.startswith('HRNet_W18') or model_type.startswith('FastSCNN'):
  280. batch_size = 2 * cpu_num
  281. else:
  282. if model_type.startswith('MobileNet'):
  283. batch_size = (memory_size_per_gpu - 513) // 57 * gpu_nums
  284. batch_size = min(batch_size, gpu_nums * 125)
  285. elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
  286. or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
  287. or model_type.startswith('ShuffleNet'):
  288. batch_size = (memory_size_per_gpu - 739) // 211 * gpu_nums
  289. batch_size = min(batch_size, gpu_nums * 16)
  290. elif model_type.startswith('YOLOv3'):
  291. batch_size = (memory_size_per_gpu - 1555) // 943 * gpu_nums
  292. batch_size = min(batch_size, gpu_nums * 8)
  293. elif model_type.startswith('PPYOLO'):
  294. batch_size = (memory_size_per_gpu - 1691) // 1025 * gpu_nums
  295. batch_size = min(batch_size, gpu_nums * 8)
  296. elif model_type.startswith('FasterRCNN'):
  297. batch_size = (memory_size_per_gpu - 1755) // 915 * gpu_nums
  298. batch_size = min(batch_size, gpu_nums * 2)
  299. elif model_type.startswith('MaskRCNN'):
  300. batch_size = (memory_size_per_gpu - 2702) // 1188 * gpu_nums
  301. batch_size = min(batch_size, gpu_nums * 2)
  302. elif model_type.startswith('DeepLab'):
  303. batch_size = (memory_size_per_gpu - 1469) // 1605 * gpu_nums
  304. batch_size = min(batch_size, gpu_nums * 4)
  305. elif model_type.startswith('UNet'):
  306. batch_size = (memory_size_per_gpu - 1275) // 1256 * gpu_nums
  307. batch_size = min(batch_size, gpu_nums * 4)
  308. elif model_type.startswith('HRNet_W18'):
  309. batch_size = (memory_size_per_gpu - 800) // 682 * gpu_nums
  310. batch_size = min(batch_size, gpu_nums * 4)
  311. elif model_type.startswith('FastSCNN'):
  312. batch_size = (memory_size_per_gpu - 636) // 144 * gpu_nums
  313. batch_size = min(batch_size, gpu_nums * 4)
  314. if batch_size > train_nums // 2:
  315. batch_size = train_nums // 2
  316. gpu_list = '{}'.format(gpu_list.split(',')[0]) if gpu_nums > 0 else ''
  317. if batch_size <= 0:
  318. batch_size = 1
  319. # set learning_rate
  320. if model_type.startswith('MobileNet'):
  321. lr = (batch_size / 500.0) * 0.1
  322. elif model_type.startswith('DenseNet') or model_type.startswith('ResNet') \
  323. or model_type.startswith('Xception') or model_type.startswith('DarkNet') \
  324. or model_type.startswith('ShuffleNet'):
  325. lr = (batch_size / 256.0) * 0.1
  326. elif model_type.startswith('YOLOv3') or model_type.startswith('PPYOLO'):
  327. lr = 0.001 * batch_size / 64
  328. num_steps_each_epoch = train_nums // batch_size
  329. min_warmup_step = max(3 * num_steps_each_epoch, 50 * class_nums)
  330. if gpu_nums == 0:
  331. gpu_nums = 1
  332. warmup_step = min(min_warmup_step, int(400 * class_nums / gpu_nums))
  333. elif model_type.startswith('FasterRCNN') or model_type.startswith(
  334. 'MaskRCNN'):
  335. lr = 0.02 * batch_size / 16
  336. num_steps_each_epoch = train_nums // batch_size
  337. min_warmup_step = max(num_steps_each_epoch, 50)
  338. if gpu_nums == 0:
  339. gpu_nums = 1
  340. warmup_step = min(min_warmup_step, int(4000 / gpu_nums))
  341. elif model_type.startswith('DeepLab') or model_type.startswith('UNet') \
  342. or model_type.startswith('HRNet_W18') or model_type.startswith('FastSCNN'):
  343. lr = 0.01 * batch_size / 2
  344. loss_type = [False, False]
  345. params['batch_size'] = batch_size
  346. params['learning_rate'] = lr
  347. params['cuda_visible_devices'] = gpu_list
  348. if model_type in ['YOLOv3', 'PPYOLO', 'FasterRCNN', 'MaskRCNN']:
  349. num_epochs = params['num_epochs']
  350. lr_decay_epochs = params['lr_decay_epochs']
  351. if warmup_step >= lr_decay_epochs[0] * num_steps_each_epoch:
  352. for i in range(len(lr_decay_epochs)):
  353. lr_decay_epochs[i] += warmup_step // num_steps_each_epoch
  354. num_epochs += warmup_step // num_steps_each_epoch
  355. params['num_epochs'] = num_epochs
  356. params['lr_decay_epochs'] = lr_decay_epochs
  357. params['warmup_steps'] = warmup_step
  358. if 'loss_type' in params:
  359. params['loss_type'] = loss_type