Explorar el Código

add picodet for restful

FlyingQianMM hace 4 años
padre
commit
c0cdf6698a
Se han modificado 1 ficheros con 38 adiciones y 1 borrados
  1. 38 1
      paddlex_restful/restful/project/train/detection.py

+ 38 - 1
paddlex_restful/restful/project/train/detection.py

@@ -90,6 +90,40 @@ def build_rcnn_transforms(params):
     return train_transforms, eval_transforms
 
 
+def build_pico_transforms(params):
+    from paddlex import transforms as T
+    target_size = params.image_shape[0]
+    dt_list = []
+    dt_list.extend([
+        T.RandomDistort(
+            brightness_range=params.brightness_range,
+            brightness_prob=params.brightness_prob,
+            contrast_range=params.contrast_range,
+            contrast_prob=params.contrast_prob,
+            saturation_range=params.saturation_range,
+            saturation_prob=params.saturation_prob,
+            hue_range=params.hue_range,
+            hue_prob=params.hue_prob),
+    ])
+    crop_image = params.crop_image
+    if crop_image:
+        dt_list.append(T.RandomCrop())
+    dt_list.extend([
+        T.Resize(
+            target_size=target_size, interp='RANDOM'),
+        T.RandomHorizontalFlip(prob=params.horizontal_flip_prob), T.Normalize(
+            mean=params.image_mean, std=params.image_std)
+    ])
+    train_transforms = T.Compose(dt_list)
+    eval_transforms = T.Compose([
+        T.Resize(
+            target_size=target_size, interp='CUBIC'),
+        T.Normalize(
+            mean=params.image_mean, std=params.image_std),
+    ])
+    return train_transforms, eval_transforms
+
+
 def build_voc_datasets(dataset_path, train_transforms, eval_transforms):
     import paddlex as pdx
     train_file_list = osp.join(dataset_path, 'train_list.txt')
@@ -157,6 +191,8 @@ def train(task_path, dataset_path, params):
     pdx.log_level = 3
     if params.model in ['YOLOv3', 'PPYOLO', 'PPYOLOTiny', 'PPYOLOv2']:
         train_transforms, eval_transforms = build_yolo_transforms(params)
+    elif params.model in ['PicoDet']:
+        train_transforms, eval_transforms = build_pico_transforms(params)
     elif params.model in ['FasterRCNN', 'MaskRCNN']:
         train_transforms, eval_transforms = build_rcnn_transforms(params)
     if osp.exists(osp.join(dataset_path, 'JPEGImages')) and \
@@ -194,7 +230,8 @@ def train(task_path, dataset_path, params):
         # prune
         dataset = eval_dataset or train_dataset
         im_shape = dataset[0]['image'].shape[:2]
-        if getattr(model, 'with_fpn', False):
+        if getattr(model, 'with_fpn',
+                   False) or model.__class__.__name__ == 'PicoDet':
             im_shape[0] = int(np.ceil(im_shape[0] / 32) * 32)
             im_shape[1] = int(np.ceil(im_shape[1] / 32) * 32)
         inputs = [{