yolo_v3.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle import fluid
  16. from paddle.fluid.param_attr import ParamAttr
  17. from paddle.fluid.regularizer import L2Decay
  18. from collections import OrderedDict
  19. from .ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
  20. from .ops import DropBlock
  21. from .loss.yolo_loss import YOLOv3Loss
  22. from .loss.iou_loss import IouLoss
  23. from .loss.iou_aware_loss import IouAwareLoss
  24. from .iou_aware import get_iou_aware_score
  25. try:
  26. from collections.abc import Sequence
  27. except Exception:
  28. from collections import Sequence
  29. class YOLOv3:
  30. def __init__(
  31. self,
  32. backbone,
  33. mode='train',
  34. # YOLOv3Head
  35. num_classes=80,
  36. anchors=None,
  37. anchor_masks=None,
  38. coord_conv=False,
  39. iou_aware=False,
  40. iou_aware_factor=0.4,
  41. scale_x_y=1.,
  42. spp=False,
  43. drop_block=False,
  44. use_matrix_nms=False,
  45. # YOLOv3Loss
  46. batch_size=8,
  47. ignore_threshold=0.7,
  48. label_smooth=False,
  49. use_fine_grained_loss=False,
  50. use_iou_loss=False,
  51. iou_loss_weight=2.5,
  52. iou_aware_loss_weight=1.0,
  53. max_height=608,
  54. max_width=608,
  55. # NMS
  56. nms_score_threshold=0.01,
  57. nms_topk=1000,
  58. nms_keep_topk=100,
  59. nms_iou_threshold=0.45,
  60. fixed_input_shape=None):
  61. if anchors is None:
  62. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  63. [59, 119], [116, 90], [156, 198], [373, 326]]
  64. if anchor_masks is None:
  65. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  66. self.anchors = anchors
  67. self.anchor_masks = anchor_masks
  68. self._parse_anchors(anchors)
  69. self.mode = mode
  70. self.num_classes = num_classes
  71. self.backbone = backbone
  72. self.norm_decay = 0.0
  73. self.prefix_name = ''
  74. self.use_fine_grained_loss = use_fine_grained_loss
  75. self.fixed_input_shape = fixed_input_shape
  76. self.coord_conv = coord_conv
  77. self.iou_aware = iou_aware
  78. self.iou_aware_factor = iou_aware_factor
  79. self.scale_x_y = scale_x_y
  80. self.use_spp = spp
  81. self.drop_block = drop_block
  82. if use_matrix_nms:
  83. self.nms = MatrixNMS(
  84. background_label=-1,
  85. keep_top_k=nms_keep_topk,
  86. normalized=False,
  87. score_threshold=nms_score_threshold,
  88. post_threshold=0.01)
  89. else:
  90. self.nms = MultiClassNMS(
  91. background_label=-1,
  92. keep_top_k=nms_keep_topk,
  93. nms_threshold=nms_iou_threshold,
  94. nms_top_k=nms_topk,
  95. normalized=False,
  96. score_threshold=nms_score_threshold)
  97. self.iou_loss = None
  98. self.iou_aware_loss = None
  99. if use_iou_loss:
  100. self.iou_loss = IouLoss(
  101. loss_weight=iou_loss_weight,
  102. max_height=max_height,
  103. max_width=max_width)
  104. if iou_aware:
  105. self.iou_aware_loss = IouAwareLoss(
  106. loss_weight=iou_aware_loss_weight,
  107. max_height=max_height,
  108. max_width=max_width)
  109. self.yolo_loss = YOLOv3Loss(
  110. batch_size=batch_size,
  111. ignore_thresh=ignore_threshold,
  112. scale_x_y=scale_x_y,
  113. label_smooth=label_smooth,
  114. use_fine_grained_loss=self.use_fine_grained_loss,
  115. iou_loss=self.iou_loss,
  116. iou_aware_loss=self.iou_aware_loss)
  117. self.conv_block_num = 2
  118. self.block_size = 3
  119. self.keep_prob = 0.9
  120. self.downsample = [32, 16, 8]
  121. self.clip_bbox = True
  122. def _head(self, input, is_train=True):
  123. outputs = []
  124. # get last out_layer_num blocks in reverse order
  125. out_layer_num = len(self.anchor_masks)
  126. blocks = input[-1:-out_layer_num - 1:-1]
  127. route = None
  128. for i, block in enumerate(blocks):
  129. if i > 0: # perform concat in first 2 detection_block
  130. block = fluid.layers.concat(input=[route, block], axis=1)
  131. route, tip = self._detection_block(
  132. block,
  133. channel=64 * (2**out_layer_num) // (2**i),
  134. is_first=i == 0,
  135. is_test=(not is_train),
  136. conv_block_num=self.conv_block_num,
  137. name=self.prefix_name + "yolo_block.{}".format(i))
  138. # out channel number = mask_num * (5 + class_num)
  139. if self.iou_aware:
  140. num_filters = len(self.anchor_masks[i]) * (
  141. self.num_classes + 6)
  142. else:
  143. num_filters = len(self.anchor_masks[i]) * (
  144. self.num_classes + 5)
  145. with fluid.name_scope('yolo_output'):
  146. block_out = fluid.layers.conv2d(
  147. input=tip,
  148. num_filters=num_filters,
  149. filter_size=1,
  150. stride=1,
  151. padding=0,
  152. act=None,
  153. param_attr=ParamAttr(
  154. name=self.prefix_name +
  155. "yolo_output.{}.conv.weights".format(i)),
  156. bias_attr=ParamAttr(
  157. regularizer=L2Decay(0.),
  158. name=self.prefix_name +
  159. "yolo_output.{}.conv.bias".format(i)))
  160. outputs.append(block_out)
  161. if i < len(blocks) - 1:
  162. # do not perform upsample in the last detection_block
  163. route = self._conv_bn(
  164. input=route,
  165. ch_out=256 // (2**i),
  166. filter_size=1,
  167. stride=1,
  168. padding=0,
  169. is_test=(not is_train),
  170. name=self.prefix_name + "yolo_transition.{}".format(i))
  171. # upsample
  172. route = self._upsample(route)
  173. return outputs
  174. def _parse_anchors(self, anchors):
  175. self.anchors = []
  176. self.mask_anchors = []
  177. assert len(anchors) > 0, "ANCHORS not set."
  178. assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
  179. for anchor in anchors:
  180. assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
  181. self.anchors.extend(anchor)
  182. anchor_num = len(anchors)
  183. for masks in self.anchor_masks:
  184. self.mask_anchors.append([])
  185. for mask in masks:
  186. assert mask < anchor_num, "anchor mask index overflow"
  187. self.mask_anchors[-1].extend(anchors[mask])
  188. def _create_tensor_from_numpy(self, numpy_array):
  189. paddle_array = fluid.layers.create_global_var(
  190. shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
  191. fluid.layers.assign(numpy_array, paddle_array)
  192. return paddle_array
  193. def _add_coord(self, input, is_test=True):
  194. if not self.coord_conv:
  195. return input
  196. # NOTE: here is used for exporting model for TensorRT inference,
  197. # only support batch_size=1 for input shape should be fixed,
  198. # and we create tensor with fixed shape from numpy array
  199. if is_test and input.shape[2] > 0 and input.shape[3] > 0:
  200. batch_size = 1
  201. grid_x = int(input.shape[3])
  202. grid_y = int(input.shape[2])
  203. idx_i = np.array(
  204. [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
  205. dtype='float32')
  206. gi_np = np.repeat(idx_i, grid_y, axis=0)
  207. gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
  208. gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
  209. x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
  210. x_range.stop_gradient = True
  211. y_range = self._create_tensor_from_numpy(
  212. gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
  213. y_range.stop_gradient = True
  214. # NOTE: in training mode, H and W is variable for random shape,
  215. # implement add_coord with shape as Variable
  216. else:
  217. input_shape = fluid.layers.shape(input)
  218. b = input_shape[0]
  219. h = input_shape[2]
  220. w = input_shape[3]
  221. x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
  222. x_range = x_range - 1.
  223. x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
  224. x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
  225. x_range.stop_gradient = True
  226. y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
  227. y_range.stop_gradient = True
  228. return fluid.layers.concat([input, x_range, y_range], axis=1)
  229. def _conv_bn(self,
  230. input,
  231. ch_out,
  232. filter_size,
  233. stride,
  234. padding,
  235. act='leaky',
  236. is_test=False,
  237. name=None):
  238. conv = fluid.layers.conv2d(
  239. input=input,
  240. num_filters=ch_out,
  241. filter_size=filter_size,
  242. stride=stride,
  243. padding=padding,
  244. act=None,
  245. param_attr=ParamAttr(name=name + '.conv.weights'),
  246. bias_attr=False)
  247. bn_name = name + '.bn'
  248. bn_param_attr = ParamAttr(
  249. regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
  250. bn_bias_attr = ParamAttr(
  251. regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
  252. out = fluid.layers.batch_norm(
  253. input=conv,
  254. act=None,
  255. is_test=is_test,
  256. param_attr=bn_param_attr,
  257. bias_attr=bn_bias_attr,
  258. moving_mean_name=bn_name + '.mean',
  259. moving_variance_name=bn_name + '.var')
  260. if act == 'leaky':
  261. out = fluid.layers.leaky_relu(x=out, alpha=0.1)
  262. return out
  263. def _spp_module(self, input, is_test=True, name=""):
  264. output1 = input
  265. output2 = fluid.layers.pool2d(
  266. input=output1,
  267. pool_size=5,
  268. pool_stride=1,
  269. pool_padding=2,
  270. ceil_mode=False,
  271. pool_type='max')
  272. output3 = fluid.layers.pool2d(
  273. input=output1,
  274. pool_size=9,
  275. pool_stride=1,
  276. pool_padding=4,
  277. ceil_mode=False,
  278. pool_type='max')
  279. output4 = fluid.layers.pool2d(
  280. input=output1,
  281. pool_size=13,
  282. pool_stride=1,
  283. pool_padding=6,
  284. ceil_mode=False,
  285. pool_type='max')
  286. output = fluid.layers.concat(
  287. input=[output1, output2, output3, output4], axis=1)
  288. return output
  289. def _upsample(self, input, scale=2, name=None):
  290. out = fluid.layers.resize_nearest(
  291. input=input, scale=float(scale), name=name, align_corners=False)
  292. return out
  293. def _detection_block(self,
  294. input,
  295. channel,
  296. conv_block_num=2,
  297. is_first=False,
  298. is_test=True,
  299. name=None):
  300. assert channel % 2 == 0, \
  301. "channel {} cannot be divided by 2 in detection block {}" \
  302. .format(channel, name)
  303. conv = input
  304. for j in range(conv_block_num):
  305. conv = self._add_coord(conv, is_test=is_test)
  306. conv = self._conv_bn(
  307. conv,
  308. channel,
  309. filter_size=1,
  310. stride=1,
  311. padding=0,
  312. is_test=is_test,
  313. name='{}.{}.0'.format(name, j))
  314. if self.use_spp and is_first and j == 1:
  315. conv = self._spp_module(conv, is_test=is_test, name="spp")
  316. conv = self._conv_bn(
  317. conv,
  318. 512,
  319. filter_size=1,
  320. stride=1,
  321. padding=0,
  322. is_test=is_test,
  323. name='{}.{}.spp.conv'.format(name, j))
  324. conv = self._conv_bn(
  325. conv,
  326. channel * 2,
  327. filter_size=3,
  328. stride=1,
  329. padding=1,
  330. is_test=is_test,
  331. name='{}.{}.1'.format(name, j))
  332. if self.drop_block and j == 0 and not is_first:
  333. conv = DropBlock(
  334. conv,
  335. block_size=self.block_size,
  336. keep_prob=self.keep_prob,
  337. is_test=is_test)
  338. if self.drop_block and is_first:
  339. conv = DropBlock(
  340. conv,
  341. block_size=self.block_size,
  342. keep_prob=self.keep_prob,
  343. is_test=is_test)
  344. conv = self._add_coord(conv, is_test=is_test)
  345. route = self._conv_bn(
  346. conv,
  347. channel,
  348. filter_size=1,
  349. stride=1,
  350. padding=0,
  351. is_test=is_test,
  352. name='{}.2'.format(name))
  353. new_route = self._add_coord(route, is_test=is_test)
  354. tip = self._conv_bn(
  355. new_route,
  356. channel * 2,
  357. filter_size=3,
  358. stride=1,
  359. padding=1,
  360. is_test=is_test,
  361. name='{}.tip'.format(name))
  362. return route, tip
  363. def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
  364. loss = self.yolo_loss(inputs, gt_box, gt_label, gt_score, targets,
  365. self.anchors, self.anchor_masks,
  366. self.mask_anchors, self.num_classes,
  367. self.prefix_name)
  368. total_loss = fluid.layers.sum(list(loss.values()))
  369. return total_loss
  370. def _get_prediction(self, inputs, im_size):
  371. boxes = []
  372. scores = []
  373. for i, input in enumerate(inputs):
  374. if self.iou_aware:
  375. input = get_iou_aware_score(input,
  376. len(self.anchor_masks[i]),
  377. self.num_classes,
  378. self.iou_aware_factor)
  379. scale_x_y = self.scale_x_y if not isinstance(
  380. self.scale_x_y, Sequence) else self.scale_x_y[i]
  381. if paddle.__version__ < '1.8.4' and paddle.__version__ != '0.0.0':
  382. box, score = fluid.layers.yolo_box(
  383. x=input,
  384. img_size=im_size,
  385. anchors=self.mask_anchors[i],
  386. class_num=self.num_classes,
  387. conf_thresh=self.nms.score_threshold,
  388. downsample_ratio=self.downsample[i],
  389. name=self.prefix_name + 'yolo_box' + str(i),
  390. clip_bbox=self.clip_bbox)
  391. else:
  392. box, score = fluid.layers.yolo_box(
  393. x=input,
  394. img_size=im_size,
  395. anchors=self.mask_anchors[i],
  396. class_num=self.num_classes,
  397. conf_thresh=self.nms.score_threshold,
  398. downsample_ratio=self.downsample[i],
  399. name=self.prefix_name + 'yolo_box' + str(i),
  400. clip_bbox=self.clip_bbox,
  401. scale_x_y=self.scale_x_y)
  402. boxes.append(box)
  403. scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
  404. yolo_boxes = fluid.layers.concat(boxes, axis=1)
  405. yolo_scores = fluid.layers.concat(scores, axis=2)
  406. if type(self.nms) is MultiClassSoftNMS:
  407. yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
  408. pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
  409. return pred
  410. def generate_inputs(self):
  411. inputs = OrderedDict()
  412. if self.fixed_input_shape is not None:
  413. input_shape = [
  414. None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
  415. ]
  416. inputs['image'] = fluid.data(
  417. dtype='float32', shape=input_shape, name='image')
  418. else:
  419. inputs['image'] = fluid.data(
  420. dtype='float32', shape=[None, 3, None, None], name='image')
  421. if self.mode == 'train':
  422. inputs['gt_box'] = fluid.data(
  423. dtype='float32', shape=[None, None, 4], name='gt_box')
  424. inputs['gt_label'] = fluid.data(
  425. dtype='int32', shape=[None, None], name='gt_label')
  426. inputs['gt_score'] = fluid.data(
  427. dtype='float32', shape=[None, None], name='gt_score')
  428. inputs['im_size'] = fluid.data(
  429. dtype='int32', shape=[None, 2], name='im_size')
  430. if self.use_fine_grained_loss:
  431. downsample = 32
  432. for i, mask in enumerate(self.anchor_masks):
  433. if self.fixed_input_shape is not None:
  434. target_shape = [
  435. self.fixed_input_shape[1] // downsample,
  436. self.fixed_input_shape[0] // downsample
  437. ]
  438. else:
  439. target_shape = [None, None]
  440. inputs['target{}'.format(i)] = fluid.data(
  441. dtype='float32',
  442. lod_level=0,
  443. shape=[
  444. None, len(mask), 6 + self.num_classes,
  445. target_shape[0], target_shape[1]
  446. ],
  447. name='target{}'.format(i))
  448. downsample //= 2
  449. elif self.mode == 'eval':
  450. inputs['im_size'] = fluid.data(
  451. dtype='int32', shape=[None, 2], name='im_size')
  452. inputs['im_id'] = fluid.data(
  453. dtype='int32', shape=[None, 1], name='im_id')
  454. inputs['gt_box'] = fluid.data(
  455. dtype='float32', shape=[None, None, 4], name='gt_box')
  456. inputs['gt_label'] = fluid.data(
  457. dtype='int32', shape=[None, None], name='gt_label')
  458. inputs['is_difficult'] = fluid.data(
  459. dtype='int32', shape=[None, None], name='is_difficult')
  460. elif self.mode == 'test':
  461. inputs['im_size'] = fluid.data(
  462. dtype='int32', shape=[None, 2], name='im_size')
  463. return inputs
  464. def build_net(self, inputs):
  465. image = inputs['image']
  466. feats = self.backbone(image)
  467. if isinstance(feats, OrderedDict):
  468. feat_names = list(feats.keys())
  469. feats = [feats[name] for name in feat_names]
  470. head_outputs = self._head(feats, self.mode == 'train')
  471. if self.mode == 'train':
  472. gt_box = inputs['gt_box']
  473. gt_label = inputs['gt_label']
  474. gt_score = inputs['gt_score']
  475. im_size = inputs['im_size']
  476. num_boxes = fluid.layers.shape(gt_box)[1]
  477. im_size_wh = fluid.layers.reverse(im_size, axis=1)
  478. whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
  479. whwh = fluid.layers.unsqueeze(whwh, axes=[1])
  480. whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
  481. whwh = fluid.layers.cast(whwh, dtype='float32')
  482. whwh.stop_gradient = True
  483. normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
  484. targets = []
  485. if self.use_fine_grained_loss:
  486. for i, mask in enumerate(self.anchor_masks):
  487. k = 'target{}'.format(i)
  488. if k in inputs:
  489. targets.append(inputs[k])
  490. return self._get_loss(head_outputs, normalized_box, gt_label,
  491. gt_score, targets)
  492. else:
  493. im_size = inputs['im_size']
  494. return self._get_prediction(head_outputs, im_size)