yolo_v3.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import math
  16. import tqdm
  17. import paddlex
  18. from .ppyolo import PPYOLO
  19. class YOLOv3(PPYOLO):
  20. """构建YOLOv3,并实现其训练、评估、预测和模型导出。
  21. Args:
  22. num_classes (int): 类别数。默认为80。
  23. backbone (str): YOLOv3的backbone网络,取值范围为['DarkNet53',
  24. 'ResNet34', 'MobileNetV1', 'MobileNetV3_large']。默认为'MobileNetV1'。
  25. anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
  26. [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  27. [59, 119], [116, 90], [156, 198], [373, 326]]。
  28. anchor_masks (list|tuple): 在计算YOLOv3损失时,使用anchor的mask索引,为None时表示使用默认值
  29. [[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
  30. ignore_threshold (float): 在计算YOLOv3损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
  31. nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
  32. nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
  33. nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
  34. nms_iou_threshold (float): 进行NMS时,用于剔除检测框IoU的阈值。默认为0.45。
  35. label_smooth (bool): 是否使用label smooth。默认值为False。
  36. train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
  37. """
  38. def __init__(self,
  39. num_classes=80,
  40. backbone='MobileNetV1',
  41. anchors=None,
  42. anchor_masks=None,
  43. ignore_threshold=0.7,
  44. nms_score_threshold=0.01,
  45. nms_topk=1000,
  46. nms_keep_topk=100,
  47. nms_iou_threshold=0.45,
  48. label_smooth=False,
  49. train_random_shapes=[
  50. 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
  51. ]):
  52. self.init_params = locals()
  53. backbones = [
  54. 'DarkNet53', 'ResNet34', 'MobileNetV1', 'MobileNetV3_large'
  55. ]
  56. assert backbone in backbones, "backbone should be one of {}".format(
  57. backbones)
  58. super(PPYOLO, self).__init__('detector')
  59. self.backbone = backbone
  60. self.num_classes = num_classes
  61. self.anchors = anchors
  62. self.anchor_masks = anchor_masks
  63. self.ignore_threshold = ignore_threshold
  64. self.nms_score_threshold = nms_score_threshold
  65. self.nms_topk = nms_topk
  66. self.nms_keep_topk = nms_keep_topk
  67. self.nms_iou_threshold = nms_iou_threshold
  68. self.label_smooth = label_smooth
  69. self.sync_bn = True
  70. self.train_random_shapes = train_random_shapes
  71. self.fixed_input_shape = None
  72. self.use_fine_grained_loss = False
  73. self.use_coord_conv = False
  74. self.use_iou_aware = False
  75. self.use_spp = False
  76. self.use_drop_block = False
  77. self.use_iou_loss = False
  78. self.scale_x_y = 1.
  79. self.use_matrix_nms = False
  80. self.use_ema = False
  81. self.with_dcn_v2 = False
  82. def _get_backbone(self, backbone_name):
  83. if backbone_name == 'DarkNet53':
  84. backbone = paddlex.cv.nets.DarkNet(norm_type='sync_bn')
  85. elif backbone_name == 'ResNet34':
  86. backbone = paddlex.cv.nets.ResNet(
  87. norm_type='sync_bn',
  88. layers=34,
  89. freeze_norm=False,
  90. norm_decay=0.,
  91. feature_maps=[3, 4, 5],
  92. freeze_at=0)
  93. elif backbone_name == 'MobileNetV1':
  94. backbone = paddlex.cv.nets.MobileNetV1(norm_type='sync_bn')
  95. elif backbone_name.startswith('MobileNetV3'):
  96. model_name = backbone_name.split('_')[1]
  97. backbone = paddlex.cv.nets.MobileNetV3(
  98. norm_type='sync_bn', model_name=model_name)
  99. return backbone
  100. def train(self,
  101. num_epochs,
  102. train_dataset,
  103. train_batch_size=8,
  104. eval_dataset=None,
  105. save_interval_epochs=20,
  106. log_interval_steps=2,
  107. save_dir='output',
  108. pretrain_weights='IMAGENET',
  109. optimizer=None,
  110. learning_rate=1.0 / 8000,
  111. warmup_steps=1000,
  112. warmup_start_lr=0.0,
  113. lr_decay_epochs=[213, 240],
  114. lr_decay_gamma=0.1,
  115. metric=None,
  116. use_vdl=False,
  117. sensitivities_file=None,
  118. eval_metric_loss=0.05,
  119. early_stop=False,
  120. early_stop_patience=5,
  121. resume_checkpoint=None):
  122. """训练。
  123. Args:
  124. num_epochs (int): 训练迭代轮数。
  125. train_dataset (paddlex.datasets): 训练数据读取器。
  126. train_batch_size (int): 训练数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与显卡
  127. 数量之商为验证数据batch大小。默认值为8。
  128. eval_dataset (paddlex.datasets): 验证数据读取器。
  129. save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为20。
  130. log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为10。
  131. save_dir (str): 模型保存路径。默认值为'output'。
  132. pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
  133. 则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',
  134. 则自动下载在COCO数据集上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。
  135. optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
  136. fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
  137. learning_rate (float): 默认优化器的学习率。默认为1.0/8000。
  138. warmup_steps (int): 默认优化器进行warmup过程的步数。默认为1000。
  139. warmup_start_lr (int): 默认优化器warmup的起始学习率。默认为0.0。
  140. lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[213, 240]。
  141. lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
  142. metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
  143. use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
  144. sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
  145. 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
  146. eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
  147. early_stop (bool): 是否使用提前终止训练策略。默认值为False。
  148. early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
  149. 连续下降或持平,则终止训练。默认值为5。
  150. resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
  151. Raises:
  152. ValueError: 评估类型不在指定列表中。
  153. ValueError: 模型从inference model进行加载。
  154. """
  155. return super(YOLOv3, self).train(
  156. num_epochs, train_dataset, train_batch_size, eval_dataset,
  157. save_interval_epochs, log_interval_steps, save_dir,
  158. pretrain_weights, optimizer, learning_rate, warmup_steps,
  159. warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric, use_vdl,
  160. sensitivities_file, eval_metric_loss, early_stop,
  161. early_stop_patience, resume_checkpoint, False)