mask_rcnn.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. import copy
  19. import paddle.fluid as fluid
  20. from .fpn import (FPN, HRFPN)
  21. from .rpn_head import (RPNHead, FPNRPNHead)
  22. from .roi_extractor import (RoIAlign, FPNRoIAlign)
  23. from .bbox_head import (BBoxHead, TwoFCHead)
  24. from .mask_head import MaskHead
  25. from ..resnet import ResNetC5
  26. from .loss.diou_loss import DiouLoss
  27. from .ops import BBoxAssigner, LibraBBoxAssigner
  28. __all__ = ['MaskRCNN']
  29. class MaskRCNN(object):
  30. """
  31. Mask R-CNN architecture, see https://arxiv.org/abs/1703.06870
  32. Args:
  33. backbone (object): backbone instance
  34. rpn_head (object): `RPNhead` instance
  35. roi_extractor (object): ROI extractor instance
  36. bbox_head (object): `BBoxHead` instance
  37. mask_head (object): `MaskHead` instance
  38. fpn (object): feature pyramid network instance
  39. """
  40. def __init__(
  41. self,
  42. backbone,
  43. input_channel=3,
  44. mode='train',
  45. num_classes=81,
  46. with_fpn=False,
  47. fpn=None,
  48. min_level=2,
  49. max_level=6,
  50. spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.],
  51. #rpn_head
  52. rpn_only=False,
  53. rpn_head=None,
  54. anchor_sizes=[32, 64, 128, 256, 512],
  55. aspect_ratios=[0.5, 1.0, 2.0],
  56. rpn_batch_size_per_im=256,
  57. rpn_fg_fraction=0.5,
  58. rpn_positive_overlap=0.7,
  59. rpn_negative_overlap=0.3,
  60. train_pre_nms_top_n=12000,
  61. train_post_nms_top_n=2000,
  62. train_nms_thresh=0.7,
  63. test_pre_nms_top_n=6000,
  64. test_post_nms_top_n=1000,
  65. test_nms_thresh=0.7,
  66. rpn_cls_loss='SigmoidCrossEntropy',
  67. rpn_focal_loss_alpha=0.25,
  68. rpn_focal_loss_gamma=2,
  69. #roi_extractor
  70. roi_extractor=None,
  71. #bbox_head
  72. bbox_head=None,
  73. keep_top_k=100,
  74. nms_threshold=0.5,
  75. score_threshold=0.05,
  76. rcnn_nms='MultiClassNMS',
  77. softnms_sigma=0.5,
  78. post_threshold=.05,
  79. #MaskHead
  80. mask_head=None,
  81. num_convs=0,
  82. mask_head_resolution=14,
  83. #bbox_assigner
  84. batch_size_per_im=512,
  85. fg_fraction=.25,
  86. fg_thresh=.5,
  87. bg_thresh_hi=.5,
  88. bg_thresh_lo=0.,
  89. bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
  90. fixed_input_shape=None,
  91. rcnn_bbox_loss='SmoothL1Loss',
  92. diouloss_weight=10.0,
  93. diouloss_is_cls_agnostic=False,
  94. diouloss_use_complete_iou_loss=True,
  95. bbox_assigner='BBoxAssigner',
  96. fpn_num_channels=256):
  97. super(MaskRCNN, self).__init__()
  98. self.backbone = backbone
  99. self.mode = mode
  100. if with_fpn and fpn is None:
  101. if self.backbone.__class__.__name__.startswith('HRNet'):
  102. fpn = HRFPN()
  103. fpn.min_level = 2
  104. fpn.max_level = 6
  105. else:
  106. fpn = FPN(num_chan=fpn_num_channels,
  107. min_level=min_level,
  108. max_level=max_level,
  109. spatial_scale=spatial_scale)
  110. self.fpn = fpn
  111. if self.fpn is not None:
  112. self.fpn.num_chan = fpn_num_channels
  113. self.num_classes = num_classes
  114. if rpn_head is None:
  115. if self.fpn is None:
  116. rpn_head = RPNHead(
  117. anchor_sizes=anchor_sizes,
  118. aspect_ratios=aspect_ratios,
  119. rpn_batch_size_per_im=rpn_batch_size_per_im,
  120. rpn_fg_fraction=rpn_fg_fraction,
  121. rpn_positive_overlap=rpn_positive_overlap,
  122. rpn_negative_overlap=rpn_negative_overlap,
  123. train_pre_nms_top_n=train_pre_nms_top_n,
  124. train_post_nms_top_n=train_post_nms_top_n,
  125. train_nms_thresh=train_nms_thresh,
  126. test_pre_nms_top_n=test_pre_nms_top_n,
  127. test_post_nms_top_n=test_post_nms_top_n,
  128. test_nms_thresh=test_nms_thresh,
  129. rpn_cls_loss=rpn_cls_loss,
  130. rpn_focal_loss_alpha=rpn_focal_loss_alpha,
  131. rpn_focal_loss_gamma=rpn_focal_loss_gamma)
  132. else:
  133. rpn_head = FPNRPNHead(
  134. anchor_start_size=anchor_sizes[0],
  135. aspect_ratios=aspect_ratios,
  136. num_chan=self.fpn.num_chan,
  137. min_level=self.fpn.min_level,
  138. max_level=self.fpn.max_level,
  139. rpn_batch_size_per_im=rpn_batch_size_per_im,
  140. rpn_fg_fraction=rpn_fg_fraction,
  141. rpn_positive_overlap=rpn_positive_overlap,
  142. rpn_negative_overlap=rpn_negative_overlap,
  143. train_pre_nms_top_n=train_pre_nms_top_n,
  144. train_post_nms_top_n=train_post_nms_top_n,
  145. train_nms_thresh=train_nms_thresh,
  146. test_pre_nms_top_n=test_pre_nms_top_n,
  147. test_post_nms_top_n=test_post_nms_top_n,
  148. test_nms_thresh=test_nms_thresh,
  149. rpn_cls_loss=rpn_cls_loss,
  150. rpn_focal_loss_alpha=rpn_focal_loss_alpha,
  151. rpn_focal_loss_gamma=rpn_focal_loss_gamma)
  152. self.rpn_head = rpn_head
  153. if roi_extractor is None:
  154. if self.fpn is None:
  155. roi_extractor = RoIAlign(
  156. resolution=14,
  157. spatial_scale=1. / 2**self.backbone.feature_maps[0])
  158. else:
  159. roi_extractor = FPNRoIAlign(sampling_ratio=2)
  160. self.roi_extractor = roi_extractor
  161. if bbox_head is None:
  162. if self.fpn is None:
  163. head = ResNetC5(
  164. layers=self.backbone.layers,
  165. norm_type=self.backbone.norm_type,
  166. freeze_norm=self.backbone.freeze_norm,
  167. variant=self.backbone.variant)
  168. else:
  169. head = TwoFCHead()
  170. bbox_head = BBoxHead(
  171. head=head,
  172. keep_top_k=keep_top_k,
  173. nms_threshold=nms_threshold,
  174. score_threshold=score_threshold,
  175. rcnn_nms=rcnn_nms,
  176. softnms_sigma=softnms_sigma,
  177. post_threshold=post_threshold,
  178. num_classes=num_classes,
  179. rcnn_bbox_loss=rcnn_bbox_loss,
  180. diouloss_weight=diouloss_weight,
  181. diouloss_is_cls_agnostic=diouloss_is_cls_agnostic,
  182. diouloss_use_complete_iou_loss=diouloss_use_complete_iou_loss)
  183. self.bbox_head = bbox_head
  184. if mask_head is None:
  185. mask_head = MaskHead(
  186. num_convs=num_convs,
  187. resolution=mask_head_resolution,
  188. num_classes=num_classes)
  189. self.mask_head = mask_head
  190. self.batch_size_per_im = batch_size_per_im
  191. self.fg_fraction = fg_fraction
  192. self.fg_thresh = fg_thresh
  193. self.bg_thresh_hi = bg_thresh_hi
  194. self.bg_thresh_lo = bg_thresh_lo
  195. self.bbox_reg_weights = bbox_reg_weights
  196. self.rpn_only = rpn_only
  197. self.fixed_input_shape = fixed_input_shape
  198. if bbox_assigner == 'BBoxAssigner':
  199. self.bbox_assigner = BBoxAssigner(
  200. batch_size_per_im=batch_size_per_im,
  201. fg_fraction=fg_fraction,
  202. fg_thresh=fg_thresh,
  203. bg_thresh_hi=bg_thresh_hi,
  204. bg_thresh_lo=bg_thresh_lo,
  205. bbox_reg_weights=bbox_reg_weights,
  206. num_classes=num_classes,
  207. shuffle_before_sample=self.rpn_head.use_random)
  208. elif bbox_assigner == 'LibraBBoxAssigner':
  209. self.bbox_assigner = LibraBBoxAssigner(
  210. batch_size_per_im=batch_size_per_im,
  211. fg_fraction=fg_fraction,
  212. fg_thresh=fg_thresh,
  213. bg_thresh_hi=bg_thresh_hi,
  214. bg_thresh_lo=bg_thresh_lo,
  215. bbox_reg_weights=bbox_reg_weights,
  216. num_classes=num_classes,
  217. shuffle_before_sample=self.rpn_head.use_random)
  218. self.input_channel = input_channel
  219. def build_net(self, inputs):
  220. im = inputs['image']
  221. im_info = inputs['im_info']
  222. if self.mode == 'train':
  223. gt_bbox = inputs['gt_box']
  224. is_crowd = inputs['is_crowd']
  225. else:
  226. im_shape = inputs['im_shape']
  227. # backbone
  228. body_feats = self.backbone(im)
  229. body_feat_names = list(body_feats.keys())
  230. # FPN
  231. spatial_scale = None
  232. if self.fpn is not None:
  233. body_feats, spatial_scale = self.fpn.get_output(body_feats)
  234. # RPN proposals
  235. rois = self.rpn_head.get_proposals(body_feats, im_info, mode=self.mode)
  236. if self.mode == 'train':
  237. rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd)
  238. outputs = self.bbox_assigner(
  239. rpn_rois=rois,
  240. gt_classes=inputs['gt_label'],
  241. is_crowd=inputs['is_crowd'],
  242. gt_boxes=inputs['gt_box'],
  243. im_info=inputs['im_info'])
  244. rois = outputs[0]
  245. labels_int32 = outputs[1]
  246. if self.fpn is None:
  247. last_feat = body_feats[body_feat_names[-1]]
  248. roi_feat = self.roi_extractor(last_feat, rois)
  249. else:
  250. roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
  251. loss = self.bbox_head.get_loss(roi_feat, labels_int32,
  252. *outputs[2:])
  253. loss.update(rpn_loss)
  254. mask_rois, roi_has_mask_int32, mask_int32 = fluid.layers.generate_mask_labels(
  255. rois=rois,
  256. gt_classes=inputs['gt_label'],
  257. is_crowd=inputs['is_crowd'],
  258. gt_segms=inputs['gt_mask'],
  259. im_info=inputs['im_info'],
  260. labels_int32=labels_int32,
  261. num_classes=self.num_classes,
  262. resolution=self.mask_head.resolution)
  263. if self.fpn is None:
  264. bbox_head_feat = self.bbox_head.get_head_feat()
  265. feat = fluid.layers.gather(bbox_head_feat, roi_has_mask_int32)
  266. else:
  267. feat = self.roi_extractor(
  268. body_feats, mask_rois, spatial_scale, is_mask=True)
  269. mask_loss = self.mask_head.get_loss(feat, mask_int32)
  270. loss.update(mask_loss)
  271. total_loss = fluid.layers.sum(list(loss.values()))
  272. loss.update({'loss': total_loss})
  273. return loss
  274. else:
  275. if self.rpn_only:
  276. im_scale = fluid.layers.slice(
  277. im_info, [1], starts=[2], ends=[3])
  278. im_scale = fluid.layers.sequence_expand(im_scale, rois)
  279. rois = rois / im_scale
  280. return {'proposal': rois}
  281. mask_name = 'mask_pred'
  282. mask_pred, bbox_pred = self._eval(body_feats, mask_name, rois,
  283. im_info, im_shape, spatial_scale)
  284. return OrderedDict(zip(['bbox', 'mask'], [bbox_pred, mask_pred]))
  285. def _eval(self,
  286. body_feats,
  287. mask_name,
  288. rois,
  289. im_info,
  290. im_shape,
  291. spatial_scale,
  292. bbox_pred=None):
  293. if not bbox_pred:
  294. if self.fpn is None:
  295. last_feat = body_feats[list(body_feats.keys())[-1]]
  296. roi_feat = self.roi_extractor(last_feat, rois)
  297. else:
  298. roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
  299. bbox_pred = self.bbox_head.get_prediction(roi_feat, rois, im_info,
  300. im_shape)
  301. bbox_pred = bbox_pred['bbox']
  302. # share weight
  303. bbox_shape = fluid.layers.shape(bbox_pred)
  304. bbox_size = fluid.layers.reduce_prod(bbox_shape)
  305. bbox_size = fluid.layers.reshape(bbox_size, [1, 1])
  306. size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32')
  307. cond = fluid.layers.less_than(x=bbox_size, y=size)
  308. mask_pred = fluid.layers.create_global_var(
  309. shape=[1],
  310. value=0.0,
  311. dtype='float32',
  312. persistable=False,
  313. name=mask_name)
  314. with fluid.layers.control_flow.Switch() as switch:
  315. with switch.case(cond):
  316. fluid.layers.assign(input=bbox_pred, output=mask_pred)
  317. with switch.default():
  318. bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
  319. im_scale = fluid.layers.slice(
  320. im_info, [1], starts=[2], ends=[3])
  321. im_scale = fluid.layers.sequence_expand(im_scale, bbox)
  322. mask_rois = bbox * im_scale
  323. if self.fpn is None:
  324. last_feat = body_feats[list(body_feats.keys())[-1]]
  325. mask_feat = self.roi_extractor(last_feat, mask_rois)
  326. mask_feat = self.bbox_head.get_head_feat(mask_feat)
  327. else:
  328. mask_feat = self.roi_extractor(
  329. body_feats, mask_rois, spatial_scale, is_mask=True)
  330. mask_out = self.mask_head.get_prediction(mask_feat, bbox)
  331. fluid.layers.assign(input=mask_out, output=mask_pred)
  332. return mask_pred, bbox_pred
  333. def generate_inputs(self):
  334. inputs = OrderedDict()
  335. if self.fixed_input_shape is not None:
  336. input_shape = [
  337. None, self.input_channel, self.fixed_input_shape[1],
  338. self.fixed_input_shape[0]
  339. ]
  340. inputs['image'] = fluid.data(
  341. dtype='float32', shape=input_shape, name='image')
  342. else:
  343. inputs['image'] = fluid.data(
  344. dtype='float32',
  345. shape=[None, self.input_channel, None, None],
  346. name='image')
  347. if self.mode == 'train':
  348. inputs['im_info'] = fluid.data(
  349. dtype='float32', shape=[None, 3], name='im_info')
  350. inputs['gt_box'] = fluid.data(
  351. dtype='float32', shape=[None, 4], lod_level=1, name='gt_box')
  352. inputs['gt_label'] = fluid.data(
  353. dtype='int32', shape=[None, 1], lod_level=1, name='gt_label')
  354. inputs['is_crowd'] = fluid.data(
  355. dtype='int32', shape=[None, 1], lod_level=1, name='is_crowd')
  356. inputs['gt_mask'] = fluid.data(
  357. dtype='float32', shape=[None, 2], lod_level=3, name='gt_mask')
  358. elif self.mode == 'eval':
  359. inputs['im_info'] = fluid.data(
  360. dtype='float32', shape=[None, 3], name='im_info')
  361. inputs['im_id'] = fluid.data(
  362. dtype='int64', shape=[None, 1], name='im_id')
  363. inputs['im_shape'] = fluid.data(
  364. dtype='float32', shape=[None, 3], name='im_shape')
  365. elif self.mode == 'test':
  366. inputs['im_info'] = fluid.data(
  367. dtype='float32', shape=[None, 3], name='im_info')
  368. inputs['im_shape'] = fluid.data(
  369. dtype='float32', shape=[None, 3], name='im_shape')
  370. return inputs