yolo_loss.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle import fluid
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. class YOLOv3Loss(object):
  23. """
  24. Combined loss for YOLOv3 network
  25. Args:
  26. batch_size (int): training batch size
  27. ignore_thresh (float): threshold to ignore confidence loss
  28. label_smooth (bool): whether to use label smoothing
  29. use_fine_grained_loss (bool): whether use fine grained YOLOv3 loss
  30. instead of fluid.layers.yolov3_loss
  31. """
  32. def __init__(self,
  33. batch_size=8,
  34. ignore_thresh=0.7,
  35. label_smooth=True,
  36. use_fine_grained_loss=False,
  37. iou_loss=None,
  38. iou_aware_loss=None,
  39. downsample=[32, 16, 8],
  40. scale_x_y=1.,
  41. match_score=False):
  42. self._batch_size = batch_size
  43. self._ignore_thresh = ignore_thresh
  44. self._label_smooth = label_smooth
  45. self._use_fine_grained_loss = use_fine_grained_loss
  46. self._iou_loss = iou_loss
  47. self._iou_aware_loss = iou_aware_loss
  48. self.downsample = downsample
  49. self.scale_x_y = scale_x_y
  50. self.match_score = match_score
  51. def __call__(self, outputs, gt_box, gt_label, gt_score, targets, anchors,
  52. anchor_masks, mask_anchors, num_classes, prefix_name):
  53. if self._use_fine_grained_loss:
  54. return self._get_fine_grained_loss(
  55. outputs, targets, gt_box, self._batch_size, num_classes,
  56. mask_anchors, self._ignore_thresh)
  57. else:
  58. losses = []
  59. for i, output in enumerate(outputs):
  60. scale_x_y = self.scale_x_y if not isinstance(
  61. self.scale_x_y, Sequence) else self.scale_x_y[i]
  62. anchor_mask = anchor_masks[i]
  63. loss = fluid.layers.yolov3_loss(
  64. x=output,
  65. gt_box=gt_box,
  66. gt_label=gt_label,
  67. gt_score=gt_score,
  68. anchors=anchors,
  69. anchor_mask=anchor_mask,
  70. class_num=num_classes,
  71. ignore_thresh=self._ignore_thresh,
  72. downsample_ratio=self.downsample[i],
  73. use_label_smooth=self._label_smooth,
  74. scale_x_y=scale_x_y,
  75. name=prefix_name + "yolo_loss" + str(i))
  76. losses.append(fluid.layers.reduce_mean(loss))
  77. return {'loss': sum(losses)}
  78. def _get_fine_grained_loss(self,
  79. outputs,
  80. targets,
  81. gt_box,
  82. batch_size,
  83. num_classes,
  84. mask_anchors,
  85. ignore_thresh,
  86. eps=1.e-10):
  87. """
  88. Calculate fine grained YOLOv3 loss
  89. Args:
  90. outputs ([Variables]): List of Variables, output of backbone stages
  91. targets ([Variables]): List of Variables, The targets for yolo
  92. loss calculatation.
  93. gt_box (Variable): The ground-truth boudding boxes.
  94. batch_size (int): The training batch size
  95. num_classes (int): class num of dataset
  96. mask_anchors ([[float]]): list of anchors in each output layer
  97. ignore_thresh (float): prediction bbox overlap any gt_box greater
  98. than ignore_thresh, objectness loss will
  99. be ignored.
  100. Returns:
  101. Type: dict
  102. xy_loss (Variable): YOLOv3 (x, y) coordinates loss
  103. wh_loss (Variable): YOLOv3 (w, h) coordinates loss
  104. obj_loss (Variable): YOLOv3 objectness score loss
  105. cls_loss (Variable): YOLOv3 classification loss
  106. """
  107. assert len(outputs) == len(targets), \
  108. "YOLOv3 output layer number not equal target number"
  109. loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], []
  110. if self._iou_loss is not None:
  111. loss_ious = []
  112. if self._iou_aware_loss is not None:
  113. loss_iou_awares = []
  114. for i, (output, target,
  115. anchors) in enumerate(zip(outputs, targets, mask_anchors)):
  116. downsample = self.downsample[i]
  117. an_num = len(anchors) // 2
  118. if self._iou_aware_loss is not None:
  119. ioup, output = self._split_ioup(output, an_num, num_classes)
  120. x, y, w, h, obj, cls = self._split_output(output, an_num,
  121. num_classes)
  122. tx, ty, tw, th, tscale, tobj, tcls = self._split_target(target)
  123. tscale_tobj = tscale * tobj
  124. scale_x_y = self.scale_x_y if not isinstance(
  125. self.scale_x_y, Sequence) else self.scale_x_y[i]
  126. if (abs(scale_x_y - 1.0) < eps):
  127. loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
  128. x, tx) * tscale_tobj
  129. loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
  130. loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
  131. y, ty) * tscale_tobj
  132. loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
  133. else:
  134. dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y -
  135. 1.0)
  136. dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y -
  137. 1.0)
  138. loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
  139. loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
  140. loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
  141. loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
  142. # NOTE: we refined loss function of (w, h) as L1Loss
  143. loss_w = fluid.layers.abs(w - tw) * tscale_tobj
  144. loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
  145. loss_h = fluid.layers.abs(h - th) * tscale_tobj
  146. loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
  147. if self._iou_loss is not None:
  148. loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors,
  149. downsample, self._batch_size,
  150. scale_x_y)
  151. loss_iou = loss_iou * tscale_tobj
  152. loss_iou = fluid.layers.reduce_sum(loss_iou, dim=[1, 2, 3])
  153. loss_ious.append(fluid.layers.reduce_mean(loss_iou))
  154. if self._iou_aware_loss is not None:
  155. loss_iou_aware = self._iou_aware_loss(
  156. ioup, x, y, w, h, tx, ty, tw, th, anchors, downsample,
  157. self._batch_size, scale_x_y)
  158. loss_iou_aware = loss_iou_aware * tobj
  159. loss_iou_aware = fluid.layers.reduce_sum(
  160. loss_iou_aware, dim=[1, 2, 3])
  161. loss_iou_awares.append(
  162. fluid.layers.reduce_mean(loss_iou_aware))
  163. loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
  164. output, obj, tobj, gt_box, self._batch_size, anchors,
  165. num_classes, downsample, self._ignore_thresh, scale_x_y)
  166. loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls,
  167. tcls)
  168. loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
  169. loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
  170. loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y))
  171. loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h))
  172. loss_objs.append(
  173. fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg))
  174. loss_clss.append(fluid.layers.reduce_mean(loss_cls))
  175. losses_all = {
  176. "loss_xy": fluid.layers.sum(loss_xys),
  177. "loss_wh": fluid.layers.sum(loss_whs),
  178. "loss_obj": fluid.layers.sum(loss_objs),
  179. "loss_cls": fluid.layers.sum(loss_clss),
  180. }
  181. if self._iou_loss is not None:
  182. losses_all["loss_iou"] = fluid.layers.sum(loss_ious)
  183. if self._iou_aware_loss is not None:
  184. losses_all["loss_iou_aware"] = fluid.layers.sum(loss_iou_awares)
  185. return losses_all
  186. def _split_ioup(self, output, an_num, num_classes):
  187. """
  188. Split output feature map to output, predicted iou
  189. along channel dimension
  190. """
  191. ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
  192. ioup = fluid.layers.sigmoid(ioup)
  193. oriout = fluid.layers.slice(
  194. output,
  195. axes=[1],
  196. starts=[an_num],
  197. ends=[an_num * (num_classes + 6)])
  198. return (ioup, oriout)
  199. def _split_output(self, output, an_num, num_classes):
  200. """
  201. Split output feature map to x, y, w, h, objectness, classification
  202. along channel dimension
  203. """
  204. x = fluid.layers.strided_slice(
  205. output,
  206. axes=[1],
  207. starts=[0],
  208. ends=[output.shape[1]],
  209. strides=[5 + num_classes])
  210. y = fluid.layers.strided_slice(
  211. output,
  212. axes=[1],
  213. starts=[1],
  214. ends=[output.shape[1]],
  215. strides=[5 + num_classes])
  216. w = fluid.layers.strided_slice(
  217. output,
  218. axes=[1],
  219. starts=[2],
  220. ends=[output.shape[1]],
  221. strides=[5 + num_classes])
  222. h = fluid.layers.strided_slice(
  223. output,
  224. axes=[1],
  225. starts=[3],
  226. ends=[output.shape[1]],
  227. strides=[5 + num_classes])
  228. obj = fluid.layers.strided_slice(
  229. output,
  230. axes=[1],
  231. starts=[4],
  232. ends=[output.shape[1]],
  233. strides=[5 + num_classes])
  234. clss = []
  235. stride = output.shape[1] // an_num
  236. for m in range(an_num):
  237. clss.append(
  238. fluid.layers.slice(
  239. output,
  240. axes=[1],
  241. starts=[stride * m + 5],
  242. ends=[stride * m + 5 + num_classes]))
  243. cls = fluid.layers.transpose(
  244. fluid.layers.stack(
  245. clss, axis=1), perm=[0, 1, 3, 4, 2])
  246. return (x, y, w, h, obj, cls)
  247. def _split_target(self, target):
  248. """
  249. split target to x, y, w, h, objectness, classification
  250. along dimension 2
  251. target is in shape [N, an_num, 6 + class_num, H, W]
  252. """
  253. tx = target[:, :, 0, :, :]
  254. ty = target[:, :, 1, :, :]
  255. tw = target[:, :, 2, :, :]
  256. th = target[:, :, 3, :, :]
  257. tscale = target[:, :, 4, :, :]
  258. tobj = target[:, :, 5, :, :]
  259. tcls = fluid.layers.transpose(
  260. target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
  261. tcls.stop_gradient = True
  262. return (tx, ty, tw, th, tscale, tobj, tcls)
  263. def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
  264. num_classes, downsample, ignore_thresh, scale_x_y):
  265. # A prediction bbox overlap any gt_bbox over ignore_thresh,
  266. # objectness loss will be ignored, process as follows:
  267. # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
  268. # NOTE: img_size is set as 1.0 to get noramlized pred bbox
  269. bbox, prob = fluid.layers.yolo_box(
  270. x=output,
  271. img_size=fluid.layers.ones(
  272. shape=[batch_size, 2], dtype="int32"),
  273. anchors=anchors,
  274. class_num=num_classes,
  275. conf_thresh=0.,
  276. downsample_ratio=downsample,
  277. clip_bbox=False,
  278. scale_x_y=scale_x_y)
  279. # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
  280. # and gt bbox in each sample
  281. if batch_size > 1:
  282. preds = fluid.layers.split(bbox, batch_size, dim=0)
  283. gts = fluid.layers.split(gt_box, batch_size, dim=0)
  284. else:
  285. preds = [bbox]
  286. gts = [gt_box]
  287. probs = [prob]
  288. ious = []
  289. for pred, gt in zip(preds, gts):
  290. def box_xywh2xyxy(box):
  291. x = box[:, 0]
  292. y = box[:, 1]
  293. w = box[:, 2]
  294. h = box[:, 3]
  295. return fluid.layers.stack(
  296. [
  297. x - w / 2.,
  298. y - h / 2.,
  299. x + w / 2.,
  300. y + h / 2.,
  301. ], axis=1)
  302. pred = fluid.layers.squeeze(pred, axes=[0])
  303. gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
  304. ious.append(fluid.layers.iou_similarity(pred, gt))
  305. iou = fluid.layers.stack(ious, axis=0)
  306. # 3. Get iou_mask by IoU between gt bbox and prediction bbox,
  307. # Get obj_mask by tobj(holds gt_score), calculate objectness loss
  308. max_iou = fluid.layers.reduce_max(iou, dim=-1)
  309. iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
  310. if self.match_score:
  311. max_prob = fluid.layers.reduce_max(prob, dim=-1)
  312. iou_mask = iou_mask * fluid.layers.cast(
  313. max_prob <= 0.25, dtype="float32")
  314. output_shape = fluid.layers.shape(output)
  315. an_num = len(anchors) // 2
  316. iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
  317. output_shape[3]))
  318. iou_mask.stop_gradient = True
  319. # NOTE: tobj holds gt_score, obj_mask holds object existence mask
  320. obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
  321. obj_mask.stop_gradient = True
  322. # For positive objectness grids, objectness loss should be calculated
  323. # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
  324. loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj,
  325. obj_mask)
  326. loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
  327. loss_obj_neg = fluid.layers.reduce_sum(
  328. loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
  329. return loss_obj_pos, loss_obj_neg