yolo_v3.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from paddle import fluid
  17. from paddle.fluid.param_attr import ParamAttr
  18. from paddle.fluid.regularizer import L2Decay
  19. from collections import OrderedDict
  20. from .ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
  21. from .ops import DropBlock
  22. from .loss.yolo_loss import YOLOv3Loss
  23. from .loss.iou_loss import IouLoss
  24. from .loss.iou_aware_loss import IouAwareLoss
  25. from .iou_aware import get_iou_aware_score
  26. try:
  27. from collections.abc import Sequence
  28. except Exception:
  29. from collections import Sequence
  30. class YOLOv3:
  31. def __init__(
  32. self,
  33. backbone,
  34. mode='train',
  35. # YOLOv3Head
  36. num_classes=80,
  37. anchors=None,
  38. anchor_masks=None,
  39. coord_conv=False,
  40. iou_aware=False,
  41. iou_aware_factor=0.4,
  42. scale_x_y=1.,
  43. spp=False,
  44. drop_block=False,
  45. use_matrix_nms=False,
  46. # YOLOv3Loss
  47. batch_size=8,
  48. ignore_threshold=0.7,
  49. label_smooth=False,
  50. use_fine_grained_loss=False,
  51. use_iou_loss=False,
  52. iou_loss_weight=2.5,
  53. iou_aware_loss_weight=1.0,
  54. max_height=608,
  55. max_width=608,
  56. # NMS
  57. nms_score_threshold=0.01,
  58. nms_topk=1000,
  59. nms_keep_topk=100,
  60. nms_iou_threshold=0.45,
  61. fixed_input_shape=None):
  62. if anchors is None:
  63. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  64. [59, 119], [116, 90], [156, 198], [373, 326]]
  65. if anchor_masks is None:
  66. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  67. self.anchors = anchors
  68. self.anchor_masks = anchor_masks
  69. self._parse_anchors(anchors)
  70. self.mode = mode
  71. self.num_classes = num_classes
  72. self.backbone = backbone
  73. self.norm_decay = 0.0
  74. self.prefix_name = ''
  75. self.use_fine_grained_loss = use_fine_grained_loss
  76. self.fixed_input_shape = fixed_input_shape
  77. self.coord_conv = coord_conv
  78. self.iou_aware = iou_aware
  79. self.iou_aware_factor = iou_aware_factor
  80. self.scale_x_y = scale_x_y
  81. self.use_spp = spp
  82. self.drop_block = drop_block
  83. if use_matrix_nms:
  84. self.nms = MatrixNMS(
  85. background_label=-1,
  86. keep_top_k=nms_keep_topk,
  87. normalized=False,
  88. score_threshold=nms_score_threshold,
  89. post_threshold=0.01)
  90. else:
  91. self.nms = MultiClassNMS(
  92. background_label=-1,
  93. keep_top_k=nms_keep_topk,
  94. nms_threshold=nms_iou_threshold,
  95. nms_top_k=nms_topk,
  96. normalized=False,
  97. score_threshold=nms_score_threshold)
  98. self.iou_loss = None
  99. self.iou_aware_loss = None
  100. if use_iou_loss:
  101. self.iou_loss = IouLoss(
  102. loss_weight=iou_loss_weight,
  103. max_height=max_height,
  104. max_width=max_width)
  105. if iou_aware:
  106. self.iou_aware_loss = IouAwareLoss(
  107. loss_weight=iou_aware_loss_weight,
  108. max_height=max_height,
  109. max_width=max_width)
  110. self.yolo_loss = YOLOv3Loss(
  111. batch_size=batch_size,
  112. ignore_thresh=ignore_threshold,
  113. scale_x_y=scale_x_y,
  114. label_smooth=label_smooth,
  115. use_fine_grained_loss=self.use_fine_grained_loss,
  116. iou_loss=self.iou_loss,
  117. iou_aware_loss=self.iou_aware_loss)
  118. self.conv_block_num = 2
  119. self.block_size = 3
  120. self.keep_prob = 0.9
  121. self.downsample = [32, 16, 8]
  122. self.clip_bbox = True
  123. def _head(self, input, is_train=True):
  124. outputs = []
  125. # get last out_layer_num blocks in reverse order
  126. out_layer_num = len(self.anchor_masks)
  127. blocks = input[-1:-out_layer_num - 1:-1]
  128. route = None
  129. for i, block in enumerate(blocks):
  130. if i > 0: # perform concat in first 2 detection_block
  131. block = fluid.layers.concat(input=[route, block], axis=1)
  132. route, tip = self._detection_block(
  133. block,
  134. channel=64 * (2**out_layer_num) // (2**i),
  135. is_first=i == 0,
  136. is_test=(not is_train),
  137. conv_block_num=self.conv_block_num,
  138. name=self.prefix_name + "yolo_block.{}".format(i))
  139. # out channel number = mask_num * (5 + class_num)
  140. if self.iou_aware:
  141. num_filters = len(self.anchor_masks[i]) * (
  142. self.num_classes + 6)
  143. else:
  144. num_filters = len(self.anchor_masks[i]) * (
  145. self.num_classes + 5)
  146. with fluid.name_scope('yolo_output'):
  147. block_out = fluid.layers.conv2d(
  148. input=tip,
  149. num_filters=num_filters,
  150. filter_size=1,
  151. stride=1,
  152. padding=0,
  153. act=None,
  154. param_attr=ParamAttr(
  155. name=self.prefix_name +
  156. "yolo_output.{}.conv.weights".format(i)),
  157. bias_attr=ParamAttr(
  158. regularizer=L2Decay(0.),
  159. name=self.prefix_name +
  160. "yolo_output.{}.conv.bias".format(i)))
  161. outputs.append(block_out)
  162. if i < len(blocks) - 1:
  163. # do not perform upsample in the last detection_block
  164. route = self._conv_bn(
  165. input=route,
  166. ch_out=256 // (2**i),
  167. filter_size=1,
  168. stride=1,
  169. padding=0,
  170. is_test=(not is_train),
  171. name=self.prefix_name + "yolo_transition.{}".format(i))
  172. # upsample
  173. route = self._upsample(route)
  174. return outputs
  175. def _parse_anchors(self, anchors):
  176. self.anchors = []
  177. self.mask_anchors = []
  178. assert len(anchors) > 0, "ANCHORS not set."
  179. assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
  180. for anchor in anchors:
  181. assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
  182. self.anchors.extend(anchor)
  183. anchor_num = len(anchors)
  184. for masks in self.anchor_masks:
  185. self.mask_anchors.append([])
  186. for mask in masks:
  187. assert mask < anchor_num, "anchor mask index overflow"
  188. self.mask_anchors[-1].extend(anchors[mask])
  189. def _create_tensor_from_numpy(self, numpy_array):
  190. paddle_array = fluid.layers.create_global_var(
  191. shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
  192. fluid.layers.assign(numpy_array, paddle_array)
  193. return paddle_array
  194. def _add_coord(self, input, is_test=True):
  195. if not self.coord_conv:
  196. return input
  197. # NOTE: here is used for exporting model for TensorRT inference,
  198. # only support batch_size=1 for input shape should be fixed,
  199. # and we create tensor with fixed shape from numpy array
  200. if is_test and input.shape[2] > 0 and input.shape[3] > 0:
  201. batch_size = 1
  202. grid_x = int(input.shape[3])
  203. grid_y = int(input.shape[2])
  204. idx_i = np.array(
  205. [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
  206. dtype='float32')
  207. gi_np = np.repeat(idx_i, grid_y, axis=0)
  208. gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
  209. gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
  210. x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
  211. x_range.stop_gradient = True
  212. y_range = self._create_tensor_from_numpy(
  213. gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
  214. y_range.stop_gradient = True
  215. # NOTE: in training mode, H and W is variable for random shape,
  216. # implement add_coord with shape as Variable
  217. else:
  218. input_shape = fluid.layers.shape(input)
  219. b = input_shape[0]
  220. h = input_shape[2]
  221. w = input_shape[3]
  222. x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
  223. x_range = x_range - 1.
  224. x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
  225. x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
  226. x_range.stop_gradient = True
  227. y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
  228. y_range.stop_gradient = True
  229. return fluid.layers.concat([input, x_range, y_range], axis=1)
  230. def _conv_bn(self,
  231. input,
  232. ch_out,
  233. filter_size,
  234. stride,
  235. padding,
  236. act='leaky',
  237. is_test=False,
  238. name=None):
  239. conv = fluid.layers.conv2d(
  240. input=input,
  241. num_filters=ch_out,
  242. filter_size=filter_size,
  243. stride=stride,
  244. padding=padding,
  245. act=None,
  246. param_attr=ParamAttr(name=name + '.conv.weights'),
  247. bias_attr=False)
  248. bn_name = name + '.bn'
  249. bn_param_attr = ParamAttr(
  250. regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
  251. bn_bias_attr = ParamAttr(
  252. regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
  253. out = fluid.layers.batch_norm(
  254. input=conv,
  255. act=None,
  256. is_test=is_test,
  257. param_attr=bn_param_attr,
  258. bias_attr=bn_bias_attr,
  259. moving_mean_name=bn_name + '.mean',
  260. moving_variance_name=bn_name + '.var')
  261. if act == 'leaky':
  262. out = fluid.layers.leaky_relu(x=out, alpha=0.1)
  263. return out
  264. def _spp_module(self, input, is_test=True, name=""):
  265. output1 = input
  266. output2 = fluid.layers.pool2d(
  267. input=output1,
  268. pool_size=5,
  269. pool_stride=1,
  270. pool_padding=2,
  271. ceil_mode=False,
  272. pool_type='max')
  273. output3 = fluid.layers.pool2d(
  274. input=output1,
  275. pool_size=9,
  276. pool_stride=1,
  277. pool_padding=4,
  278. ceil_mode=False,
  279. pool_type='max')
  280. output4 = fluid.layers.pool2d(
  281. input=output1,
  282. pool_size=13,
  283. pool_stride=1,
  284. pool_padding=6,
  285. ceil_mode=False,
  286. pool_type='max')
  287. output = fluid.layers.concat(
  288. input=[output1, output2, output3, output4], axis=1)
  289. return output
  290. def _upsample(self, input, scale=2, name=None):
  291. out = fluid.layers.resize_nearest(
  292. input=input, scale=float(scale), name=name, align_corners=False)
  293. return out
  294. def _detection_block(self,
  295. input,
  296. channel,
  297. conv_block_num=2,
  298. is_first=False,
  299. is_test=True,
  300. name=None):
  301. assert channel % 2 == 0, \
  302. "channel {} cannot be divided by 2 in detection block {}" \
  303. .format(channel, name)
  304. conv = input
  305. for j in range(conv_block_num):
  306. conv = self._add_coord(conv, is_test=is_test)
  307. conv = self._conv_bn(
  308. conv,
  309. channel,
  310. filter_size=1,
  311. stride=1,
  312. padding=0,
  313. is_test=is_test,
  314. name='{}.{}.0'.format(name, j))
  315. if self.use_spp and is_first and j == 1:
  316. conv = self._spp_module(conv, is_test=is_test, name="spp")
  317. conv = self._conv_bn(
  318. conv,
  319. 512,
  320. filter_size=1,
  321. stride=1,
  322. padding=0,
  323. is_test=is_test,
  324. name='{}.{}.spp.conv'.format(name, j))
  325. conv = self._conv_bn(
  326. conv,
  327. channel * 2,
  328. filter_size=3,
  329. stride=1,
  330. padding=1,
  331. is_test=is_test,
  332. name='{}.{}.1'.format(name, j))
  333. if self.drop_block and j == 0 and not is_first:
  334. conv = DropBlock(
  335. conv,
  336. block_size=self.block_size,
  337. keep_prob=self.keep_prob,
  338. is_test=is_test)
  339. if self.drop_block and is_first:
  340. conv = DropBlock(
  341. conv,
  342. block_size=self.block_size,
  343. keep_prob=self.keep_prob,
  344. is_test=is_test)
  345. conv = self._add_coord(conv, is_test=is_test)
  346. route = self._conv_bn(
  347. conv,
  348. channel,
  349. filter_size=1,
  350. stride=1,
  351. padding=0,
  352. is_test=is_test,
  353. name='{}.2'.format(name))
  354. new_route = self._add_coord(route, is_test=is_test)
  355. tip = self._conv_bn(
  356. new_route,
  357. channel * 2,
  358. filter_size=3,
  359. stride=1,
  360. padding=1,
  361. is_test=is_test,
  362. name='{}.tip'.format(name))
  363. return route, tip
  364. def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
  365. loss = self.yolo_loss(inputs, gt_box, gt_label, gt_score, targets,
  366. self.anchors, self.anchor_masks,
  367. self.mask_anchors, self.num_classes,
  368. self.prefix_name)
  369. total_loss = fluid.layers.sum(list(loss.values()))
  370. return total_loss
  371. def _get_prediction(self, inputs, im_size):
  372. boxes = []
  373. scores = []
  374. for i, input in enumerate(inputs):
  375. if self.iou_aware:
  376. input = get_iou_aware_score(input,
  377. len(self.anchor_masks[i]),
  378. self.num_classes,
  379. self.iou_aware_factor)
  380. scale_x_y = self.scale_x_y if not isinstance(
  381. self.scale_x_y, Sequence) else self.scale_x_y[i]
  382. if paddle.__version__ < '1.8.4' and paddle.__version__ != '0.0.0':
  383. box, score = fluid.layers.yolo_box(
  384. x=input,
  385. img_size=im_size,
  386. anchors=self.mask_anchors[i],
  387. class_num=self.num_classes,
  388. conf_thresh=self.nms.score_threshold,
  389. downsample_ratio=self.downsample[i],
  390. name=self.prefix_name + 'yolo_box' + str(i),
  391. clip_bbox=self.clip_bbox)
  392. else:
  393. box, score = fluid.layers.yolo_box(
  394. x=input,
  395. img_size=im_size,
  396. anchors=self.mask_anchors[i],
  397. class_num=self.num_classes,
  398. conf_thresh=self.nms.score_threshold,
  399. downsample_ratio=self.downsample[i],
  400. name=self.prefix_name + 'yolo_box' + str(i),
  401. clip_bbox=self.clip_bbox,
  402. scale_x_y=self.scale_x_y)
  403. boxes.append(box)
  404. scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
  405. yolo_boxes = fluid.layers.concat(boxes, axis=1)
  406. yolo_scores = fluid.layers.concat(scores, axis=2)
  407. if type(self.nms) is MultiClassSoftNMS:
  408. yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
  409. pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
  410. return pred
  411. def generate_inputs(self):
  412. inputs = OrderedDict()
  413. if self.fixed_input_shape is not None:
  414. input_shape = [
  415. None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
  416. ]
  417. inputs['image'] = fluid.data(
  418. dtype='float32', shape=input_shape, name='image')
  419. else:
  420. inputs['image'] = fluid.data(
  421. dtype='float32', shape=[None, 3, None, None], name='image')
  422. if self.mode == 'train':
  423. inputs['gt_box'] = fluid.data(
  424. dtype='float32', shape=[None, None, 4], name='gt_box')
  425. inputs['gt_label'] = fluid.data(
  426. dtype='int32', shape=[None, None], name='gt_label')
  427. inputs['gt_score'] = fluid.data(
  428. dtype='float32', shape=[None, None], name='gt_score')
  429. inputs['im_size'] = fluid.data(
  430. dtype='int32', shape=[None, 2], name='im_size')
  431. if self.use_fine_grained_loss:
  432. downsample = 32
  433. for i, mask in enumerate(self.anchor_masks):
  434. if self.fixed_input_shape is not None:
  435. target_shape = [
  436. self.fixed_input_shape[1] // downsample,
  437. self.fixed_input_shape[0] // downsample
  438. ]
  439. else:
  440. target_shape = [None, None]
  441. inputs['target{}'.format(i)] = fluid.data(
  442. dtype='float32',
  443. lod_level=0,
  444. shape=[
  445. None, len(mask), 6 + self.num_classes,
  446. target_shape[0], target_shape[1]
  447. ],
  448. name='target{}'.format(i))
  449. downsample //= 2
  450. elif self.mode == 'eval':
  451. inputs['im_size'] = fluid.data(
  452. dtype='int32', shape=[None, 2], name='im_size')
  453. inputs['im_id'] = fluid.data(
  454. dtype='int32', shape=[None, 1], name='im_id')
  455. inputs['gt_box'] = fluid.data(
  456. dtype='float32', shape=[None, None, 4], name='gt_box')
  457. inputs['gt_label'] = fluid.data(
  458. dtype='int32', shape=[None, None], name='gt_label')
  459. inputs['is_difficult'] = fluid.data(
  460. dtype='int32', shape=[None, None], name='is_difficult')
  461. elif self.mode == 'test':
  462. inputs['im_size'] = fluid.data(
  463. dtype='int32', shape=[None, 2], name='im_size')
  464. return inputs
  465. def build_net(self, inputs):
  466. image = inputs['image']
  467. feats = self.backbone(image)
  468. if isinstance(feats, OrderedDict):
  469. feat_names = list(feats.keys())
  470. feats = [feats[name] for name in feat_names]
  471. head_outputs = self._head(feats, self.mode == 'train')
  472. if self.mode == 'train':
  473. gt_box = inputs['gt_box']
  474. gt_label = inputs['gt_label']
  475. gt_score = inputs['gt_score']
  476. im_size = inputs['im_size']
  477. num_boxes = fluid.layers.shape(gt_box)[1]
  478. im_size_wh = fluid.layers.reverse(im_size, axis=1)
  479. whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
  480. whwh = fluid.layers.unsqueeze(whwh, axes=[1])
  481. whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
  482. whwh = fluid.layers.cast(whwh, dtype='float32')
  483. whwh.stop_gradient = True
  484. normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
  485. targets = []
  486. if self.use_fine_grained_loss:
  487. for i, mask in enumerate(self.anchor_masks):
  488. k = 'target{}'.format(i)
  489. if k in inputs:
  490. targets.append(inputs[k])
  491. return self._get_loss(head_outputs, normalized_box, gt_label,
  492. gt_score, targets)
  493. else:
  494. im_size = inputs['im_size']
  495. return self._get_prediction(head_outputs, im_size)