yolo_v3.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import fluid
  15. from paddle.fluid.param_attr import ParamAttr
  16. from paddle.fluid.regularizer import L2Decay
  17. from collections import OrderedDict
  18. from .ops import MultiClassNMS, MultiClassSoftNMS, MatrixNMS
  19. from .ops import DropBlock
  20. from .loss.yolo_loss import YOLOv3Loss
  21. from .loss.iou_loss import IouLoss
  22. from .loss.iou_aware_loss import IouAwareLoss
  23. from .iou_aware import get_iou_aware_score
  24. try:
  25. from collections.abc import Sequence
  26. except Exception:
  27. from collections import Sequence
  28. class YOLOv3:
  29. def __init__(
  30. self,
  31. backbone,
  32. mode='train',
  33. # YOLOv3Head
  34. num_classes=80,
  35. anchors=None,
  36. anchor_masks=None,
  37. coord_conv=False,
  38. iou_aware=False,
  39. iou_aware_factor=0.4,
  40. scale_x_y=1.,
  41. spp=False,
  42. drop_block=False,
  43. use_matrix_nms=False,
  44. # YOLOv3Loss
  45. batch_size=8,
  46. ignore_threshold=0.7,
  47. label_smooth=False,
  48. use_fine_grained_loss=False,
  49. use_iou_loss=False,
  50. iou_loss_weight=2.5,
  51. iou_aware_loss_weight=1.0,
  52. max_height=608,
  53. max_width=608,
  54. # NMS
  55. nms_score_threshold=0.01,
  56. nms_topk=1000,
  57. nms_keep_topk=100,
  58. nms_iou_threshold=0.45,
  59. fixed_input_shape=None):
  60. if anchors is None:
  61. anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  62. [59, 119], [116, 90], [156, 198], [373, 326]]
  63. if anchor_masks is None:
  64. anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  65. self.anchors = anchors
  66. self.anchor_masks = anchor_masks
  67. self._parse_anchors(anchors)
  68. self.mode = mode
  69. self.num_classes = num_classes
  70. self.backbone = backbone
  71. self.norm_decay = 0.0
  72. self.prefix_name = ''
  73. self.use_fine_grained_loss = use_fine_grained_loss
  74. self.fixed_input_shape = fixed_input_shape
  75. self.coord_conv = coord_conv
  76. self.iou_aware = iou_aware
  77. self.iou_aware_factor = iou_aware_factor
  78. self.scale_x_y = scale_x_y
  79. self.use_spp = spp
  80. self.drop_block = drop_block
  81. if use_matrix_nms:
  82. self.nms = MatrixNMS(
  83. background_label=-1,
  84. keep_top_k=nms_keep_topk,
  85. normalized=False,
  86. score_threshold=nms_score_threshold,
  87. post_threshold=0.01)
  88. else:
  89. self.nms = MultiClassNMS(
  90. background_label=-1,
  91. keep_top_k=nms_keep_topk,
  92. nms_threshold=nms_iou_threshold,
  93. nms_top_k=nms_topk,
  94. normalized=False,
  95. score_threshold=nms_score_threshold)
  96. self.iou_loss = None
  97. self.iou_aware_loss = None
  98. if use_iou_loss:
  99. self.iou_loss = IouLoss(
  100. loss_weight=iou_loss_weight,
  101. max_height=max_height,
  102. max_width=max_width)
  103. if iou_aware:
  104. self.iou_aware_loss = IouAwareLoss(
  105. loss_weight=iou_aware_loss_weight,
  106. max_height=max_height,
  107. max_width=max_width)
  108. self.yolo_loss = YOLOv3Loss(
  109. batch_size=batch_size,
  110. ignore_thresh=ignore_threshold,
  111. scale_x_y=scale_x_y,
  112. label_smooth=label_smooth,
  113. use_fine_grained_loss=self.use_fine_grained_loss,
  114. iou_loss=self.iou_loss,
  115. iou_aware_loss=self.iou_aware_loss)
  116. self.conv_block_num = 2
  117. self.block_size = 3
  118. self.keep_prob = 0.9
  119. self.downsample = [32, 16, 8]
  120. self.clip_bbox = True
  121. def _head(self, input, is_train=True):
  122. outputs = []
  123. # get last out_layer_num blocks in reverse order
  124. out_layer_num = len(self.anchor_masks)
  125. blocks = input[-1:-out_layer_num - 1:-1]
  126. route = None
  127. for i, block in enumerate(blocks):
  128. if i > 0: # perform concat in first 2 detection_block
  129. block = fluid.layers.concat(input=[route, block], axis=1)
  130. route, tip = self._detection_block(
  131. block,
  132. channel=64 * (2**out_layer_num) // (2**i),
  133. is_first=i == 0,
  134. is_test=(not is_train),
  135. conv_block_num=self.conv_block_num,
  136. name=self.prefix_name + "yolo_block.{}".format(i))
  137. # out channel number = mask_num * (5 + class_num)
  138. if self.iou_aware:
  139. num_filters = len(self.anchor_masks[i]) * (
  140. self.num_classes + 6)
  141. else:
  142. num_filters = len(self.anchor_masks[i]) * (
  143. self.num_classes + 5)
  144. with fluid.name_scope('yolo_output'):
  145. block_out = fluid.layers.conv2d(
  146. input=tip,
  147. num_filters=num_filters,
  148. filter_size=1,
  149. stride=1,
  150. padding=0,
  151. act=None,
  152. param_attr=ParamAttr(
  153. name=self.prefix_name +
  154. "yolo_output.{}.conv.weights".format(i)),
  155. bias_attr=ParamAttr(
  156. regularizer=L2Decay(0.),
  157. name=self.prefix_name +
  158. "yolo_output.{}.conv.bias".format(i)))
  159. outputs.append(block_out)
  160. if i < len(blocks) - 1:
  161. # do not perform upsample in the last detection_block
  162. route = self._conv_bn(
  163. input=route,
  164. ch_out=256 // (2**i),
  165. filter_size=1,
  166. stride=1,
  167. padding=0,
  168. is_test=(not is_train),
  169. name=self.prefix_name + "yolo_transition.{}".format(i))
  170. # upsample
  171. route = self._upsample(route)
  172. return outputs
  173. def _parse_anchors(self, anchors):
  174. self.anchors = []
  175. self.mask_anchors = []
  176. assert len(anchors) > 0, "ANCHORS not set."
  177. assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
  178. for anchor in anchors:
  179. assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
  180. self.anchors.extend(anchor)
  181. anchor_num = len(anchors)
  182. for masks in self.anchor_masks:
  183. self.mask_anchors.append([])
  184. for mask in masks:
  185. assert mask < anchor_num, "anchor mask index overflow"
  186. self.mask_anchors[-1].extend(anchors[mask])
  187. def _create_tensor_from_numpy(self, numpy_array):
  188. paddle_array = fluid.layers.create_global_var(
  189. shape=numpy_array.shape, value=0., dtype=numpy_array.dtype)
  190. fluid.layers.assign(numpy_array, paddle_array)
  191. return paddle_array
  192. def _add_coord(self, input, is_test=True):
  193. if not self.coord_conv:
  194. return input
  195. # NOTE: here is used for exporting model for TensorRT inference,
  196. # only support batch_size=1 for input shape should be fixed,
  197. # and we create tensor with fixed shape from numpy array
  198. if is_test and input.shape[2] > 0 and input.shape[3] > 0:
  199. batch_size = 1
  200. grid_x = int(input.shape[3])
  201. grid_y = int(input.shape[2])
  202. idx_i = np.array(
  203. [[i / (grid_x - 1) * 2.0 - 1 for i in range(grid_x)]],
  204. dtype='float32')
  205. gi_np = np.repeat(idx_i, grid_y, axis=0)
  206. gi_np = np.reshape(gi_np, newshape=[1, 1, grid_y, grid_x])
  207. gi_np = np.tile(gi_np, reps=[batch_size, 1, 1, 1])
  208. x_range = self._create_tensor_from_numpy(gi_np.astype(np.float32))
  209. x_range.stop_gradient = True
  210. y_range = self._create_tensor_from_numpy(
  211. gi_np.transpose([0, 1, 3, 2]).astype(np.float32))
  212. y_range.stop_gradient = True
  213. # NOTE: in training mode, H and W is variable for random shape,
  214. # implement add_coord with shape as Variable
  215. else:
  216. input_shape = fluid.layers.shape(input)
  217. b = input_shape[0]
  218. h = input_shape[2]
  219. w = input_shape[3]
  220. x_range = fluid.layers.range(0, w, 1, 'float32') / ((w - 1.) / 2.)
  221. x_range = x_range - 1.
  222. x_range = fluid.layers.unsqueeze(x_range, [0, 1, 2])
  223. x_range = fluid.layers.expand(x_range, [b, 1, h, 1])
  224. x_range.stop_gradient = True
  225. y_range = fluid.layers.transpose(x_range, [0, 1, 3, 2])
  226. y_range.stop_gradient = True
  227. return fluid.layers.concat([input, x_range, y_range], axis=1)
  228. def _conv_bn(self,
  229. input,
  230. ch_out,
  231. filter_size,
  232. stride,
  233. padding,
  234. act='leaky',
  235. is_test=False,
  236. name=None):
  237. conv = fluid.layers.conv2d(
  238. input=input,
  239. num_filters=ch_out,
  240. filter_size=filter_size,
  241. stride=stride,
  242. padding=padding,
  243. act=None,
  244. param_attr=ParamAttr(name=name + '.conv.weights'),
  245. bias_attr=False)
  246. bn_name = name + '.bn'
  247. bn_param_attr = ParamAttr(
  248. regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
  249. bn_bias_attr = ParamAttr(
  250. regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
  251. out = fluid.layers.batch_norm(
  252. input=conv,
  253. act=None,
  254. is_test=is_test,
  255. param_attr=bn_param_attr,
  256. bias_attr=bn_bias_attr,
  257. moving_mean_name=bn_name + '.mean',
  258. moving_variance_name=bn_name + '.var')
  259. if act == 'leaky':
  260. out = fluid.layers.leaky_relu(x=out, alpha=0.1)
  261. return out
  262. def _spp_module(self, input, is_test=True, name=""):
  263. output1 = input
  264. output2 = fluid.layers.pool2d(
  265. input=output1,
  266. pool_size=5,
  267. pool_stride=1,
  268. pool_padding=2,
  269. ceil_mode=False,
  270. pool_type='max')
  271. output3 = fluid.layers.pool2d(
  272. input=output1,
  273. pool_size=9,
  274. pool_stride=1,
  275. pool_padding=4,
  276. ceil_mode=False,
  277. pool_type='max')
  278. output4 = fluid.layers.pool2d(
  279. input=output1,
  280. pool_size=13,
  281. pool_stride=1,
  282. pool_padding=6,
  283. ceil_mode=False,
  284. pool_type='max')
  285. output = fluid.layers.concat(
  286. input=[output1, output2, output3, output4], axis=1)
  287. return output
  288. def _upsample(self, input, scale=2, name=None):
  289. out = fluid.layers.resize_nearest(
  290. input=input, scale=float(scale), name=name)
  291. return out
  292. def _detection_block(self,
  293. input,
  294. channel,
  295. conv_block_num=2,
  296. is_first=False,
  297. is_test=True,
  298. name=None):
  299. assert channel % 2 == 0, \
  300. "channel {} cannot be divided by 2 in detection block {}" \
  301. .format(channel, name)
  302. conv = input
  303. for j in range(conv_block_num):
  304. conv = self._add_coord(conv, is_test=is_test)
  305. conv = self._conv_bn(
  306. conv,
  307. channel,
  308. filter_size=1,
  309. stride=1,
  310. padding=0,
  311. is_test=is_test,
  312. name='{}.{}.0'.format(name, j))
  313. if self.use_spp and is_first and j == 1:
  314. conv = self._spp_module(conv, is_test=is_test, name="spp")
  315. conv = self._conv_bn(
  316. conv,
  317. 512,
  318. filter_size=1,
  319. stride=1,
  320. padding=0,
  321. is_test=is_test,
  322. name='{}.{}.spp.conv'.format(name, j))
  323. conv = self._conv_bn(
  324. conv,
  325. channel * 2,
  326. filter_size=3,
  327. stride=1,
  328. padding=1,
  329. is_test=is_test,
  330. name='{}.{}.1'.format(name, j))
  331. if self.drop_block and j == 0 and not is_first:
  332. conv = DropBlock(
  333. conv,
  334. block_size=self.block_size,
  335. keep_prob=self.keep_prob,
  336. is_test=is_test)
  337. if self.drop_block and is_first:
  338. conv = DropBlock(
  339. conv,
  340. block_size=self.block_size,
  341. keep_prob=self.keep_prob,
  342. is_test=is_test)
  343. conv = self._add_coord(conv, is_test=is_test)
  344. route = self._conv_bn(
  345. conv,
  346. channel,
  347. filter_size=1,
  348. stride=1,
  349. padding=0,
  350. is_test=is_test,
  351. name='{}.2'.format(name))
  352. new_route = self._add_coord(route, is_test=is_test)
  353. tip = self._conv_bn(
  354. new_route,
  355. channel * 2,
  356. filter_size=3,
  357. stride=1,
  358. padding=1,
  359. is_test=is_test,
  360. name='{}.tip'.format(name))
  361. return route, tip
  362. def _get_loss(self, inputs, gt_box, gt_label, gt_score, targets):
  363. loss = self.yolo_loss(inputs, gt_box, gt_label, gt_score, targets,
  364. self.anchors, self.anchor_masks,
  365. self.mask_anchors, self.num_classes,
  366. self.prefix_name)
  367. total_loss = fluid.layers.sum(list(loss.values()))
  368. return total_loss
  369. def _get_prediction(self, inputs, im_size):
  370. boxes = []
  371. scores = []
  372. for i, input in enumerate(inputs):
  373. if self.iou_aware:
  374. input = get_iou_aware_score(input,
  375. len(self.anchor_masks[i]),
  376. self.num_classes,
  377. self.iou_aware_factor)
  378. scale_x_y = self.scale_x_y if not isinstance(
  379. self.scale_x_y, Sequence) else self.scale_x_y[i]
  380. box, score = fluid.layers.yolo_box(
  381. x=input,
  382. img_size=im_size,
  383. anchors=self.mask_anchors[i],
  384. class_num=self.num_classes,
  385. conf_thresh=self.nms.score_threshold,
  386. downsample_ratio=self.downsample[i],
  387. name=self.prefix_name + 'yolo_box' + str(i),
  388. clip_bbox=self.clip_bbox,
  389. scale_x_y=self.scale_x_y)
  390. boxes.append(box)
  391. scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
  392. yolo_boxes = fluid.layers.concat(boxes, axis=1)
  393. yolo_scores = fluid.layers.concat(scores, axis=2)
  394. if type(self.nms) is MultiClassSoftNMS:
  395. yolo_scores = fluid.layers.transpose(yolo_scores, perm=[0, 2, 1])
  396. pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
  397. return pred
  398. def generate_inputs(self):
  399. inputs = OrderedDict()
  400. if self.fixed_input_shape is not None:
  401. input_shape = [
  402. None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
  403. ]
  404. inputs['image'] = fluid.data(
  405. dtype='float32', shape=input_shape, name='image')
  406. else:
  407. inputs['image'] = fluid.data(
  408. dtype='float32', shape=[None, 3, None, None], name='image')
  409. if self.mode == 'train':
  410. inputs['gt_box'] = fluid.data(
  411. dtype='float32', shape=[None, None, 4], name='gt_box')
  412. inputs['gt_label'] = fluid.data(
  413. dtype='int32', shape=[None, None], name='gt_label')
  414. inputs['gt_score'] = fluid.data(
  415. dtype='float32', shape=[None, None], name='gt_score')
  416. inputs['im_size'] = fluid.data(
  417. dtype='int32', shape=[None, 2], name='im_size')
  418. if self.use_fine_grained_loss:
  419. downsample = 32
  420. for i, mask in enumerate(self.anchor_masks):
  421. if self.fixed_input_shape is not None:
  422. target_shape = [
  423. self.fixed_input_shape[1] // downsample,
  424. self.fixed_input_shape[0] // downsample
  425. ]
  426. else:
  427. target_shape = [None, None]
  428. inputs['target{}'.format(i)] = fluid.data(
  429. dtype='float32',
  430. lod_level=0,
  431. shape=[
  432. None, len(mask), 6 + self.num_classes,
  433. target_shape[0], target_shape[1]
  434. ],
  435. name='target{}'.format(i))
  436. downsample //= 2
  437. elif self.mode == 'eval':
  438. inputs['im_size'] = fluid.data(
  439. dtype='int32', shape=[None, 2], name='im_size')
  440. inputs['im_id'] = fluid.data(
  441. dtype='int32', shape=[None, 1], name='im_id')
  442. inputs['gt_box'] = fluid.data(
  443. dtype='float32', shape=[None, None, 4], name='gt_box')
  444. inputs['gt_label'] = fluid.data(
  445. dtype='int32', shape=[None, None], name='gt_label')
  446. inputs['is_difficult'] = fluid.data(
  447. dtype='int32', shape=[None, None], name='is_difficult')
  448. elif self.mode == 'test':
  449. inputs['im_size'] = fluid.data(
  450. dtype='int32', shape=[None, 2], name='im_size')
  451. return inputs
  452. def build_net(self, inputs):
  453. image = inputs['image']
  454. feats = self.backbone(image)
  455. if isinstance(feats, OrderedDict):
  456. feat_names = list(feats.keys())
  457. feats = [feats[name] for name in feat_names]
  458. head_outputs = self._head(feats, self.mode == 'train')
  459. if self.mode == 'train':
  460. gt_box = inputs['gt_box']
  461. gt_label = inputs['gt_label']
  462. gt_score = inputs['gt_score']
  463. im_size = inputs['im_size']
  464. num_boxes = fluid.layers.shape(gt_box)[1]
  465. im_size_wh = fluid.layers.reverse(im_size, axis=1)
  466. whwh = fluid.layers.concat([im_size_wh, im_size_wh], axis=1)
  467. whwh = fluid.layers.unsqueeze(whwh, axes=[1])
  468. whwh = fluid.layers.expand(whwh, expand_times=[1, num_boxes, 1])
  469. whwh = fluid.layers.cast(whwh, dtype='float32')
  470. whwh.stop_gradient = True
  471. normalized_box = fluid.layers.elementwise_div(gt_box, whwh)
  472. targets = []
  473. if self.use_fine_grained_loss:
  474. for i, mask in enumerate(self.anchor_masks):
  475. k = 'target{}'.format(i)
  476. if k in inputs:
  477. targets.append(inputs[k])
  478. return self._get_loss(head_outputs, normalized_box, gt_label,
  479. gt_score, targets)
  480. else:
  481. im_size = inputs['im_size']
  482. return self._get_prediction(head_outputs, im_size)