batch_operators.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import traceback
  15. import random
  16. import numpy as np
  17. try:
  18. from collections.abc import Sequence
  19. except Exception:
  20. from collections import Sequence
  21. from .operators import Transform, Resize, ResizeByShort, _Permute, interp_dict
  22. from .box_utils import jaccard_overlap
  23. from paddlex.utils import logging
  24. from paddlex.cv.datasets.utils import default_collate_fn
  25. class BatchCompose(Transform):
  26. def __init__(self, batch_transforms=None, collate_batch=True):
  27. super(BatchCompose, self).__init__()
  28. self.batch_transforms = batch_transforms
  29. self.collate_batch = collate_batch
  30. def __call__(self, samples):
  31. if self.batch_transforms is not None:
  32. for op in self.batch_transforms:
  33. try:
  34. samples = op(samples)
  35. except Exception as e:
  36. stack_info = traceback.format_exc()
  37. logging.warning("fail to map batch transform [{}] "
  38. "with error: {} and stack:\n{}".format(
  39. op, e, str(stack_info)))
  40. raise e
  41. samples = _Permute()(samples)
  42. extra_key = ['h', 'w', 'flipped']
  43. for k in extra_key:
  44. for sample in samples:
  45. if k in sample:
  46. sample.pop(k)
  47. if self.collate_batch:
  48. batch_data = default_collate_fn(samples)
  49. else:
  50. batch_data = {}
  51. for k in samples[0].keys():
  52. tmp_data = []
  53. for i in range(len(samples)):
  54. tmp_data.append(samples[i][k])
  55. if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
  56. tmp_data = np.stack(tmp_data, axis=0)
  57. batch_data[k] = tmp_data
  58. return batch_data
  59. class BatchRandomResize(Transform):
  60. """
  61. Resize a batch of input to random sizes.
  62. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  63. Args:
  64. target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
  65. Multiple target sizes, each target size is an int or list/tuple of length 2.
  66. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  67. Interpolation method of resize. Defaults to 'LINEAR'.
  68. Raises:
  69. TypeError: Invalid type of target_size.
  70. ValueError: Invalid interpolation method.
  71. See Also:
  72. RandomResize: Resize input to random sizes.
  73. """
  74. def __init__(self, target_sizes, interp='NEAREST'):
  75. super(BatchRandomResize, self).__init__()
  76. if not (interp == "RANDOM" or interp in interp_dict):
  77. raise ValueError("interp should be one of {}".format(
  78. interp_dict.keys()))
  79. self.interp = interp
  80. assert isinstance(target_sizes, list), \
  81. "target_size must be List"
  82. for i, item in enumerate(target_sizes):
  83. if isinstance(item, int):
  84. target_sizes[i] = (item, item)
  85. self.target_size = target_sizes
  86. def __call__(self, samples):
  87. height, width = random.choice(self.target_size)
  88. resizer = Resize((height, width), interp=self.interp)
  89. samples = resizer(samples)
  90. return samples
  91. class BatchRandomResizeByShort(Transform):
  92. """Resize a batch of input to random sizes with keeping the aspect ratio.
  93. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  94. Args:
  95. short_sizes (List[int], Tuple[int]): Target sizes of the shorter side of the image(s).
  96. max_size (int, optional): The upper bound of longer side of the image(s).
  97. If max_size is -1, no upper bound is applied. Defaults to -1.
  98. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  99. Interpolation method of resize. Defaults to 'LINEAR'.
  100. Raises:
  101. TypeError: Invalid type of target_size.
  102. ValueError: Invalid interpolation method.
  103. See Also:
  104. RandomResizeByShort: Resize input to random sizes with keeping the aspect ratio.
  105. """
  106. def __init__(self, short_sizes, max_size=-1, interp='NEAREST'):
  107. super(BatchRandomResizeByShort, self).__init__()
  108. if not (interp == "RANDOM" or interp in interp_dict):
  109. raise ValueError("interp should be one of {}".format(
  110. interp_dict.keys()))
  111. self.interp = interp
  112. assert isinstance(short_sizes, list), \
  113. "short_sizes must be List"
  114. self.short_sizes = short_sizes
  115. self.max_size = max_size
  116. def __call__(self, samples):
  117. short_size = random.choice(self.short_sizes)
  118. resizer = ResizeByShort(
  119. short_size=short_size, max_size=self.max_size, interp=self.interp)
  120. samples = resizer(samples)
  121. return samples
  122. class _BatchPadding(Transform):
  123. def __init__(self, pad_to_stride=0):
  124. super(_BatchPadding, self).__init__()
  125. self.pad_to_stride = pad_to_stride
  126. def __call__(self, samples):
  127. coarsest_stride = self.pad_to_stride
  128. max_shape = np.array([data['image'].shape for data in samples]).max(
  129. axis=0)
  130. if coarsest_stride > 0:
  131. max_shape[0] = int(
  132. np.ceil(max_shape[0] / coarsest_stride) * coarsest_stride)
  133. max_shape[1] = int(
  134. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  135. for data in samples:
  136. im = data['image']
  137. im_h, im_w, im_c = im.shape[:]
  138. padding_im = np.zeros(
  139. (max_shape[0], max_shape[1], im_c), dtype=np.float32)
  140. padding_im[:im_h, :im_w, :] = im
  141. data['image'] = padding_im
  142. return samples
  143. class _Gt2YoloTarget(Transform):
  144. """
  145. Generate YOLOv3 targets by groud truth data, this operator is only used in
  146. fine grained YOLOv3 loss mode
  147. """
  148. def __init__(self,
  149. anchors,
  150. anchor_masks,
  151. downsample_ratios,
  152. num_classes=80,
  153. iou_thresh=1.):
  154. super(_Gt2YoloTarget, self).__init__()
  155. self.anchors = anchors
  156. self.anchor_masks = anchor_masks
  157. self.downsample_ratios = downsample_ratios
  158. self.num_classes = num_classes
  159. self.iou_thresh = iou_thresh
  160. def __call__(self, samples, context=None):
  161. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  162. "anchor_masks', and 'downsample_ratios' should have same length."
  163. h, w = samples[0]['image'].shape[:2]
  164. an_hw = np.array(self.anchors) / np.array([[w, h]])
  165. for sample in samples:
  166. gt_bbox = sample['gt_bbox']
  167. gt_class = sample['gt_class']
  168. if 'gt_score' not in sample:
  169. sample['gt_score'] = np.ones(
  170. (gt_bbox.shape[0], 1), dtype=np.float32)
  171. gt_score = sample['gt_score']
  172. for i, (
  173. mask, downsample_ratio
  174. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  175. grid_h = int(h / downsample_ratio)
  176. grid_w = int(w / downsample_ratio)
  177. target = np.zeros(
  178. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  179. dtype=np.float32)
  180. for b in range(gt_bbox.shape[0]):
  181. gx, gy, gw, gh = gt_bbox[b, :]
  182. cls = gt_class[b]
  183. score = gt_score[b]
  184. if gw <= 0. or gh <= 0. or score <= 0.:
  185. continue
  186. # find best match anchor index
  187. best_iou = 0.
  188. best_idx = -1
  189. for an_idx in range(an_hw.shape[0]):
  190. iou = jaccard_overlap(
  191. [0., 0., gw, gh],
  192. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  193. if iou > best_iou:
  194. best_iou = iou
  195. best_idx = an_idx
  196. gi = int(gx * grid_w)
  197. gj = int(gy * grid_h)
  198. # gtbox should be regresed in this layes if best match
  199. # anchor index in anchor mask of this layer
  200. if best_idx in mask:
  201. best_n = mask.index(best_idx)
  202. # x, y, w, h, scale
  203. target[best_n, 0, gj, gi] = gx * grid_w - gi
  204. target[best_n, 1, gj, gi] = gy * grid_h - gj
  205. target[best_n, 2, gj, gi] = np.log(
  206. gw * w / self.anchors[best_idx][0])
  207. target[best_n, 3, gj, gi] = np.log(
  208. gh * h / self.anchors[best_idx][1])
  209. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  210. # objectness record gt_score
  211. target[best_n, 5, gj, gi] = score
  212. # classification
  213. target[best_n, 6 + cls, gj, gi] = 1.
  214. # For non-matched anchors, calculate the target if the iou
  215. # between anchor and gt is larger than iou_thresh
  216. if self.iou_thresh < 1:
  217. for idx, mask_i in enumerate(mask):
  218. if mask_i == best_idx: continue
  219. iou = jaccard_overlap(
  220. [0., 0., gw, gh],
  221. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  222. if iou > self.iou_thresh and target[idx, 5, gj,
  223. gi] == 0.:
  224. # x, y, w, h, scale
  225. target[idx, 0, gj, gi] = gx * grid_w - gi
  226. target[idx, 1, gj, gi] = gy * grid_h - gj
  227. target[idx, 2, gj, gi] = np.log(
  228. gw * w / self.anchors[mask_i][0])
  229. target[idx, 3, gj, gi] = np.log(
  230. gh * h / self.anchors[mask_i][1])
  231. target[idx, 4, gj, gi] = 2.0 - gw * gh
  232. # objectness record gt_score
  233. target[idx, 5, gj, gi] = score
  234. # classification
  235. target[idx, 5 + cls, gj, gi] = 1.
  236. sample['target{}'.format(i)] = target
  237. # remove useless gt_class and gt_score after target calculated
  238. sample.pop('gt_class')
  239. sample.pop('gt_score')
  240. return samples