batch_operators.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import traceback
  15. import random
  16. import numpy as np
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. from paddle.fluid.dataloader.collate import default_collate_fn
  22. from .operators import Transform, Resize, ResizeByShort, _Permute, interp_dict
  23. from .box_utils import jaccard_overlap
  24. from paddlex.utils import logging
  25. class BatchCompose(Transform):
  26. def __init__(self,
  27. batch_transforms=None,
  28. collate_batch=True,
  29. return_list=False):
  30. super(BatchCompose, self).__init__()
  31. self.batch_transforms = batch_transforms
  32. self.collate_batch = collate_batch
  33. self.return_list = return_list
  34. def __call__(self, samples):
  35. if self.batch_transforms is not None:
  36. for op in self.batch_transforms:
  37. try:
  38. samples = op(samples)
  39. except Exception as e:
  40. stack_info = traceback.format_exc()
  41. logging.warning("fail to map batch transform [{}] "
  42. "with error: {} and stack:\n{}".format(
  43. op, e, str(stack_info)))
  44. raise e
  45. samples = _Permute()(samples)
  46. extra_key = ['h', 'w', 'flipped']
  47. for k in extra_key:
  48. for sample in samples:
  49. if k in sample:
  50. sample.pop(k)
  51. if self.return_list:
  52. batch_data = [{
  53. k: np.expand_dims(
  54. sample[k], axis=0)
  55. for k in sample
  56. } for sample in samples]
  57. elif self.collate_batch:
  58. batch_data = default_collate_fn(samples)
  59. else:
  60. batch_data = {}
  61. for k in samples[0].keys():
  62. tmp_data = []
  63. for i in range(len(samples)):
  64. tmp_data.append(samples[i][k])
  65. if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
  66. tmp_data = np.stack(tmp_data, axis=0)
  67. batch_data[k] = tmp_data
  68. return batch_data
  69. class BatchRandomResize(Transform):
  70. """
  71. Resize a batch of input to random sizes.
  72. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  73. Args:
  74. target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
  75. Multiple target sizes, each target size is an int or list/tuple of length 2.
  76. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  77. Interpolation method of resize. Defaults to 'LINEAR'.
  78. Raises:
  79. TypeError: Invalid type of target_size.
  80. ValueError: Invalid interpolation method.
  81. See Also:
  82. RandomResize: Resize input to random sizes.
  83. """
  84. def __init__(self, target_sizes, interp='NEAREST'):
  85. super(BatchRandomResize, self).__init__()
  86. if not (interp == "RANDOM" or interp in interp_dict):
  87. raise ValueError("interp should be one of {}".format(
  88. interp_dict.keys()))
  89. self.interp = interp
  90. assert isinstance(target_sizes, list), \
  91. "target_size must be List"
  92. for i, item in enumerate(target_sizes):
  93. if isinstance(item, int):
  94. target_sizes[i] = (item, item)
  95. self.target_size = target_sizes
  96. def __call__(self, samples):
  97. height, width = random.choice(self.target_size)
  98. resizer = Resize((height, width), interp=self.interp)
  99. samples = resizer(samples)
  100. return samples
  101. class BatchRandomResizeByShort(Transform):
  102. """Resize a batch of input to random sizes with keeping the aspect ratio.
  103. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  104. Args:
  105. short_sizes (List[int], Tuple[int]): Target sizes of the shorter side of the image(s).
  106. max_size (int, optional): The upper bound of longer side of the image(s).
  107. If max_size is -1, no upper bound is applied. Defaults to -1.
  108. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  109. Interpolation method of resize. Defaults to 'LINEAR'.
  110. Raises:
  111. TypeError: Invalid type of target_size.
  112. ValueError: Invalid interpolation method.
  113. See Also:
  114. RandomResizeByShort: Resize input to random sizes with keeping the aspect ratio.
  115. """
  116. def __init__(self, short_sizes, max_size=-1, interp='NEAREST'):
  117. super(BatchRandomResizeByShort, self).__init__()
  118. if not (interp == "RANDOM" or interp in interp_dict):
  119. raise ValueError("interp should be one of {}".format(
  120. interp_dict.keys()))
  121. self.interp = interp
  122. assert isinstance(short_sizes, list), \
  123. "short_sizes must be List"
  124. self.short_sizes = short_sizes
  125. self.max_size = max_size
  126. def __call__(self, samples):
  127. short_size = random.choice(self.short_sizes)
  128. resizer = ResizeByShort(
  129. short_size=short_size, max_size=self.max_size, interp=self.interp)
  130. samples = resizer(samples)
  131. return samples
  132. class _BatchPadding(Transform):
  133. def __init__(self, pad_to_stride=0):
  134. super(_BatchPadding, self).__init__()
  135. self.pad_to_stride = pad_to_stride
  136. def __call__(self, samples):
  137. coarsest_stride = self.pad_to_stride
  138. max_shape = np.array([data['image'].shape for data in samples]).max(
  139. axis=0)
  140. if coarsest_stride > 0:
  141. max_shape[0] = int(
  142. np.ceil(max_shape[0] / coarsest_stride) * coarsest_stride)
  143. max_shape[1] = int(
  144. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  145. for data in samples:
  146. im = data['image']
  147. im_h, im_w, im_c = im.shape[:]
  148. padding_im = np.zeros(
  149. (max_shape[0], max_shape[1], im_c), dtype=np.float32)
  150. padding_im[:im_h, :im_w, :] = im
  151. data['image'] = padding_im
  152. return samples
  153. class _Gt2YoloTarget(Transform):
  154. """
  155. Generate YOLOv3 targets by groud truth data, this operator is only used in
  156. fine grained YOLOv3 loss mode
  157. """
  158. def __init__(self,
  159. anchors,
  160. anchor_masks,
  161. downsample_ratios,
  162. num_classes=80,
  163. iou_thresh=1.):
  164. super(_Gt2YoloTarget, self).__init__()
  165. self.anchors = anchors
  166. self.anchor_masks = anchor_masks
  167. self.downsample_ratios = downsample_ratios
  168. self.num_classes = num_classes
  169. self.iou_thresh = iou_thresh
  170. def __call__(self, samples, context=None):
  171. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  172. "anchor_masks', and 'downsample_ratios' should have same length."
  173. h, w = samples[0]['image'].shape[:2]
  174. an_hw = np.array(self.anchors) / np.array([[w, h]])
  175. for sample in samples:
  176. gt_bbox = sample['gt_bbox']
  177. gt_class = sample['gt_class']
  178. if 'gt_score' not in sample:
  179. sample['gt_score'] = np.ones(
  180. (gt_bbox.shape[0], 1), dtype=np.float32)
  181. gt_score = sample['gt_score']
  182. for i, (
  183. mask, downsample_ratio
  184. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  185. grid_h = int(h / downsample_ratio)
  186. grid_w = int(w / downsample_ratio)
  187. target = np.zeros(
  188. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  189. dtype=np.float32)
  190. for b in range(gt_bbox.shape[0]):
  191. gx, gy, gw, gh = gt_bbox[b, :]
  192. cls = gt_class[b]
  193. score = gt_score[b]
  194. if gw <= 0. or gh <= 0. or score <= 0.:
  195. continue
  196. # find best match anchor index
  197. best_iou = 0.
  198. best_idx = -1
  199. for an_idx in range(an_hw.shape[0]):
  200. iou = jaccard_overlap(
  201. [0., 0., gw, gh],
  202. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  203. if iou > best_iou:
  204. best_iou = iou
  205. best_idx = an_idx
  206. gi = int(gx * grid_w)
  207. gj = int(gy * grid_h)
  208. # gtbox should be regresed in this layes if best match
  209. # anchor index in anchor mask of this layer
  210. if best_idx in mask:
  211. best_n = mask.index(best_idx)
  212. # x, y, w, h, scale
  213. target[best_n, 0, gj, gi] = gx * grid_w - gi
  214. target[best_n, 1, gj, gi] = gy * grid_h - gj
  215. target[best_n, 2, gj, gi] = np.log(
  216. gw * w / self.anchors[best_idx][0])
  217. target[best_n, 3, gj, gi] = np.log(
  218. gh * h / self.anchors[best_idx][1])
  219. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  220. # objectness record gt_score
  221. target[best_n, 5, gj, gi] = score
  222. # classification
  223. target[best_n, 6 + cls, gj, gi] = 1.
  224. # For non-matched anchors, calculate the target if the iou
  225. # between anchor and gt is larger than iou_thresh
  226. if self.iou_thresh < 1:
  227. for idx, mask_i in enumerate(mask):
  228. if mask_i == best_idx: continue
  229. iou = jaccard_overlap(
  230. [0., 0., gw, gh],
  231. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  232. if iou > self.iou_thresh and target[idx, 5, gj,
  233. gi] == 0.:
  234. # x, y, w, h, scale
  235. target[idx, 0, gj, gi] = gx * grid_w - gi
  236. target[idx, 1, gj, gi] = gy * grid_h - gj
  237. target[idx, 2, gj, gi] = np.log(
  238. gw * w / self.anchors[mask_i][0])
  239. target[idx, 3, gj, gi] = np.log(
  240. gh * h / self.anchors[mask_i][1])
  241. target[idx, 4, gj, gi] = 2.0 - gw * gh
  242. # objectness record gt_score
  243. target[idx, 5, gj, gi] = score
  244. # classification
  245. target[idx, 5 + cls, gj, gi] = 1.
  246. sample['target{}'.format(i)] = target
  247. # remove useless gt_class and gt_score after target calculated
  248. sample.pop('gt_class')
  249. sample.pop('gt_score')
  250. return samples