ソースを参照

refine batch_compose

will-jl944 4 年 前
コミット
51ebae6441
1 ファイル変更19 行追加3 行削除
  1. 19 3
      dygraph/paddlex/cv/transforms/batch_operators.py

+ 19 - 3
dygraph/paddlex/cv/transforms/batch_operators.py

@@ -27,10 +27,10 @@ from paddlex.utils import logging
 
 
 class BatchCompose(Transform):
-    def __init__(self, batch_transforms=None):
+    def __init__(self, batch_transforms=None, collate_batch=True):
         super(BatchCompose, self).__init__()
         self.batch_transforms = batch_transforms
-        self.lock = mp.Lock()
+        self.collate_batch = collate_batch
 
     def __call__(self, samples):
         if self.batch_transforms is not None:
@@ -46,7 +46,23 @@ class BatchCompose(Transform):
 
         samples = _Permute()(samples)
 
-        batch_data = default_collate_fn(samples)
+        extra_key = ['h', 'w', 'flipped']
+        for k in extra_key:
+            for sample in samples:
+                if k in sample:
+                    sample.pop(k)
+
+        if self.collate_batch:
+            batch_data = default_collate_fn(samples)
+        else:
+            batch_data = {}
+            for k in samples[0].keys():
+                tmp_data = []
+                for i in range(len(samples)):
+                    tmp_data.append(samples[i][k])
+                if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
+                    tmp_data = np.stack(tmp_data, axis=0)
+                batch_data[k] = tmp_data
         return batch_data