detection.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import numpy as np
  16. import paddle
  17. from paddleslim import L1NormFilterPruner
  18. def build_yolo_transforms(params):
  19. from paddlex import transforms as T
  20. target_size = params.image_shape[0]
  21. use_mixup = params.use_mixup
  22. dt_list = []
  23. if use_mixup:
  24. dt_list.append(
  25. T.MixupImage(
  26. alpha=params.mixup_alpha,
  27. beta=params.mixup_beta,
  28. mixup_epoch=int(params.num_epochs * 25. / 27)))
  29. dt_list.extend([
  30. T.RandomDistort(
  31. brightness_range=params.brightness_range,
  32. brightness_prob=params.brightness_prob,
  33. contrast_range=params.contrast_range,
  34. contrast_prob=params.contrast_prob,
  35. saturation_range=params.saturation_range,
  36. saturation_prob=params.saturation_prob,
  37. hue_range=params.hue_range,
  38. hue_prob=params.hue_prob),
  39. T.RandomExpand(
  40. prob=params.expand_prob,
  41. im_padding_value=[float(int(x * 255)) for x in params.image_mean])
  42. ])
  43. crop_image = params.crop_image
  44. if crop_image:
  45. dt_list.append(T.RandomCrop())
  46. dt_list.extend([
  47. T.Resize(
  48. target_size=target_size, interp='RANDOM'),
  49. T.RandomHorizontalFlip(prob=params.horizontal_flip_prob), T.Normalize(
  50. mean=params.image_mean, std=params.image_std)
  51. ])
  52. train_transforms = T.Compose(dt_list)
  53. eval_transforms = T.Compose([
  54. T.Resize(
  55. target_size=target_size, interp='CUBIC'),
  56. T.Normalize(
  57. mean=params.image_mean, std=params.image_std),
  58. ])
  59. return train_transforms, eval_transforms
  60. def build_rcnn_transforms(params):
  61. from paddlex import transforms as T
  62. short_size = min(params.image_shape)
  63. max_size = max(params.image_shape)
  64. train_transforms = T.Compose([
  65. T.RandomDistort(
  66. brightness_range=params.brightness_range,
  67. brightness_prob=params.brightness_prob,
  68. contrast_range=params.contrast_range,
  69. contrast_prob=params.contrast_prob,
  70. saturation_range=params.saturation_range,
  71. saturation_prob=params.saturation_prob,
  72. hue_range=params.hue_range,
  73. hue_prob=params.hue_prob),
  74. T.RandomHorizontalFlip(prob=params.horizontal_flip_prob),
  75. T.Normalize(
  76. mean=params.image_mean, std=params.image_std),
  77. T.ResizeByShort(
  78. short_size=short_size, max_size=max_size),
  79. ])
  80. eval_transforms = T.Compose([
  81. T.Normalize(),
  82. T.ResizeByShort(
  83. short_size=short_size, max_size=max_size),
  84. ])
  85. return train_transforms, eval_transforms
  86. def build_pico_transforms(params):
  87. from paddlex import transforms as T
  88. target_size = params.image_shape[0]
  89. dt_list = []
  90. dt_list.extend([
  91. T.RandomDistort(
  92. brightness_range=params.brightness_range,
  93. brightness_prob=params.brightness_prob,
  94. contrast_range=params.contrast_range,
  95. contrast_prob=params.contrast_prob,
  96. saturation_range=params.saturation_range,
  97. saturation_prob=params.saturation_prob,
  98. hue_range=params.hue_range,
  99. hue_prob=params.hue_prob),
  100. ])
  101. crop_image = params.crop_image
  102. if crop_image:
  103. dt_list.append(T.RandomCrop())
  104. dt_list.extend([
  105. T.Resize(
  106. target_size=target_size, interp='RANDOM'),
  107. T.RandomHorizontalFlip(prob=params.horizontal_flip_prob), T.Normalize(
  108. mean=params.image_mean, std=params.image_std)
  109. ])
  110. train_transforms = T.Compose(dt_list)
  111. eval_transforms = T.Compose([
  112. T.Resize(
  113. target_size=target_size, interp='CUBIC'),
  114. T.Normalize(
  115. mean=params.image_mean, std=params.image_std),
  116. ])
  117. return train_transforms, eval_transforms
  118. def build_voc_datasets(dataset_path, train_transforms, eval_transforms):
  119. import paddlex as pdx
  120. train_file_list = osp.join(dataset_path, 'train_list.txt')
  121. eval_file_list = osp.join(dataset_path, 'val_list.txt')
  122. label_list = osp.join(dataset_path, 'labels.txt')
  123. train_dataset = pdx.datasets.VOCDetection(
  124. data_dir=dataset_path,
  125. file_list=train_file_list,
  126. label_list=label_list,
  127. transforms=train_transforms,
  128. shuffle=True)
  129. eval_dataset = pdx.datasets.VOCDetection(
  130. data_dir=dataset_path,
  131. file_list=eval_file_list,
  132. label_list=label_list,
  133. transforms=eval_transforms)
  134. return train_dataset, eval_dataset
  135. def build_coco_datasets(dataset_path, train_transforms, eval_transforms):
  136. import paddlex as pdx
  137. data_dir = osp.join(dataset_path, 'JPEGImages')
  138. train_ann_file = osp.join(dataset_path, 'train.json')
  139. eval_ann_file = osp.join(dataset_path, 'val.json')
  140. train_dataset = pdx.datasets.CocoDetection(
  141. data_dir=data_dir,
  142. ann_file=train_ann_file,
  143. transforms=train_transforms,
  144. shuffle=True)
  145. eval_dataset = pdx.datasets.CocoDetection(
  146. data_dir=data_dir, ann_file=eval_ann_file, transforms=eval_transforms)
  147. return train_dataset, eval_dataset
  148. def build_optimizer(parameters, step_each_epoch, params):
  149. import paddle
  150. from paddle.regularizer import L2Decay
  151. learning_rate = params.learning_rate
  152. lr_decay_epochs = params.lr_decay_epochs
  153. warmup_steps = params.warmup_steps
  154. warmup_start_lr = params.warmup_start_lr
  155. boundaries = [b * step_each_epoch for b in lr_decay_epochs]
  156. values = [
  157. learning_rate * (0.1**i) for i in range(len(lr_decay_epochs) + 1)
  158. ]
  159. lr = paddle.optimizer.lr.PiecewiseDecay(
  160. boundaries=boundaries, values=values)
  161. lr = paddle.optimizer.lr.LinearWarmup(
  162. learning_rate=lr,
  163. warmup_steps=warmup_steps,
  164. start_lr=warmup_start_lr,
  165. end_lr=learning_rate)
  166. factor = 1e-04 if params.model in ['FasterRCNN', 'MaskRCNN'] else 5e-04
  167. optimizer = paddle.optimizer.Momentum(
  168. learning_rate=lr,
  169. momentum=0.9,
  170. weight_decay=L2Decay(factor),
  171. parameters=parameters)
  172. return optimizer
  173. def train(task_path, dataset_path, params):
  174. import paddlex as pdx
  175. pdx.log_level = 3
  176. if params.model in ['YOLOv3', 'PPYOLO', 'PPYOLOTiny', 'PPYOLOv2']:
  177. train_transforms, eval_transforms = build_yolo_transforms(params)
  178. elif params.model in ['PicoDet']:
  179. train_transforms, eval_transforms = build_pico_transforms(params)
  180. elif params.model in ['FasterRCNN', 'MaskRCNN']:
  181. train_transforms, eval_transforms = build_rcnn_transforms(params)
  182. if osp.exists(osp.join(dataset_path, 'JPEGImages')) and \
  183. osp.exists(osp.join(dataset_path, 'train.json')) and \
  184. osp.exists(osp.join(dataset_path, 'val.json')):
  185. train_dataset, eval_dataset = build_coco_datasets(
  186. dataset_path=dataset_path,
  187. train_transforms=train_transforms,
  188. eval_transforms=eval_transforms)
  189. elif osp.exists(osp.join(dataset_path, 'train_list.txt')) and \
  190. osp.exists(osp.join(dataset_path, 'val_list.txt')) and \
  191. osp.exists(osp.join(dataset_path, 'labels.txt')):
  192. train_dataset, eval_dataset = build_voc_datasets(
  193. dataset_path=dataset_path,
  194. train_transforms=train_transforms,
  195. eval_transforms=eval_transforms)
  196. step_each_epoch = train_dataset.num_samples // params.batch_size
  197. train_batch_size = params.batch_size
  198. save_interval_epochs = params.save_interval_epochs
  199. save_dir = osp.join(task_path, 'output')
  200. pretrain_weights = params.pretrain_weights
  201. if pretrain_weights is not None and osp.exists(pretrain_weights):
  202. pretrain_weights = osp.join(pretrain_weights, 'model.pdparams')
  203. detector = getattr(pdx.det, params.model)
  204. num_classes = len(train_dataset.labels)
  205. sensitivities_path = params.sensitivities_path
  206. pruned_flops = params.pruned_flops
  207. model = detector(num_classes=num_classes, backbone=params.backbone)
  208. if sensitivities_path is not None:
  209. # load weights
  210. model.net_initialize(pretrain_weights=pretrain_weights)
  211. pretrain_weights = None
  212. # prune
  213. dataset = eval_dataset or train_dataset
  214. im_shape = dataset[0]['image'].shape[:2]
  215. if getattr(model, 'with_fpn',
  216. False) or model.__class__.__name__ == 'PicoDet':
  217. im_shape[0] = int(np.ceil(im_shape[0] / 32) * 32)
  218. im_shape[1] = int(np.ceil(im_shape[1] / 32) * 32)
  219. inputs = [{
  220. "image": paddle.ones(
  221. shape=[1, 3] + list(im_shape), dtype='float32'),
  222. "im_shape": paddle.full(
  223. [1, 2], 640, dtype='float32'),
  224. "scale_factor": paddle.ones(
  225. shape=[1, 2], dtype='float32')
  226. }]
  227. model.net.eval()
  228. model.pruner = L1NormFilterPruner(
  229. model.net, inputs=inputs, sen_file=sensitivities_path)
  230. model.prune(pruned_flops=pruned_flops)
  231. optimizer = build_optimizer(model.net.parameters(), step_each_epoch,
  232. params)
  233. model.train(
  234. num_epochs=params.num_epochs,
  235. train_dataset=train_dataset,
  236. train_batch_size=train_batch_size,
  237. eval_dataset=eval_dataset,
  238. save_interval_epochs=save_interval_epochs,
  239. log_interval_steps=2,
  240. save_dir=save_dir,
  241. pretrain_weights=pretrain_weights,
  242. optimizer=optimizer,
  243. use_vdl=True,
  244. resume_checkpoint=params.resume_checkpoint)