det.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. from . import cv
  16. from .cv.models.utils.visualize import visualize_detection, draw_pr_curve
  17. from paddlex.cv.transforms import det_transforms
  18. from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
  19. from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
  20. _BatchPadding, _Gt2YoloTarget
  21. import paddlex.utils.logging as logging
  22. transforms = det_transforms
  23. visualize = visualize_detection
  24. draw_pr_curve = draw_pr_curve
  25. class FasterRCNN(cv.models.FasterRCNN):
  26. def __init__(self,
  27. num_classes=81,
  28. backbone='ResNet50',
  29. with_fpn=True,
  30. aspect_ratios=[0.5, 1.0, 2.0],
  31. anchor_sizes=[32, 64, 128, 256, 512],
  32. with_dcn=None,
  33. rpn_cls_loss=None,
  34. rpn_focal_loss_alpha=None,
  35. rpn_focal_loss_gamma=None,
  36. rcnn_bbox_loss=None,
  37. rcnn_nms=None,
  38. keep_top_k=100,
  39. nms_threshold=0.5,
  40. score_threshold=0.05,
  41. softnms_sigma=None,
  42. bbox_assigner=None,
  43. fpn_num_channels=256,
  44. input_channel=None,
  45. rpn_batch_size_per_im=256,
  46. rpn_fg_fraction=0.5,
  47. test_pre_nms_top_n=None,
  48. test_post_nms_top_n=1000):
  49. if with_dcn is not None:
  50. logging.warning(
  51. "`with_dcn` is deprecated in PaddleX 2.0 and won't take effect. Defaults to False."
  52. )
  53. if rpn_cls_loss is not None:
  54. logging.warning(
  55. "`rpn_cls_loss` is deprecated in PaddleX 2.0 and won't take effect. "
  56. "Defaults to 'SigmoidCrossEntropy'.")
  57. if rpn_focal_loss_alpha is not None or rpn_focal_loss_gamma is not None:
  58. logging.warning(
  59. "Focal loss is deprecated in PaddleX 2.0."
  60. " `rpn_focal_loss_alpha` and `rpn_focal_loss_gamma` won't take effect."
  61. )
  62. if rcnn_bbox_loss is not None:
  63. logging.warning(
  64. "`rcnn_bbox_loss` is deprecated in PaddleX 2.0 and won't take effect. "
  65. "Defaults to 'SmoothL1Loss'")
  66. if rcnn_nms is not None:
  67. logging.warning(
  68. "MultiClassSoftNMS is deprecated in PaddleX 2.0. "
  69. "`rcnn_nms` and `softnms_sigma` won't take effect. MultiClassNMS will be used by default"
  70. )
  71. if bbox_assigner is not None:
  72. logging.warning(
  73. "`bbox_assigner` is deprecated in PaddleX 2.0 and won't take effect. "
  74. "Defaults to 'BBoxAssigner'")
  75. if input_channel is not None:
  76. logging.warning(
  77. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  78. )
  79. if isinstance(anchor_sizes[0], int):
  80. anchor_sizes = [[size] for size in anchor_sizes]
  81. super(FasterRCNN, self).__init__(
  82. num_classes=num_classes - 1,
  83. backbone=backbone,
  84. with_fpn=with_fpn,
  85. aspect_ratios=aspect_ratios,
  86. anchor_sizes=anchor_sizes,
  87. keep_top_k=keep_top_k,
  88. nms_threshold=nms_threshold,
  89. score_threshold=score_threshold,
  90. fpn_num_channels=fpn_num_channels,
  91. rpn_batch_size_per_im=rpn_batch_size_per_im,
  92. rpn_fg_fraction=rpn_fg_fraction,
  93. test_pre_nms_top_n=test_pre_nms_top_n,
  94. test_post_nms_top_n=test_post_nms_top_n)
  95. class YOLOv3(cv.models.YOLOv3):
  96. def __init__(self,
  97. num_classes=80,
  98. backbone='MobileNetV1',
  99. anchors=None,
  100. anchor_masks=None,
  101. ignore_threshold=0.7,
  102. nms_score_threshold=0.01,
  103. nms_topk=1000,
  104. nms_keep_topk=100,
  105. nms_iou_threshold=0.45,
  106. label_smooth=False,
  107. train_random_shapes=[
  108. 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
  109. ],
  110. input_channel=None):
  111. if input_channel is not None:
  112. logging.warning(
  113. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  114. )
  115. if anchors is None:
  116. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  117. [59, 119], [116, 90], [156, 198], [373, 326]]
  118. if anchor_masks is None:
  119. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  120. super(YOLOv3, self).__init__(
  121. num_classes=num_classes,
  122. backbone=backbone,
  123. anchors=anchors,
  124. anchor_masks=anchor_masks,
  125. ignore_threshold=ignore_threshold,
  126. nms_score_threshold=nms_score_threshold,
  127. nms_topk=nms_topk,
  128. nms_keep_topk=nms_keep_topk,
  129. nms_iou_threshold=nms_iou_threshold,
  130. label_smooth=label_smooth)
  131. self.train_random_shapes = train_random_shapes
  132. def _compose_batch_transform(self, transforms, mode='train'):
  133. if mode == 'train':
  134. default_batch_transforms = [
  135. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  136. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  137. _Gt2YoloTarget(
  138. anchor_masks=self.anchor_masks,
  139. anchors=self.anchors,
  140. downsample_ratios=getattr(self, 'downsample_ratios',
  141. [32, 16, 8]),
  142. num_classes=self.num_classes)
  143. ]
  144. else:
  145. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  146. if mode == 'eval' and self.metric == 'voc':
  147. collate_batch = False
  148. else:
  149. collate_batch = True
  150. custom_batch_transforms = []
  151. random_shape_defined = False
  152. for i, op in enumerate(transforms.transforms):
  153. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  154. if mode != 'train':
  155. raise Exception(
  156. "{} cannot be present in the {} transforms. ".format(
  157. op.__class__.__name__, mode) +
  158. "Please check the {} transforms.".format(mode))
  159. custom_batch_transforms.insert(0, copy.deepcopy(op))
  160. random_shape_defined = True
  161. if not random_shape_defined:
  162. default_batch_transforms.insert(
  163. 0,
  164. BatchRandomResize(
  165. target_sizes=self.train_random_shapes, interp='RANDOM'))
  166. batch_transforms = BatchCompose(
  167. custom_batch_transforms + default_batch_transforms,
  168. collate_batch=collate_batch)
  169. return batch_transforms
  170. class PPYOLO(cv.models.PPYOLO):
  171. def __init__(
  172. self,
  173. num_classes=80,
  174. backbone='ResNet50_vd_ssld',
  175. with_dcn_v2=None,
  176. # YOLO Head
  177. anchors=None,
  178. anchor_masks=None,
  179. use_coord_conv=True,
  180. use_iou_aware=True,
  181. use_spp=True,
  182. use_drop_block=True,
  183. scale_x_y=1.05,
  184. # PPYOLO Loss
  185. ignore_threshold=0.7,
  186. label_smooth=False,
  187. use_iou_loss=True,
  188. # NMS
  189. use_matrix_nms=True,
  190. nms_score_threshold=0.01,
  191. nms_topk=1000,
  192. nms_keep_topk=100,
  193. nms_iou_threshold=0.45,
  194. train_random_shapes=[
  195. 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
  196. ],
  197. input_channel=None):
  198. if backbone == 'ResNet50_vd_ssld':
  199. backbone = 'ResNet50_vd_dcn'
  200. if with_dcn_v2 is not None:
  201. logging.warning(
  202. "`with_dcn_v2` is deprecated in PaddleX 2.0 and will not take effect. "
  203. "To use backbone with deformable convolutional networks, "
  204. "please specify in `backbone_name`. "
  205. "Currently the only backbone with dcn is 'ResNet50_vd_dcn'.")
  206. if input_channel is not None:
  207. logging.warning(
  208. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  209. )
  210. super(PPYOLO, self).__init__(
  211. num_classes=num_classes,
  212. backbone=backbone,
  213. anchors=anchors,
  214. anchor_masks=anchor_masks,
  215. use_coord_conv=use_coord_conv,
  216. use_iou_aware=use_iou_aware,
  217. use_spp=use_spp,
  218. use_drop_block=use_drop_block,
  219. scale_x_y=scale_x_y,
  220. ignore_threshold=ignore_threshold,
  221. label_smooth=label_smooth,
  222. use_iou_loss=use_iou_loss,
  223. use_matrix_nms=use_matrix_nms,
  224. nms_score_threshold=nms_score_threshold,
  225. nms_topk=nms_topk,
  226. nms_keep_topk=nms_keep_topk,
  227. nms_iou_threshold=nms_iou_threshold)
  228. self.train_random_shapes = train_random_shapes
  229. def _compose_batch_transform(self, transforms, mode='train'):
  230. if mode == 'train':
  231. default_batch_transforms = [
  232. _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
  233. _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
  234. _Gt2YoloTarget(
  235. anchor_masks=self.anchor_masks,
  236. anchors=self.anchors,
  237. downsample_ratios=getattr(self, 'downsample_ratios',
  238. [32, 16, 8]),
  239. num_classes=self.num_classes)
  240. ]
  241. else:
  242. default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
  243. if mode == 'eval' and self.metric == 'voc':
  244. collate_batch = False
  245. else:
  246. collate_batch = True
  247. custom_batch_transforms = []
  248. random_shape_defined = False
  249. for i, op in enumerate(transforms.transforms):
  250. if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
  251. if mode != 'train':
  252. raise Exception(
  253. "{} cannot be present in the {} transforms. ".format(
  254. op.__class__.__name__, mode) +
  255. "Please check the {} transforms.".format(mode))
  256. custom_batch_transforms.insert(0, copy.deepcopy(op))
  257. random_shape_defined = True
  258. if not random_shape_defined:
  259. default_batch_transforms.insert(
  260. 0,
  261. BatchRandomResize(
  262. target_sizes=self.train_random_shapes, interp='RANDOM'))
  263. batch_transforms = BatchCompose(
  264. custom_batch_transforms + default_batch_transforms,
  265. collate_batch=collate_batch)
  266. return batch_transforms