Browse Source

add picodet _compose_batch_transform

will-jl944 4 years ago
parent
commit
f29d2b1722
1 changed files with 21 additions and 0 deletions
  1. 21 0
      paddlex/cv/models/detector.py

+ 21 - 0
paddlex/cv/models/detector.py

@@ -712,6 +712,27 @@ class PicoDet(BaseDetector):
         super(PicoDet, self).__init__(
             model_name='PicoDet', num_classes=num_classes, **params)
 
+    def _compose_batch_transform(self, transforms, mode='train'):
+        if mode == 'eval':
+            collate_batch = True
+        else:
+            collate_batch = False
+
+        custom_batch_transforms = []
+        for i, op in enumerate(transforms.transforms):
+            if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
+                if mode != 'train':
+                    raise Exception(
+                        "{} cannot be present in the {} transforms. ".format(
+                            op.__class__.__name__, mode) +
+                        "Please check the {} transforms.".format(mode))
+                custom_batch_transforms.insert(0, copy.deepcopy(op))
+
+        batch_transforms = BatchCompose(
+            custom_batch_transforms, collate_batch=collate_batch)
+
+        return batch_transforms
+
 
 class YOLOv3(BaseDetector):
     def __init__(self,