test_yolov3_loss.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import division
  15. import unittest
  16. import numpy as np
  17. from scipy.special import logit
  18. from scipy.special import expit
  19. import paddle
  20. from paddle import fluid
  21. from paddle.fluid import core
  22. # add python path of PadleDetection to sys.path
  23. import os
  24. import sys
  25. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
  26. if parent_path not in sys.path:
  27. sys.path.append(parent_path)
  28. from paddlex.ppdet.modeling.losses import YOLOv3Loss
  29. from paddlex.ppdet.data.transform.op_helper import jaccard_overlap
  30. import random
  31. import numpy as np
  32. def _split_ioup(output, an_num, num_classes):
  33. """
  34. Split output feature map to output, predicted iou
  35. along channel dimension
  36. """
  37. ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
  38. ioup = fluid.layers.sigmoid(ioup)
  39. oriout = fluid.layers.slice(
  40. output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
  41. return (ioup, oriout)
  42. def _split_output(output, an_num, num_classes):
  43. """
  44. Split output feature map to x, y, w, h, objectness, classification
  45. along channel dimension
  46. """
  47. x = fluid.layers.strided_slice(
  48. output,
  49. axes=[1],
  50. starts=[0],
  51. ends=[output.shape[1]],
  52. strides=[5 + num_classes])
  53. y = fluid.layers.strided_slice(
  54. output,
  55. axes=[1],
  56. starts=[1],
  57. ends=[output.shape[1]],
  58. strides=[5 + num_classes])
  59. w = fluid.layers.strided_slice(
  60. output,
  61. axes=[1],
  62. starts=[2],
  63. ends=[output.shape[1]],
  64. strides=[5 + num_classes])
  65. h = fluid.layers.strided_slice(
  66. output,
  67. axes=[1],
  68. starts=[3],
  69. ends=[output.shape[1]],
  70. strides=[5 + num_classes])
  71. obj = fluid.layers.strided_slice(
  72. output,
  73. axes=[1],
  74. starts=[4],
  75. ends=[output.shape[1]],
  76. strides=[5 + num_classes])
  77. clss = []
  78. stride = output.shape[1] // an_num
  79. for m in range(an_num):
  80. clss.append(
  81. fluid.layers.slice(
  82. output,
  83. axes=[1],
  84. starts=[stride * m + 5],
  85. ends=[stride * m + 5 + num_classes]))
  86. cls = fluid.layers.transpose(
  87. fluid.layers.stack(
  88. clss, axis=1), perm=[0, 1, 3, 4, 2])
  89. return (x, y, w, h, obj, cls)
  90. def _split_target(target):
  91. """
  92. split target to x, y, w, h, objectness, classification
  93. along dimension 2
  94. target is in shape [N, an_num, 6 + class_num, H, W]
  95. """
  96. tx = target[:, :, 0, :, :]
  97. ty = target[:, :, 1, :, :]
  98. tw = target[:, :, 2, :, :]
  99. th = target[:, :, 3, :, :]
  100. tscale = target[:, :, 4, :, :]
  101. tobj = target[:, :, 5, :, :]
  102. tcls = fluid.layers.transpose(target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
  103. tcls.stop_gradient = True
  104. return (tx, ty, tw, th, tscale, tobj, tcls)
  105. def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes,
  106. downsample, ignore_thresh, scale_x_y):
  107. # A prediction bbox overlap any gt_bbox over ignore_thresh,
  108. # objectness loss will be ignored, process as follows:
  109. # 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
  110. # NOTE: img_size is set as 1.0 to get noramlized pred bbox
  111. bbox, prob = fluid.layers.yolo_box(
  112. x=output,
  113. img_size=fluid.layers.ones(
  114. shape=[batch_size, 2], dtype="int32"),
  115. anchors=anchors,
  116. class_num=num_classes,
  117. conf_thresh=0.,
  118. downsample_ratio=downsample,
  119. clip_bbox=False,
  120. scale_x_y=scale_x_y)
  121. # 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
  122. # and gt bbox in each sample
  123. if batch_size > 1:
  124. preds = fluid.layers.split(bbox, batch_size, dim=0)
  125. gts = fluid.layers.split(gt_box, batch_size, dim=0)
  126. else:
  127. preds = [bbox]
  128. gts = [gt_box]
  129. probs = [prob]
  130. ious = []
  131. for pred, gt in zip(preds, gts):
  132. def box_xywh2xyxy(box):
  133. x = box[:, 0]
  134. y = box[:, 1]
  135. w = box[:, 2]
  136. h = box[:, 3]
  137. return fluid.layers.stack(
  138. [
  139. x - w / 2.,
  140. y - h / 2.,
  141. x + w / 2.,
  142. y + h / 2.,
  143. ], axis=1)
  144. pred = fluid.layers.squeeze(pred, axes=[0])
  145. gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
  146. ious.append(fluid.layers.iou_similarity(pred, gt))
  147. iou = fluid.layers.stack(ious, axis=0)
  148. # 3. Get iou_mask by IoU between gt bbox and prediction bbox,
  149. # Get obj_mask by tobj(holds gt_score), calculate objectness loss
  150. max_iou = fluid.layers.reduce_max(iou, dim=-1)
  151. iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
  152. output_shape = fluid.layers.shape(output)
  153. an_num = len(anchors) // 2
  154. iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
  155. output_shape[3]))
  156. iou_mask.stop_gradient = True
  157. # NOTE: tobj holds gt_score, obj_mask holds object existence mask
  158. obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
  159. obj_mask.stop_gradient = True
  160. # For positive objectness grids, objectness loss should be calculated
  161. # For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
  162. loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask)
  163. loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
  164. loss_obj_neg = fluid.layers.reduce_sum(
  165. loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
  166. return loss_obj_pos, loss_obj_neg
  167. def fine_grained_loss(output,
  168. target,
  169. gt_box,
  170. batch_size,
  171. num_classes,
  172. anchors,
  173. ignore_thresh,
  174. downsample,
  175. scale_x_y=1.,
  176. eps=1e-10):
  177. an_num = len(anchors) // 2
  178. x, y, w, h, obj, cls = _split_output(output, an_num, num_classes)
  179. tx, ty, tw, th, tscale, tobj, tcls = _split_target(target)
  180. tscale_tobj = tscale * tobj
  181. scale_x_y = scale_x_y
  182. if (abs(scale_x_y - 1.0) < eps):
  183. loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
  184. x, tx) * tscale_tobj
  185. loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
  186. loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
  187. y, ty) * tscale_tobj
  188. loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
  189. else:
  190. dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - 1.0)
  191. dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - 1.0)
  192. loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
  193. loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
  194. loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
  195. loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
  196. # NOTE: we refined loss function of (w, h) as L1Loss
  197. loss_w = fluid.layers.abs(w - tw) * tscale_tobj
  198. loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
  199. loss_h = fluid.layers.abs(h - th) * tscale_tobj
  200. loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
  201. loss_obj_pos, loss_obj_neg = _calc_obj_loss(
  202. output, obj, tobj, gt_box, batch_size, anchors, num_classes, downsample,
  203. ignore_thresh, scale_x_y)
  204. loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
  205. loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
  206. loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
  207. loss_xys = fluid.layers.reduce_mean(loss_x + loss_y)
  208. loss_whs = fluid.layers.reduce_mean(loss_w + loss_h)
  209. loss_objs = fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg)
  210. loss_clss = fluid.layers.reduce_mean(loss_cls)
  211. losses_all = {
  212. "loss_xy": fluid.layers.sum(loss_xys),
  213. "loss_wh": fluid.layers.sum(loss_whs),
  214. "loss_loc": fluid.layers.sum(loss_xys) + fluid.layers.sum(loss_whs),
  215. "loss_obj": fluid.layers.sum(loss_objs),
  216. "loss_cls": fluid.layers.sum(loss_clss),
  217. }
  218. return losses_all, x, y, tx, ty
  219. def gt2yolotarget(gt_bbox, gt_class, gt_score, anchors, mask, num_classes, size,
  220. stride):
  221. grid_h, grid_w = size
  222. h, w = grid_h * stride, grid_w * stride
  223. an_hw = np.array(anchors) / np.array([[w, h]])
  224. target = np.zeros(
  225. (len(mask), 6 + num_classes, grid_h, grid_w), dtype=np.float32)
  226. for b in range(gt_bbox.shape[0]):
  227. gx, gy, gw, gh = gt_bbox[b, :]
  228. cls = gt_class[b]
  229. score = gt_score[b]
  230. if gw <= 0. or gh <= 0. or score <= 0.:
  231. continue
  232. # find best match anchor index
  233. best_iou = 0.
  234. best_idx = -1
  235. for an_idx in range(an_hw.shape[0]):
  236. iou = jaccard_overlap([0., 0., gw, gh],
  237. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  238. if iou > best_iou:
  239. best_iou = iou
  240. best_idx = an_idx
  241. gi = int(gx * grid_w)
  242. gj = int(gy * grid_h)
  243. # gtbox should be regresed in this layes if best match
  244. # anchor index in anchor mask of this layer
  245. if best_idx in mask:
  246. best_n = mask.index(best_idx)
  247. # x, y, w, h, scale
  248. target[best_n, 0, gj, gi] = gx * grid_w - gi
  249. target[best_n, 1, gj, gi] = gy * grid_h - gj
  250. target[best_n, 2, gj, gi] = np.log(gw * w / anchors[best_idx][0])
  251. target[best_n, 3, gj, gi] = np.log(gh * h / anchors[best_idx][1])
  252. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  253. # objectness record gt_score
  254. # if target[best_n, 5, gj, gi] > 0:
  255. # print('find 1 duplicate')
  256. target[best_n, 5, gj, gi] = score
  257. # classification
  258. target[best_n, 6 + cls, gj, gi] = 1.
  259. return target
  260. class TestYolov3LossOp(unittest.TestCase):
  261. def setUp(self):
  262. self.initTestCase()
  263. x = np.random.uniform(0, 1, self.x_shape).astype('float64')
  264. gtbox = np.random.random(size=self.gtbox_shape).astype('float64')
  265. gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2])
  266. gtmask = np.random.randint(0, 2, self.gtbox_shape[:2])
  267. gtbox = gtbox * gtmask[:, :, np.newaxis]
  268. gtlabel = gtlabel * gtmask
  269. gtscore = np.ones(self.gtbox_shape[:2]).astype('float64')
  270. if self.gtscore:
  271. gtscore = np.random.random(self.gtbox_shape[:2]).astype('float64')
  272. target = []
  273. for box, label, score in zip(gtbox, gtlabel, gtscore):
  274. target.append(
  275. gt2yolotarget(box, label, score, self.anchors, self.anchor_mask,
  276. self.class_num, (self.h, self.w
  277. ), self.downsample_ratio))
  278. self.target = np.array(target).astype('float64')
  279. self.mask_anchors = []
  280. for i in self.anchor_mask:
  281. self.mask_anchors.extend(self.anchors[i])
  282. self.x = x
  283. self.gtbox = gtbox
  284. self.gtlabel = gtlabel
  285. self.gtscore = gtscore
  286. def initTestCase(self):
  287. self.b = 8
  288. self.h = 19
  289. self.w = 19
  290. self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  291. [59, 119], [116, 90], [156, 198], [373, 326]]
  292. self.anchor_mask = [6, 7, 8]
  293. self.na = len(self.anchor_mask)
  294. self.class_num = 80
  295. self.ignore_thresh = 0.7
  296. self.downsample_ratio = 32
  297. self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
  298. self.h, self.w)
  299. self.gtbox_shape = (self.b, 40, 4)
  300. self.gtscore = True
  301. self.use_label_smooth = False
  302. self.scale_x_y = 1.
  303. def test_loss(self):
  304. x, gtbox, gtlabel, gtscore, target = self.x, self.gtbox, self.gtlabel, self.gtscore, self.target
  305. yolo_loss = YOLOv3Loss(
  306. ignore_thresh=self.ignore_thresh,
  307. label_smooth=self.use_label_smooth,
  308. num_classes=self.class_num,
  309. downsample=self.downsample_ratio,
  310. scale_x_y=self.scale_x_y)
  311. x = paddle.to_tensor(x.astype(np.float32))
  312. gtbox = paddle.to_tensor(gtbox.astype(np.float32))
  313. gtlabel = paddle.to_tensor(gtlabel.astype(np.float32))
  314. gtscore = paddle.to_tensor(gtscore.astype(np.float32))
  315. t = paddle.to_tensor(target.astype(np.float32))
  316. anchor = [self.anchors[i] for i in self.anchor_mask]
  317. (yolo_loss1, px, py, tx, ty) = fine_grained_loss(
  318. output=x,
  319. target=t,
  320. gt_box=gtbox,
  321. batch_size=self.b,
  322. num_classes=self.class_num,
  323. anchors=self.mask_anchors,
  324. ignore_thresh=self.ignore_thresh,
  325. downsample=self.downsample_ratio,
  326. scale_x_y=self.scale_x_y)
  327. yolo_loss2 = yolo_loss.yolov3_loss(
  328. x, t, gtbox, anchor, self.downsample_ratio, self.scale_x_y)
  329. for k in yolo_loss2:
  330. self.assertAlmostEqual(
  331. yolo_loss1[k].numpy()[0],
  332. yolo_loss2[k].numpy()[0],
  333. delta=1e-2,
  334. msg=k)
  335. class TestYolov3LossNoGTScore(TestYolov3LossOp):
  336. def initTestCase(self):
  337. self.b = 1
  338. self.h = 76
  339. self.w = 76
  340. self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  341. [59, 119], [116, 90], [156, 198], [373, 326]]
  342. self.anchor_mask = [0, 1, 2]
  343. self.na = len(self.anchor_mask)
  344. self.class_num = 80
  345. self.ignore_thresh = 0.7
  346. self.downsample_ratio = 8
  347. self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
  348. self.h, self.w)
  349. self.gtbox_shape = (self.b, 40, 4)
  350. self.gtscore = False
  351. self.use_label_smooth = False
  352. self.scale_x_y = 1.
  353. class TestYolov3LossWithScaleXY(TestYolov3LossOp):
  354. def initTestCase(self):
  355. self.b = 5
  356. self.h = 38
  357. self.w = 38
  358. self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  359. [59, 119], [116, 90], [156, 198], [373, 326]]
  360. self.anchor_mask = [3, 4, 5]
  361. self.na = len(self.anchor_mask)
  362. self.class_num = 80
  363. self.ignore_thresh = 0.7
  364. self.downsample_ratio = 16
  365. self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
  366. self.h, self.w)
  367. self.gtbox_shape = (self.b, 40, 4)
  368. self.gtscore = True
  369. self.use_label_smooth = False
  370. self.scale_x_y = 1.2
  371. if __name__ == "__main__":
  372. unittest.main()