ops.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from numbers import Integral
  16. import math
  17. import six
  18. import paddle
  19. from paddle import fluid
  20. def DropBlock(input, block_size, keep_prob, is_test):
  21. if is_test:
  22. return input
  23. def CalculateGamma(input, block_size, keep_prob):
  24. input_shape = fluid.layers.shape(input)
  25. feat_shape_tmp = fluid.layers.slice(input_shape, [0], [3], [4])
  26. feat_shape_tmp = fluid.layers.cast(feat_shape_tmp, dtype="float32")
  27. feat_shape_t = fluid.layers.reshape(feat_shape_tmp, [1, 1, 1, 1])
  28. feat_area = fluid.layers.pow(feat_shape_t, factor=2)
  29. block_shape_t = fluid.layers.fill_constant(
  30. shape=[1, 1, 1, 1], value=block_size, dtype='float32')
  31. block_area = fluid.layers.pow(block_shape_t, factor=2)
  32. useful_shape_t = feat_shape_t - block_shape_t + 1
  33. useful_area = fluid.layers.pow(useful_shape_t, factor=2)
  34. upper_t = feat_area * (1 - keep_prob)
  35. bottom_t = block_area * useful_area
  36. output = upper_t / bottom_t
  37. return output
  38. gamma = CalculateGamma(input, block_size=block_size, keep_prob=keep_prob)
  39. input_shape = fluid.layers.shape(input)
  40. p = fluid.layers.expand_as(gamma, input)
  41. input_shape_tmp = fluid.layers.cast(input_shape, dtype="int64")
  42. random_matrix = fluid.layers.uniform_random(
  43. input_shape_tmp, dtype='float32', min=0.0, max=1.0)
  44. one_zero_m = fluid.layers.less_than(random_matrix, p)
  45. one_zero_m.stop_gradient = True
  46. one_zero_m = fluid.layers.cast(one_zero_m, dtype="float32")
  47. mask_flag = fluid.layers.pool2d(
  48. one_zero_m,
  49. pool_size=block_size,
  50. pool_type='max',
  51. pool_stride=1,
  52. pool_padding=block_size // 2)
  53. mask = 1.0 - mask_flag
  54. elem_numel = fluid.layers.reduce_prod(input_shape)
  55. elem_numel_m = fluid.layers.cast(elem_numel, dtype="float32")
  56. elem_numel_m.stop_gradient = True
  57. elem_sum = fluid.layers.reduce_sum(mask)
  58. elem_sum_m = fluid.layers.cast(elem_sum, dtype="float32")
  59. elem_sum_m.stop_gradient = True
  60. output = input * mask * elem_numel_m / elem_sum_m
  61. return output
  62. class MultiClassNMS(object):
  63. def __init__(self,
  64. score_threshold=.05,
  65. nms_top_k=-1,
  66. keep_top_k=100,
  67. nms_threshold=.5,
  68. normalized=False,
  69. nms_eta=1.0,
  70. background_label=0):
  71. super(MultiClassNMS, self).__init__()
  72. self.score_threshold = score_threshold
  73. self.nms_top_k = nms_top_k
  74. self.keep_top_k = keep_top_k
  75. self.nms_threshold = nms_threshold
  76. self.normalized = normalized
  77. self.nms_eta = nms_eta
  78. self.background_label = background_label
  79. def __call__(self, bboxes, scores):
  80. return fluid.layers.multiclass_nms(
  81. bboxes=bboxes,
  82. scores=scores,
  83. score_threshold=self.score_threshold,
  84. nms_top_k=self.nms_top_k,
  85. keep_top_k=self.keep_top_k,
  86. normalized=self.normalized,
  87. nms_threshold=self.nms_threshold,
  88. nms_eta=self.nms_eta,
  89. background_label=self.background_label)
  90. class MatrixNMS(object):
  91. def __init__(self,
  92. score_threshold=.05,
  93. post_threshold=.05,
  94. nms_top_k=-1,
  95. keep_top_k=100,
  96. use_gaussian=False,
  97. gaussian_sigma=2.,
  98. normalized=False,
  99. background_label=0):
  100. super(MatrixNMS, self).__init__()
  101. self.score_threshold = score_threshold
  102. self.post_threshold = post_threshold
  103. self.nms_top_k = nms_top_k
  104. self.keep_top_k = keep_top_k
  105. self.normalized = normalized
  106. self.use_gaussian = use_gaussian
  107. self.gaussian_sigma = gaussian_sigma
  108. self.background_label = background_label
  109. def __call__(self, bboxes, scores):
  110. return paddle.fluid.layers.matrix_nms(
  111. bboxes=bboxes,
  112. scores=scores,
  113. score_threshold=self.score_threshold,
  114. post_threshold=self.post_threshold,
  115. nms_top_k=self.nms_top_k,
  116. keep_top_k=self.keep_top_k,
  117. normalized=self.normalized,
  118. use_gaussian=self.use_gaussian,
  119. gaussian_sigma=self.gaussian_sigma,
  120. background_label=self.background_label)
  121. class MultiClassSoftNMS(object):
  122. def __init__(
  123. self,
  124. score_threshold=0.01,
  125. keep_top_k=300,
  126. softnms_sigma=0.5,
  127. normalized=False,
  128. background_label=0, ):
  129. super(MultiClassSoftNMS, self).__init__()
  130. self.score_threshold = score_threshold
  131. self.keep_top_k = keep_top_k
  132. self.softnms_sigma = softnms_sigma
  133. self.normalized = normalized
  134. self.background_label = background_label
  135. def __call__(self, bboxes, scores):
  136. def create_tmp_var(program, name, dtype, shape, lod_level):
  137. return program.current_block().create_var(
  138. name=name, dtype=dtype, shape=shape, lod_level=lod_level)
  139. def _soft_nms_for_cls(dets, sigma, thres):
  140. """soft_nms_for_cls"""
  141. dets_final = []
  142. while len(dets) > 0:
  143. maxpos = np.argmax(dets[:, 0])
  144. dets_final.append(dets[maxpos].copy())
  145. ts, tx1, ty1, tx2, ty2 = dets[maxpos]
  146. scores = dets[:, 0]
  147. # force remove bbox at maxpos
  148. scores[maxpos] = -1
  149. x1 = dets[:, 1]
  150. y1 = dets[:, 2]
  151. x2 = dets[:, 3]
  152. y2 = dets[:, 4]
  153. eta = 0 if self.normalized else 1
  154. areas = (x2 - x1 + eta) * (y2 - y1 + eta)
  155. xx1 = np.maximum(tx1, x1)
  156. yy1 = np.maximum(ty1, y1)
  157. xx2 = np.minimum(tx2, x2)
  158. yy2 = np.minimum(ty2, y2)
  159. w = np.maximum(0.0, xx2 - xx1 + eta)
  160. h = np.maximum(0.0, yy2 - yy1 + eta)
  161. inter = w * h
  162. ovr = inter / (areas + areas[maxpos] - inter)
  163. weight = np.exp(-(ovr * ovr) / sigma)
  164. scores = scores * weight
  165. idx_keep = np.where(scores >= thres)
  166. dets[:, 0] = scores
  167. dets = dets[idx_keep]
  168. dets_final = np.array(dets_final).reshape(-1, 5)
  169. return dets_final
  170. def _soft_nms(bboxes, scores):
  171. class_nums = scores.shape[-1]
  172. softnms_thres = self.score_threshold
  173. softnms_sigma = self.softnms_sigma
  174. keep_top_k = self.keep_top_k
  175. cls_boxes = [[] for _ in range(class_nums)]
  176. cls_ids = [[] for _ in range(class_nums)]
  177. start_idx = 1 if self.background_label == 0 else 0
  178. for j in range(start_idx, class_nums):
  179. inds = np.where(scores[:, j] >= softnms_thres)[0]
  180. scores_j = scores[inds, j]
  181. rois_j = bboxes[inds, j, :] if len(
  182. bboxes.shape) > 2 else bboxes[inds, :]
  183. dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
  184. np.float32, copy=False)
  185. cls_rank = np.argsort(-dets_j[:, 0])
  186. dets_j = dets_j[cls_rank]
  187. cls_boxes[j] = _soft_nms_for_cls(
  188. dets_j, sigma=softnms_sigma, thres=softnms_thres)
  189. cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
  190. 1)
  191. cls_boxes = np.vstack(cls_boxes[start_idx:])
  192. cls_ids = np.vstack(cls_ids[start_idx:])
  193. pred_result = np.hstack([cls_ids, cls_boxes])
  194. # Limit to max_per_image detections **over all classes**
  195. image_scores = cls_boxes[:, 0]
  196. if len(image_scores) > keep_top_k:
  197. image_thresh = np.sort(image_scores)[-keep_top_k]
  198. keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
  199. pred_result = pred_result[keep, :]
  200. return pred_result
  201. def _batch_softnms(bboxes, scores):
  202. batch_offsets = bboxes.lod()
  203. bboxes = np.array(bboxes)
  204. scores = np.array(scores)
  205. out_offsets = [0]
  206. pred_res = []
  207. if len(batch_offsets) > 0:
  208. batch_offset = batch_offsets[0]
  209. for i in range(len(batch_offset) - 1):
  210. s, e = batch_offset[i], batch_offset[i + 1]
  211. pred = _soft_nms(bboxes[s:e], scores[s:e])
  212. out_offsets.append(pred.shape[0] + out_offsets[-1])
  213. pred_res.append(pred)
  214. else:
  215. assert len(bboxes.shape) == 3
  216. assert len(scores.shape) == 3
  217. for i in range(bboxes.shape[0]):
  218. pred = _soft_nms(bboxes[i], scores[i])
  219. out_offsets.append(pred.shape[0] + out_offsets[-1])
  220. pred_res.append(pred)
  221. res = fluid.LoDTensor()
  222. res.set_lod([out_offsets])
  223. if len(pred_res) == 0:
  224. pred_res = np.array([[1]], dtype=np.float32)
  225. res.set(np.vstack(pred_res).astype(np.float32), fluid.CPUPlace())
  226. return res
  227. pred_result = create_tmp_var(
  228. fluid.default_main_program(),
  229. name='softnms_pred_result',
  230. dtype='float32',
  231. shape=[-1, 6],
  232. lod_level=1)
  233. fluid.layers.py_func(
  234. func=_batch_softnms, x=[bboxes, scores], out=pred_result)
  235. return pred_result