faster_rcnn.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. import copy
  19. from paddle import fluid
  20. from .fpn import (FPN, HRFPN)
  21. from .rpn_head import (RPNHead, FPNRPNHead)
  22. from .roi_extractor import (RoIAlign, FPNRoIAlign)
  23. from .bbox_head import (BBoxHead, TwoFCHead)
  24. from ..resnet import ResNetC5
  25. from .loss.diou_loss import DiouLoss
  26. from .ops import BBoxAssigner, LibraBBoxAssigner
  27. __all__ = ['FasterRCNN']
  28. class FasterRCNN(object):
  29. """
  30. Faster R-CNN architecture, see https://arxiv.org/abs/1506.01497
  31. Args:
  32. backbone (object): backbone instance
  33. rpn_head (object): `RPNhead` instance
  34. roi_extractor (object): ROI extractor instance
  35. bbox_head (object): `BBoxHead` instance
  36. fpn (object): feature pyramid network instance
  37. """
  38. def __init__(
  39. self,
  40. backbone,
  41. input_channel=3,
  42. mode='train',
  43. num_classes=81,
  44. with_fpn=False,
  45. fpn=None,
  46. #rpn_head
  47. rpn_only=False,
  48. rpn_head=None,
  49. anchor_sizes=[32, 64, 128, 256, 512],
  50. aspect_ratios=[0.5, 1.0, 2.0],
  51. rpn_batch_size_per_im=256,
  52. rpn_fg_fraction=0.5,
  53. rpn_positive_overlap=0.7,
  54. rpn_negative_overlap=0.3,
  55. train_pre_nms_top_n=12000,
  56. train_post_nms_top_n=2000,
  57. train_nms_thresh=0.7,
  58. test_pre_nms_top_n=6000,
  59. test_post_nms_top_n=1000,
  60. test_nms_thresh=0.7,
  61. rpn_cls_loss='SigmoidCrossEntropy',
  62. rpn_focal_loss_alpha=0.25,
  63. rpn_focal_loss_gamma=2,
  64. #roi_extractor
  65. roi_extractor=None,
  66. #bbox_head
  67. bbox_head=None,
  68. keep_top_k=100,
  69. nms_threshold=0.5,
  70. score_threshold=0.05,
  71. rcnn_nms='MultiClassNMS',
  72. softnms_sigma=0.5,
  73. post_threshold=.05,
  74. #bbox_assigner
  75. batch_size_per_im=512,
  76. fg_fraction=.25,
  77. fg_thresh=.5,
  78. bg_thresh_hi=.5,
  79. bg_thresh_lo=0.,
  80. bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
  81. fixed_input_shape=None,
  82. rcnn_bbox_loss='SmoothL1Loss',
  83. diouloss_weight=10.0,
  84. diouloss_is_cls_agnostic=False,
  85. diouloss_use_complete_iou_loss=True,
  86. bbox_assigner='BBoxAssigner',
  87. fpn_num_channels=256):
  88. super(FasterRCNN, self).__init__()
  89. self.backbone = backbone
  90. self.mode = mode
  91. if with_fpn and fpn is None:
  92. if self.backbone.__class__.__name__.startswith('HRNet'):
  93. fpn = HRFPN()
  94. fpn.min_level = 2
  95. fpn.max_level = 6
  96. else:
  97. fpn = FPN()
  98. self.fpn = fpn
  99. if self.fpn is not None:
  100. self.fpn.num_chan = fpn_num_channels
  101. self.num_classes = num_classes
  102. if rpn_head is None:
  103. if self.fpn is None:
  104. rpn_head = RPNHead(
  105. anchor_sizes=anchor_sizes,
  106. aspect_ratios=aspect_ratios,
  107. rpn_batch_size_per_im=rpn_batch_size_per_im,
  108. rpn_fg_fraction=rpn_fg_fraction,
  109. rpn_positive_overlap=rpn_positive_overlap,
  110. rpn_negative_overlap=rpn_negative_overlap,
  111. train_pre_nms_top_n=train_pre_nms_top_n,
  112. train_post_nms_top_n=train_post_nms_top_n,
  113. train_nms_thresh=train_nms_thresh,
  114. test_pre_nms_top_n=test_pre_nms_top_n,
  115. test_post_nms_top_n=test_post_nms_top_n,
  116. test_nms_thresh=test_nms_thresh,
  117. rpn_cls_loss=rpn_cls_loss,
  118. rpn_focal_loss_alpha=rpn_focal_loss_alpha,
  119. rpn_focal_loss_gamma=rpn_focal_loss_gamma)
  120. else:
  121. rpn_head = FPNRPNHead(
  122. anchor_start_size=anchor_sizes[0],
  123. aspect_ratios=aspect_ratios,
  124. num_chan=self.fpn.num_chan,
  125. min_level=self.fpn.min_level,
  126. max_level=self.fpn.max_level,
  127. rpn_batch_size_per_im=rpn_batch_size_per_im,
  128. rpn_fg_fraction=rpn_fg_fraction,
  129. rpn_positive_overlap=rpn_positive_overlap,
  130. rpn_negative_overlap=rpn_negative_overlap,
  131. train_pre_nms_top_n=train_pre_nms_top_n,
  132. train_post_nms_top_n=train_post_nms_top_n,
  133. train_nms_thresh=train_nms_thresh,
  134. test_pre_nms_top_n=test_pre_nms_top_n,
  135. test_post_nms_top_n=test_post_nms_top_n,
  136. test_nms_thresh=test_nms_thresh,
  137. rpn_cls_loss=rpn_cls_loss,
  138. rpn_focal_loss_alpha=rpn_focal_loss_alpha,
  139. rpn_focal_loss_gamma=rpn_focal_loss_gamma)
  140. self.rpn_head = rpn_head
  141. if roi_extractor is None:
  142. if self.fpn is None:
  143. roi_extractor = RoIAlign(
  144. resolution=14,
  145. spatial_scale=1. / 2**self.backbone.feature_maps[0])
  146. else:
  147. roi_extractor = FPNRoIAlign(sampling_ratio=2)
  148. self.roi_extractor = roi_extractor
  149. if bbox_head is None:
  150. if self.fpn is None:
  151. head = ResNetC5(
  152. layers=self.backbone.layers,
  153. norm_type=self.backbone.norm_type,
  154. freeze_norm=self.backbone.freeze_norm,
  155. variant=self.backbone.variant)
  156. else:
  157. head = TwoFCHead()
  158. bbox_head = BBoxHead(
  159. head=head,
  160. keep_top_k=keep_top_k,
  161. nms_threshold=nms_threshold,
  162. score_threshold=score_threshold,
  163. rcnn_nms=rcnn_nms,
  164. softnms_sigma=softnms_sigma,
  165. post_threshold=post_threshold,
  166. num_classes=num_classes,
  167. rcnn_bbox_loss=rcnn_bbox_loss,
  168. diouloss_weight=diouloss_weight,
  169. diouloss_is_cls_agnostic=diouloss_is_cls_agnostic,
  170. diouloss_use_complete_iou_loss=diouloss_use_complete_iou_loss)
  171. self.bbox_head = bbox_head
  172. self.batch_size_per_im = batch_size_per_im
  173. self.fg_fraction = fg_fraction
  174. self.fg_thresh = fg_thresh
  175. self.bg_thresh_hi = bg_thresh_hi
  176. self.bg_thresh_lo = bg_thresh_lo
  177. self.bbox_reg_weights = bbox_reg_weights
  178. self.rpn_only = rpn_only
  179. self.fixed_input_shape = fixed_input_shape
  180. if bbox_assigner == 'BBoxAssigner':
  181. self.bbox_assigner = BBoxAssigner(
  182. batch_size_per_im=batch_size_per_im,
  183. fg_fraction=fg_fraction,
  184. fg_thresh=fg_thresh,
  185. bg_thresh_hi=bg_thresh_hi,
  186. bg_thresh_lo=bg_thresh_lo,
  187. bbox_reg_weights=bbox_reg_weights,
  188. num_classes=num_classes,
  189. shuffle_before_sample=self.rpn_head.use_random)
  190. elif bbox_assigner == 'LibraBBoxAssigner':
  191. self.bbox_assigner = LibraBBoxAssigner(
  192. batch_size_per_im=batch_size_per_im,
  193. fg_fraction=fg_fraction,
  194. fg_thresh=fg_thresh,
  195. bg_thresh_hi=bg_thresh_hi,
  196. bg_thresh_lo=bg_thresh_lo,
  197. bbox_reg_weights=bbox_reg_weights,
  198. num_classes=num_classes,
  199. shuffle_before_sample=self.rpn_head.use_random)
  200. self.input_channel = input_channel
  201. def build_net(self, inputs):
  202. im = inputs['image']
  203. im_info = inputs['im_info']
  204. if self.mode == 'train':
  205. gt_bbox = inputs['gt_box']
  206. is_crowd = inputs['is_crowd']
  207. else:
  208. im_shape = inputs['im_shape']
  209. body_feats = self.backbone(im)
  210. body_feat_names = list(body_feats.keys())
  211. if self.fpn is not None:
  212. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  213. rois = self.rpn_head.get_proposals(body_feats, im_info, mode=self.mode)
  214. if self.mode == 'train':
  215. rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd)
  216. outputs = self.bbox_assigner(
  217. rpn_rois=rois,
  218. gt_classes=inputs['gt_label'],
  219. is_crowd=inputs['is_crowd'],
  220. gt_boxes=inputs['gt_box'],
  221. im_info=inputs['im_info'])
  222. rois = outputs[0]
  223. labels_int32 = outputs[1]
  224. bbox_targets = outputs[2]
  225. bbox_inside_weights = outputs[3]
  226. bbox_outside_weights = outputs[4]
  227. else:
  228. if self.rpn_only:
  229. im_scale = fluid.layers.slice(
  230. im_info, [1], starts=[2], ends=[3])
  231. im_scale = fluid.layers.sequence_expand(im_scale, rois)
  232. rois = rois / im_scale
  233. return {'proposal': rois}
  234. if self.fpn is None:
  235. # in models without FPN, roi extractor only uses the last level of
  236. # feature maps. And body_feat_names[-1] represents the name of
  237. # last feature map.
  238. body_feat = body_feats[body_feat_names[-1]]
  239. roi_feat = self.roi_extractor(body_feat, rois)
  240. else:
  241. roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
  242. if self.mode == 'train':
  243. loss = self.bbox_head.get_loss(roi_feat, labels_int32,
  244. bbox_targets, bbox_inside_weights,
  245. bbox_outside_weights)
  246. loss.update(rpn_loss)
  247. total_loss = fluid.layers.sum(list(loss.values()))
  248. loss.update({'loss': total_loss})
  249. return loss
  250. else:
  251. pred = self.bbox_head.get_prediction(roi_feat, rois, im_info,
  252. im_shape)
  253. return pred
  254. def generate_inputs(self):
  255. inputs = OrderedDict()
  256. if self.fixed_input_shape is not None:
  257. input_shape = [
  258. None, self.input_channel, self.fixed_input_shape[1],
  259. self.fixed_input_shape[0]
  260. ]
  261. inputs['image'] = fluid.data(
  262. dtype='float32', shape=input_shape, name='image')
  263. else:
  264. inputs['image'] = fluid.data(
  265. dtype='float32',
  266. shape=[None, self.input_channel, None, None],
  267. name='image')
  268. if self.mode == 'train':
  269. inputs['im_info'] = fluid.data(
  270. dtype='float32', shape=[None, 3], name='im_info')
  271. inputs['gt_box'] = fluid.data(
  272. dtype='float32', shape=[None, 4], lod_level=1, name='gt_box')
  273. inputs['gt_label'] = fluid.data(
  274. dtype='int32', shape=[None, 1], lod_level=1, name='gt_label')
  275. inputs['is_crowd'] = fluid.data(
  276. dtype='int32', shape=[None, 1], lod_level=1, name='is_crowd')
  277. elif self.mode == 'eval':
  278. inputs['im_info'] = fluid.data(
  279. dtype='float32', shape=[None, 3], name='im_info')
  280. inputs['im_id'] = fluid.data(
  281. dtype='int64', shape=[None, 1], name='im_id')
  282. inputs['im_shape'] = fluid.data(
  283. dtype='float32', shape=[None, 3], name='im_shape')
  284. inputs['gt_box'] = fluid.data(
  285. dtype='float32', shape=[None, 4], lod_level=1, name='gt_box')
  286. inputs['gt_label'] = fluid.data(
  287. dtype='int32', shape=[None, 1], lod_level=1, name='gt_label')
  288. inputs['is_difficult'] = fluid.data(
  289. dtype='int32',
  290. shape=[None, 1],
  291. lod_level=1,
  292. name='is_difficult')
  293. elif self.mode == 'test':
  294. inputs['im_info'] = fluid.data(
  295. dtype='float32', shape=[None, 3], name='im_info')
  296. inputs['im_shape'] = fluid.data(
  297. dtype='float32', shape=[None, 3], name='im_shape')
  298. return inputs