batch_operators.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import traceback
  15. import multiprocessing as mp
  16. import random
  17. import numpy as np
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. from paddle.fluid.dataloader.collate import default_collate_fn
  23. from .operators import Transform, Resize, ResizeByShort, _Permute, interp_dict
  24. from .box_utils import jaccard_overlap
  25. from paddlex.utils import logging
  26. class BatchCompose(Transform):
  27. def __init__(self, batch_transforms=None, collate_batch=True):
  28. super(BatchCompose, self).__init__()
  29. self.batch_transforms = batch_transforms
  30. self.collate_batch = collate_batch
  31. def __call__(self, samples):
  32. if self.batch_transforms is not None:
  33. for op in self.batch_transforms:
  34. try:
  35. samples = op(samples)
  36. except Exception as e:
  37. stack_info = traceback.format_exc()
  38. logging.warning("fail to map batch transform [{}] "
  39. "with error: {} and stack:\n{}".format(
  40. op, e, str(stack_info)))
  41. raise e
  42. samples = _Permute()(samples)
  43. extra_key = ['h', 'w', 'flipped']
  44. for k in extra_key:
  45. for sample in samples:
  46. if k in sample:
  47. sample.pop(k)
  48. if self.collate_batch:
  49. batch_data = default_collate_fn(samples)
  50. else:
  51. batch_data = {}
  52. for k in samples[0].keys():
  53. tmp_data = []
  54. for i in range(len(samples)):
  55. tmp_data.append(samples[i][k])
  56. if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
  57. tmp_data = np.stack(tmp_data, axis=0)
  58. batch_data[k] = tmp_data
  59. return batch_data
  60. class BatchRandomResize(Transform):
  61. """
  62. Resize a batch of input to random sizes.
  63. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  64. Args:
  65. target_sizes (List[int], List[list or tuple] or Tuple[list or tuple]):
  66. Multiple target sizes, each target size is an int or list/tuple of length 2.
  67. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  68. Interpolation method of resize. Defaults to 'LINEAR'.
  69. Raises:
  70. TypeError: Invalid type of target_size.
  71. ValueError: Invalid interpolation method.
  72. See Also:
  73. RandomResize: Resize input to random sizes.
  74. """
  75. def __init__(self, target_sizes, interp='NEAREST'):
  76. super(BatchRandomResize, self).__init__()
  77. if not (interp == "RANDOM" or interp in interp_dict):
  78. raise ValueError("interp should be one of {}".format(
  79. interp_dict.keys()))
  80. self.interp = interp
  81. assert isinstance(target_sizes, list), \
  82. "target_size must be List"
  83. for i, item in enumerate(target_sizes):
  84. if isinstance(item, int):
  85. target_sizes[i] = (item, item)
  86. self.target_size = target_sizes
  87. def __call__(self, samples):
  88. height, width = random.choice(self.target_size)
  89. resizer = Resize((height, width), interp=self.interp)
  90. samples = resizer(samples)
  91. return samples
  92. class BatchRandomResizeByShort(Transform):
  93. """Resize a batch of input to random sizes with keeping the aspect ratio.
  94. Attention:If interp is 'RANDOM', the interpolation method will be chose randomly.
  95. Args:
  96. short_sizes (List[int], Tuple[int]): Target sizes of the shorter side of the image(s).
  97. max_size (int, optional): The upper bound of longer side of the image(s).
  98. If max_size is -1, no upper bound is applied. Defaults to -1.
  99. interp ({'NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM'}, optional):
  100. Interpolation method of resize. Defaults to 'LINEAR'.
  101. Raises:
  102. TypeError: Invalid type of target_size.
  103. ValueError: Invalid interpolation method.
  104. See Also:
  105. RandomResizeByShort: Resize input to random sizes with keeping the aspect ratio.
  106. """
  107. def __init__(self, short_sizes, max_size=-1, interp='NEAREST'):
  108. super(BatchRandomResizeByShort, self).__init__()
  109. if not (interp == "RANDOM" or interp in interp_dict):
  110. raise ValueError("interp should be one of {}".format(
  111. interp_dict.keys()))
  112. self.interp = interp
  113. assert isinstance(short_sizes, list), \
  114. "short_sizes must be List"
  115. self.short_sizes = short_sizes
  116. self.max_size = max_size
  117. def __call__(self, samples):
  118. short_size = random.choice(self.short_sizes)
  119. resizer = ResizeByShort(
  120. short_size=short_size, max_size=self.max_size, interp=self.interp)
  121. samples = resizer(samples)
  122. return samples
  123. class _BatchPadding(Transform):
  124. def __init__(self, pad_to_stride=0, pad_gt=False):
  125. super(_BatchPadding, self).__init__()
  126. self.pad_to_stride = pad_to_stride
  127. self.pad_gt = pad_gt
  128. def __call__(self, samples):
  129. coarsest_stride = self.pad_to_stride
  130. max_shape = np.array([data['image'].shape for data in samples]).max(
  131. axis=0)
  132. if coarsest_stride > 0:
  133. max_shape[0] = int(
  134. np.ceil(max_shape[0] / coarsest_stride) * coarsest_stride)
  135. max_shape[1] = int(
  136. np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
  137. for data in samples:
  138. im = data['image']
  139. im_h, im_w, im_c = im.shape[:]
  140. padding_im = np.zeros(
  141. (max_shape[0], max_shape[1], im_c), dtype=np.float32)
  142. padding_im[:im_h, :im_w, :] = im
  143. data['image'] = padding_im
  144. if self.pad_gt:
  145. gt_num = []
  146. if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
  147. 'gt_poly']) > 0:
  148. pad_mask = True
  149. else:
  150. pad_mask = False
  151. if pad_mask:
  152. poly_num = []
  153. poly_part_num = []
  154. point_num = []
  155. for data in samples:
  156. gt_num.append(data['gt_bbox'].shape[0])
  157. if pad_mask:
  158. poly_num.append(len(data['gt_poly']))
  159. for poly in data['gt_poly']:
  160. poly_part_num.append(int(len(poly)))
  161. for p_p in poly:
  162. point_num.append(int(len(p_p) / 2))
  163. gt_num_max = max(gt_num)
  164. for i, data in enumerate(samples):
  165. gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
  166. gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
  167. is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
  168. if pad_mask:
  169. poly_num_max = max(poly_num)
  170. poly_part_num_max = max(poly_part_num)
  171. point_num_max = max(point_num)
  172. gt_masks_data = -np.ones(
  173. [poly_num_max, poly_part_num_max, point_num_max, 2],
  174. dtype=np.float32)
  175. gt_num = data['gt_bbox'].shape[0]
  176. gt_box_data[0:gt_num, :] = data['gt_bbox']
  177. gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
  178. if 'is_crowd' in data:
  179. is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd'])
  180. data['is_crowd'] = is_crowd_data
  181. data['gt_bbox'] = gt_box_data
  182. data['gt_class'] = gt_class_data
  183. if pad_mask:
  184. for j, poly in enumerate(data['gt_poly']):
  185. for k, p_p in enumerate(poly):
  186. pp_np = np.array(p_p).reshape(-1, 2)
  187. gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
  188. data['gt_poly'] = gt_masks_data
  189. if 'gt_score' in data:
  190. gt_score_data = np.zeros([gt_num_max], dtype=np.float32)
  191. gt_score_data[0:gt_num] = data['gt_score'][:gt_num, 0]
  192. data['gt_score'] = gt_score_data
  193. if 'difficult' in data:
  194. diff_data = np.zeros([gt_num_max], dtype=np.int32)
  195. diff_data[0:gt_num] = data['difficult'][:gt_num, 0]
  196. data['difficult'] = diff_data
  197. return samples
  198. class _Gt2YoloTarget(Transform):
  199. """
  200. Generate YOLOv3 targets by groud truth data, this operator is only used in
  201. fine grained YOLOv3 loss mode
  202. """
  203. def __init__(self,
  204. anchors,
  205. anchor_masks,
  206. downsample_ratios,
  207. num_classes=80,
  208. iou_thresh=1.):
  209. super(_Gt2YoloTarget, self).__init__()
  210. self.anchors = anchors
  211. self.anchor_masks = anchor_masks
  212. self.downsample_ratios = downsample_ratios
  213. self.num_classes = num_classes
  214. self.iou_thresh = iou_thresh
  215. def __call__(self, samples, context=None):
  216. assert len(self.anchor_masks) == len(self.downsample_ratios), \
  217. "anchor_masks', and 'downsample_ratios' should have same length."
  218. h, w = samples[0]['image'].shape[:2]
  219. an_hw = np.array(self.anchors) / np.array([[w, h]])
  220. for sample in samples:
  221. gt_bbox = sample['gt_bbox']
  222. gt_class = sample['gt_class']
  223. if 'gt_score' not in sample:
  224. sample['gt_score'] = np.ones(
  225. (gt_bbox.shape[0], 1), dtype=np.float32)
  226. gt_score = sample['gt_score']
  227. for i, (
  228. mask, downsample_ratio
  229. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  230. grid_h = int(h / downsample_ratio)
  231. grid_w = int(w / downsample_ratio)
  232. target = np.zeros(
  233. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  234. dtype=np.float32)
  235. for b in range(gt_bbox.shape[0]):
  236. gx, gy, gw, gh = gt_bbox[b, :]
  237. cls = gt_class[b]
  238. score = gt_score[b]
  239. if gw <= 0. or gh <= 0. or score <= 0.:
  240. continue
  241. # find best match anchor index
  242. best_iou = 0.
  243. best_idx = -1
  244. for an_idx in range(an_hw.shape[0]):
  245. iou = jaccard_overlap(
  246. [0., 0., gw, gh],
  247. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  248. if iou > best_iou:
  249. best_iou = iou
  250. best_idx = an_idx
  251. gi = int(gx * grid_w)
  252. gj = int(gy * grid_h)
  253. # gtbox should be regresed in this layes if best match
  254. # anchor index in anchor mask of this layer
  255. if best_idx in mask:
  256. best_n = mask.index(best_idx)
  257. # x, y, w, h, scale
  258. target[best_n, 0, gj, gi] = gx * grid_w - gi
  259. target[best_n, 1, gj, gi] = gy * grid_h - gj
  260. target[best_n, 2, gj, gi] = np.log(
  261. gw * w / self.anchors[best_idx][0])
  262. target[best_n, 3, gj, gi] = np.log(
  263. gh * h / self.anchors[best_idx][1])
  264. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  265. # objectness record gt_score
  266. target[best_n, 5, gj, gi] = score
  267. # classification
  268. target[best_n, 6 + cls, gj, gi] = 1.
  269. # For non-matched anchors, calculate the target if the iou
  270. # between anchor and gt is larger than iou_thresh
  271. if self.iou_thresh < 1:
  272. for idx, mask_i in enumerate(mask):
  273. if mask_i == best_idx: continue
  274. iou = jaccard_overlap(
  275. [0., 0., gw, gh],
  276. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  277. if iou > self.iou_thresh and target[idx, 5, gj,
  278. gi] == 0.:
  279. # x, y, w, h, scale
  280. target[idx, 0, gj, gi] = gx * grid_w - gi
  281. target[idx, 1, gj, gi] = gy * grid_h - gj
  282. target[idx, 2, gj, gi] = np.log(
  283. gw * w / self.anchors[mask_i][0])
  284. target[idx, 3, gj, gi] = np.log(
  285. gh * h / self.anchors[mask_i][1])
  286. target[idx, 4, gj, gi] = 2.0 - gw * gh
  287. # objectness record gt_score
  288. target[idx, 5, gj, gi] = score
  289. # classification
  290. target[idx, 5 + cls, gj, gi] = 1.
  291. sample['target{}'.format(i)] = target
  292. # remove useless gt_class and gt_score after target calculated
  293. sample.pop('gt_class')
  294. sample.pop('gt_score')
  295. return samples