|
|
@@ -712,6 +712,27 @@ class PicoDet(BaseDetector):
|
|
|
super(PicoDet, self).__init__(
|
|
|
model_name='PicoDet', num_classes=num_classes, **params)
|
|
|
|
|
|
+ def _compose_batch_transform(self, transforms, mode='train'):
|
|
|
+ if mode == 'eval':
|
|
|
+ collate_batch = True
|
|
|
+ else:
|
|
|
+ collate_batch = False
|
|
|
+
|
|
|
+ custom_batch_transforms = []
|
|
|
+ for i, op in enumerate(transforms.transforms):
|
|
|
+ if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
|
|
|
+ if mode != 'train':
|
|
|
+ raise Exception(
|
|
|
+ "{} cannot be present in the {} transforms. ".format(
|
|
|
+ op.__class__.__name__, mode) +
|
|
|
+ "Please check the {} transforms.".format(mode))
|
|
|
+ custom_batch_transforms.insert(0, copy.deepcopy(op))
|
|
|
+
|
|
|
+ batch_transforms = BatchCompose(
|
|
|
+ custom_batch_transforms, collate_batch=collate_batch)
|
|
|
+
|
|
|
+ return batch_transforms
|
|
|
+
|
|
|
|
|
|
class YOLOv3(BaseDetector):
|
|
|
def __init__(self,
|