|
|
@@ -27,10 +27,10 @@ from paddlex.utils import logging
|
|
|
|
|
|
|
|
|
class BatchCompose(Transform):
|
|
|
- def __init__(self, batch_transforms=None):
|
|
|
+ def __init__(self, batch_transforms=None, collate_batch=True):
|
|
|
super(BatchCompose, self).__init__()
|
|
|
self.batch_transforms = batch_transforms
|
|
|
- self.lock = mp.Lock()
|
|
|
+ self.collate_batch = collate_batch
|
|
|
|
|
|
def __call__(self, samples):
|
|
|
if self.batch_transforms is not None:
|
|
|
@@ -46,7 +46,23 @@ class BatchCompose(Transform):
|
|
|
|
|
|
samples = _Permute()(samples)
|
|
|
|
|
|
- batch_data = default_collate_fn(samples)
|
|
|
+ extra_key = ['h', 'w', 'flipped']
|
|
|
+ for k in extra_key:
|
|
|
+ for sample in samples:
|
|
|
+ if k in sample:
|
|
|
+ sample.pop(k)
|
|
|
+
|
|
|
+ if self.collate_batch:
|
|
|
+ batch_data = default_collate_fn(samples)
|
|
|
+ else:
|
|
|
+ batch_data = {}
|
|
|
+ for k in samples[0].keys():
|
|
|
+ tmp_data = []
|
|
|
+ for i in range(len(samples)):
|
|
|
+ tmp_data.append(samples[i][k])
|
|
|
+ if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
|
|
|
+ tmp_data = np.stack(tmp_data, axis=0)
|
|
|
+ batch_data[k] = tmp_data
|
|
|
return batch_data
|
|
|
|
|
|
|