yolo_v3.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from paddle import fluid
  17. from paddle.fluid.param_attr import ParamAttr
  18. from paddle.fluid.regularizer import L2Decay
  19. from collections import OrderedDict
  20. from .ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
  21. from .ops import DropBlock
  22. from .loss.yolo_loss import YOLOv3Loss
  23. from .loss.iou_loss import IouLoss
  24. from .loss.iou_aware_loss import IouAwareLoss
  25. from .iou_aware import get_iou_aware_score
  26. try:
  27. from collections.abc import Sequence
  28. except Exception:
  29. from collections import Sequence
  30. class YOLOv3:
  31. def __init__(
  32. self,
  33. backbone,
  34. mode='train',
  35. # YOLOv3Head
  36. num_classes=80,
  37. anchors=None,
  38. anchor_masks=None,
  39. coord_conv=False,
  40. iou_aware=False,
  41. iou_aware_factor=0.4,
  42. scale_x_y=1.,
  43. spp=False,
  44. drop_block=False,
  45. use_matrix_nms=False,
  46. # YOLOv3Loss
  47. batch_size=8,
  48. ignore_threshold=0.7,
  49. label_smooth=False,
  50. use_fine_grained_loss=False,
  51. use_iou_loss=False,
  52. iou_loss_weight=2.5,
  53. iou_aware_loss_weight=1.0,
  54. max_height=608,
  55. max_width=608,
  56. # NMS
  57. nms_score_threshold=0.01,
  58. nms_topk=1000,
  59. nms_keep_topk=100,
  60. nms_iou_threshold=0.45,
  61. fixed_input_shape=None,
  62. input_channel=3):
  63. if anchors is None:
  64. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  65. [59, 119], [116, 90], [156, 198], [373, 326]]
  66. if anchor_masks is None:
  67. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  68. self.anchors = anchors
  69. self.anchor_masks = anchor_masks
  70. self._parse_anchors(anchors)
  71. self.mode = mode
  72. self.num_classes = num_classes
  73. self.backbone = backbone
  74. self.norm_decay = 0.0
  75. self.prefix_name = ''
  76. self.use_fine_grained_loss = use_fine_grained_loss
  77. self.fixed_input_shape = fixed_input_shape
  78. self.coord_conv = coord_conv
  79. self.iou_aware = iou_aware
  80. self.iou_aware_factor = iou_aware_factor
  81. self.scale_x_y = scale_x_y
  82. self.use_spp = spp
  83. self.drop_block = drop_block
  84. if use_matrix_nms:
  85. self.nms = MatrixNMS(
  86. background_label=-1,
  87. keep_top_k=nms_keep_topk,
  88. normalized=False,
  89. score_threshold=nms_score_threshold,
  90. post_threshold=0.01)
  91. else:
  92. self.nms = MultiClassNMS(
  93. background_label=-1,
  94. keep_top_k=nms_keep_topk,
  95. nms_threshold=nms_iou_threshold,
  96. nms_top_k=nms_topk,
  97. normalized=False,
  98. score_threshold=nms_score_threshold)
  99. self.iou_loss = None
  100. self.iou_aware_loss = None
  101. if use_iou_loss:
  102. self.iou_loss = IouLoss(
  103. loss_weight=iou_loss_weight,
  104. max_height=max_height,
  105. max_width=max_width)
  106. if iou_aware:
  107. self.iou_aware_loss = IouAwareLoss(
  108. loss_weight=iou_aware_loss_weight,
  109. max_height=max_height,
  110. max_width=max_width)
  111. self.yolo_loss = YOLOv3Loss(
  112. batch_size=batch_size,
  113. ignore_thresh=ignore_threshold,
  114. scale_x_y=scale_x_y,
  115. label_smooth=label_smooth,
  116. use_fine_grained_loss=self.use_fine_grained_loss,
  117. iou_loss=self.iou_loss,
  118. iou_aware_loss=self.iou_aware_loss)
  119. self.conv_block_num = 2
  120. self.block_size = 3
  121. self.keep_prob = 0.9
  122. self.downsample = [32, 16, 8]
  123. self.clip_bbox = True
  124. self.input_channel = input_channel
  125. def _head(self, input, is_train=True):
  126. outputs = []
  127. # get last out_layer_num blocks in reverse order
  128. out_layer_num = len(self.anchor_masks)
  129. blocks = input[-1:-out_layer_num - 1:-1]
  130. route = None
  131. for i, block in enumerate(blocks):
  132. if i > 0: # perform concat in first 2 detection_block
  133. block = fluid.layers.concat(input=[route, block], axis=1)
  134. route, tip = self._detection_block(
  135. block,
  136. channel=64 * (2**out_layer_num) // (2**i),
  137. is_first=i == 0,
  138. is_test=(not is_train),
  139. conv_block_num=self.conv_block_num,
  140. name=self.prefix_name + "yolo_block.{}".format(i))
  141. # out channel number = mask_num * (5 + class_num)
  142. if self.iou_aware:
  143. num_filters = len(self.anchor_masks[i]) * (
  144. self.num_classes + 6)
  145. else:
  146. num_filters = len(self.anchor_masks[i]) * (
  147. self.num_classes + 5)
  148. with fluid.name_scope('yolo_output'):
  149. block_out = fluid.layers.conv2d(
  150. input=tip,
  151. num_filters=num_filters,
  152. filter_size=1,
  153. stride=1,
  154. padding=0,
  155. act=None,
  156. param_attr=ParamAttr(
  157. name=self.prefix_name +
  158. "yolo_output.{}.conv.weights".format(i)),
  159. bias_attr=ParamAttr(
  160. regularizer=L2Decay(0.),
  161. name=self.prefix_name +
  162. "yolo_output.{}.conv.bias".format(i)))
  163. outputs.append(block_out)
  164. if i < len(blocks) - 1:
  165. # do not perform upsample in the last detection_block
  166. route = self._conv_bn(
  167. input=route,
  168. ch_out=256 // (2**i),
  169. filter_size=1,
  170. stride=1,
  171. padding=0,
  172. is_test=(not is_train),
  173. name=self.prefix_name + "yolo_transition.{}".format(i))
  174. # upsample
  175. route = self._upsample(route)
  176. return outputs
  177. def _parse_anchors(self, anchors):
  178. self.anchors = []
  179. self.mask_anchors = []
  180. assert len(anchors) > 0, "ANCHORS not set."
  181. assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
  182. for anchor in anchors:
  183. assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
  184. self.anchors.extend(anchor)
  185. anchor_num = len(anchors)
  186. for masks in self.anchor_masks:
  187. self.mask_anchors.append([])
  188. for mask in masks:
  189. assert mask < anchor_num, "anchor mask index overflow"
  190. self.mask_anchors[-1].extend(anchors[mask])
  191. def _create_tensor_from_numpy(self, numpy_array):
  192. paddle_array = fluid.layers.create_global_var(
  193. shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
  194. fluid.layers.assign(numpy_array, paddle_array)
  195. return paddle_array
  196. def _add_coord(self, input, is_test=True):
  197. if not self.coord_conv:
  198. return input
  199. # NOTE: here is used for exporting model for TensorRT inference,
  200. # only support batch_size=1 for input shape should be fixed,
  201. # and we create tensor with fixed shape from numpy array
  202. if is_test and input.shape[2] > 0 and input.shape[3] > 0:
  203. batch_size = 1
  204. grid_x = int(input.shape[3])
  205. grid_y = int(input.shape[2])
  206. idx_i = np.array(
  207. [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
  208. dtype='float32')
  209. gi_np = np.repeat(idx_i, grid_y, axis=0)
  210. gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
  211. gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
  212. x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
  213. x_range.stop_gradient = True
  214. y_range = self._create_tensor_from_numpy(
  215. gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
  216. y_range.stop_gradient = True
  217. # NOTE: in training mode, H and W is variable for random shape,
  218. # implement add_coord with shape as Variable
  219. else:
  220. input_shape = fluid.layers.shape(input)
  221. b = input_shape[0]
  222. h = input_shape[2]
  223. w = input_shape[3]
  224. x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
  225. x_range = x_range - 1.
  226. x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
  227. x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
  228. x_range.stop_gradient = True
  229. y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
  230. y_range.stop_gradient = True
  231. return fluid.layers.concat([input, x_range, y_range], axis=1)
  232. def _conv_bn(self,
  233. input,
  234. ch_out,
  235. filter_size,
  236. stride,
  237. padding,
  238. act='leaky',
  239. is_test=False,
  240. name=None):
  241. conv = fluid.layers.conv2d(
  242. input=input,
  243. num_filters=ch_out,
  244. filter_size=filter_size,
  245. stride=stride,
  246. padding=padding,
  247. act=None,
  248. param_attr=ParamAttr(name=name + '.conv.weights'),
  249. bias_attr=False)
  250. bn_name = name + '.bn'
  251. bn_param_attr = ParamAttr(
  252. regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
  253. bn_bias_attr = ParamAttr(
  254. regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
  255. out = fluid.layers.batch_norm(
  256. input=conv,
  257. act=None,
  258. is_test=is_test,
  259. param_attr=bn_param_attr,
  260. bias_attr=bn_bias_attr,
  261. moving_mean_name=bn_name + '.mean',
  262. moving_variance_name=bn_name + '.var')
  263. if act == 'leaky':
  264. out = fluid.layers.leaky_relu(x=out, alpha=0.1)
  265. return out
  266. def _spp_module(self, input, is_test=True, name=""):
  267. output1 = input
  268. output2 = fluid.layers.pool2d(
  269. input=output1,
  270. pool_size=5,
  271. pool_stride=1,
  272. pool_padding=2,
  273. ceil_mode=False,
  274. pool_type='max')
  275. output3 = fluid.layers.pool2d(
  276. input=output1,
  277. pool_size=9,
  278. pool_stride=1,
  279. pool_padding=4,
  280. ceil_mode=False,
  281. pool_type='max')
  282. output4 = fluid.layers.pool2d(
  283. input=output1,
  284. pool_size=13,
  285. pool_stride=1,
  286. pool_padding=6,
  287. ceil_mode=False,
  288. pool_type='max')
  289. output = fluid.layers.concat(
  290. input=[output1, output2, output3, output4], axis=1)
  291. return output
  292. def _upsample(self, input, scale=2, name=None):
  293. out = fluid.layers.resize_nearest(
  294. input=input, scale=float(scale), name=name, align_corners=False)
  295. return out
  296. def _detection_block(self,
  297. input,
  298. channel,
  299. conv_block_num=2,
  300. is_first=False,
  301. is_test=True,
  302. name=None):
  303. assert channel % 2 == 0, \
  304. "channel {} cannot be divided by 2 in detection block {}" \
  305. .format(channel, name)
  306. conv = input
  307. for j in range(conv_block_num):
  308. conv = self._add_coord(conv, is_test=is_test)
  309. conv = self._conv_bn(
  310. conv,
  311. channel,
  312. filter_size=1,
  313. stride=1,
  314. padding=0,
  315. is_test=is_test,
  316. name='{}.{}.0'.format(name, j))
  317. if self.use_spp and is_first and j == 1:
  318. conv = self._spp_module(conv, is_test=is_test, name="spp")
  319. conv = self._conv_bn(
  320. conv,
  321. 512,
  322. filter_size=1,
  323. stride=1,
  324. padding=0,
  325. is_test=is_test,
  326. name='{}.{}.spp.conv'.format(name, j))
  327. conv = self._conv_bn(
  328. conv,
  329. channel * 2,
  330. filter_size=3,
  331. stride=1,
  332. padding=1,
  333. is_test=is_test,
  334. name='{}.{}.1'.format(name, j))
  335. if self.drop_block and j == 0 and not is_first:
  336. conv = DropBlock(
  337. conv,
  338. block_size=self.block_size,
  339. keep_prob=self.keep_prob,
  340. is_test=is_test)
  341. if self.drop_block and is_first:
  342. conv = DropBlock(
  343. conv,
  344. block_size=self.block_size,
  345. keep_prob=self.keep_prob,
  346. is_test=is_test)
  347. conv = self._add_coord(conv, is_test=is_test)
  348. route = self._conv_bn(
  349. conv,
  350. channel,
  351. filter_size=1,
  352. stride=1,
  353. padding=0,
  354. is_test=is_test,
  355. name='{}.2'.format(name))
  356. new_route = self._add_coord(route, is_test=is_test)
  357. tip = self._conv_bn(
  358. new_route,
  359. channel * 2,
  360. filter_size=3,
  361. stride=1,
  362. padding=1,
  363. is_test=is_test,
  364. name='{}.tip'.format(name))
  365. return route, tip
  366. def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
  367. loss = self.yolo_loss(inputs, gt_box, gt_label, gt_score, targets,
  368. self.anchors, self.anchor_masks,
  369. self.mask_anchors, self.num_classes,
  370. self.prefix_name)
  371. total_loss = fluid.layers.sum(list(loss.values()))
  372. return total_loss
  373. def _get_prediction(self, inputs, im_size):
  374. boxes = []
  375. scores = []
  376. for i, input in enumerate(inputs):
  377. if self.iou_aware:
  378. input = get_iou_aware_score(input,
  379. len(self.anchor_masks[i]),
  380. self.num_classes,
  381. self.iou_aware_factor)
  382. scale_x_y = self.scale_x_y if not isinstance(
  383. self.scale_x_y, Sequence) else self.scale_x_y[i]
  384. if paddle.__version__ < '1.8.4' and paddle.__version__ != '0.0.0':
  385. box, score = fluid.layers.yolo_box(
  386. x=input,
  387. img_size=im_size,
  388. anchors=self.mask_anchors[i],
  389. class_num=self.num_classes,
  390. conf_thresh=self.nms.score_threshold,
  391. downsample_ratio=self.downsample[i],
  392. name=self.prefix_name + 'yolo_box' + str(i),
  393. clip_bbox=self.clip_bbox)
  394. else:
  395. box, score = fluid.layers.yolo_box(
  396. x=input,
  397. img_size=im_size,
  398. anchors=self.mask_anchors[i],
  399. class_num=self.num_classes,
  400. conf_thresh=self.nms.score_threshold,
  401. downsample_ratio=self.downsample[i],
  402. name=self.prefix_name + 'yolo_box' + str(i),
  403. clip_bbox=self.clip_bbox,
  404. scale_x_y=self.scale_x_y)
  405. boxes.append(box)
  406. scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
  407. yolo_boxes = fluid.layers.concat(boxes, axis=1)
  408. yolo_scores = fluid.layers.concat(scores, axis=2)
  409. if type(self.nms) is MultiClassSoftNMS:
  410. yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
  411. pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
  412. return pred
  413. def generate_inputs(self):
  414. inputs = OrderedDict()
  415. if self.fixed_input_shape is not None:
  416. input_shape = [
  417. None, self.input_channel, self.fixed_input_shape[1],
  418. self.fixed_input_shape[0]
  419. ]
  420. inputs['image'] = fluid.data(
  421. dtype='float32', shape=input_shape, name='image')
  422. else:
  423. inputs['image'] = fluid.data(
  424. dtype='float32',
  425. shape=[None, self.input_channel, None, None],
  426. name='image')
  427. if self.mode == 'train':
  428. inputs['gt_box'] = fluid.data(
  429. dtype='float32', shape=[None, None, 4], name='gt_box')
  430. inputs['gt_label'] = fluid.data(
  431. dtype='int32', shape=[None, None], name='gt_label')
  432. inputs['gt_score'] = fluid.data(
  433. dtype='float32', shape=[None, None], name='gt_score')
  434. inputs['im_size'] = fluid.data(
  435. dtype='int32', shape=[None, 2], name='im_size')
  436. if self.use_fine_grained_loss:
  437. downsample = 32
  438. for i, mask in enumerate(self.anchor_masks):
  439. if self.fixed_input_shape is not None:
  440. target_shape = [
  441. self.fixed_input_shape[1] // downsample,
  442. self.fixed_input_shape[0] // downsample
  443. ]
  444. else:
  445. target_shape = [None, None]
  446. inputs['target{}'.format(i)] = fluid.data(
  447. dtype='float32',
  448. lod_level=0,
  449. shape=[
  450. None, len(mask), 6 + self.num_classes,
  451. target_shape[0], target_shape[1]
  452. ],
  453. name='target{}'.format(i))
  454. downsample //= 2
  455. elif self.mode == 'eval':
  456. inputs['im_size'] = fluid.data(
  457. dtype='int32', shape=[None, 2], name='im_size')
  458. inputs['im_id'] = fluid.data(
  459. dtype='int32', shape=[None, 1], name='im_id')
  460. inputs['gt_box'] = fluid.data(
  461. dtype='float32', shape=[None, None, 4], name='gt_box')
  462. inputs['gt_label'] = fluid.data(
  463. dtype='int32', shape=[None, None], name='gt_label')
  464. inputs['is_difficult'] = fluid.data(
  465. dtype='int32', shape=[None, None], name='is_difficult')
  466. elif self.mode == 'test':
  467. inputs['im_size'] = fluid.data(
  468. dtype='int32', shape=[None, 2], name='im_size')
  469. return inputs
  470. def build_net(self, inputs):
  471. image = inputs['image']
  472. feats = self.backbone(image)
  473. if isinstance(feats, OrderedDict):
  474. feat_names = list(feats.keys())
  475. feats = [feats[name] for name in feat_names]
  476. head_outputs = self._head(feats, self.mode == 'train')
  477. if self.mode == 'train':
  478. gt_box = inputs['gt_box']
  479. gt_label = inputs['gt_label']
  480. gt_score = inputs['gt_score']
  481. im_size = inputs['im_size']
  482. num_boxes = fluid.layers.shape(gt_box)[1]
  483. im_size_wh = fluid.layers.reverse(im_size, axis=1)
  484. whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
  485. whwh = fluid.layers.unsqueeze(whwh, axes=[1])
  486. whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
  487. whwh = fluid.layers.cast(whwh, dtype='float32')
  488. whwh.stop_gradient = True
  489. normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
  490. targets = []
  491. if self.use_fine_grained_loss:
  492. for i, mask in enumerate(self.anchor_masks):
  493. k = 'target{}'.format(i)
  494. if k in inputs:
  495. targets.append(inputs[k])
  496. return self._get_loss(head_outputs, normalized_box, gt_label,
  497. gt_score, targets)
  498. else:
  499. im_size = inputs['im_size']
  500. return self._get_prediction(head_outputs, im_size)