Browse Source

force single card while training RCNN with negative samples

will-jl944 4 years ago
parent
commit
7298976d96
4 changed files with 167 additions and 12 deletions
  1. 1 1
      paddlex/cv/datasets/coco.py
  2. 4 2
      paddlex/cv/datasets/voc.py
  3. 5 1
      paddlex/cv/models/base.py
  4. 157 8
      paddlex/cv/models/detector.py

+ 1 - 1
paddlex/cv/datasets/coco.py

@@ -196,7 +196,7 @@ class CocoDetection(VOCDetection):
             logging.error(
                 "No coco record found in %s' % (ann_file)", exit=True)
         self.pos_num = len(self.file_list)
-        if self.allow_empty:
+        if self.allow_empty and neg_file_list:
             self.file_list += self._sample_empty(neg_file_list)
         logging.info(
             "{} samples in file {}, including {} positive samples and {} negative samples.".

+ 4 - 2
paddlex/cv/datasets/voc.py

@@ -290,7 +290,7 @@ class VOCDetection(Dataset):
             logging.error(
                 "No voc record found in %s' % (file_list)", exit=True)
         self.pos_num = len(self.file_list)
-        if self.allow_empty:
+        if self.allow_empty and neg_file_list:
             self.file_list += self._sample_empty(neg_file_list)
         logging.info(
             "{} samples in file {}, including {} positive samples and {} negative samples.".
@@ -423,7 +423,9 @@ class VOCDetection(Dataset):
                 **
                 label_info
             })
-        self.file_list += self._sample_empty(neg_file_list)
+        if neg_file_list:
+            self.allow_empty = True
+            self.file_list += self._sample_empty(neg_file_list)
         logging.info(
             "{} negative samples added. Dataset contains {} positive samples and {} negative samples.".
             format(

+ 5 - 1
paddlex/cv/models/base.py

@@ -271,7 +271,11 @@ class BaseModel:
             transforms=train_dataset.transforms,
             mode='train')
 
-        nranks = paddle.distributed.get_world_size()
+        if "RCNN" in self.__class__.__name__ and train_dataset.pos_num < len(
+                train_dataset.file_list):
+            nranks = 1
+        else:
+            nranks = paddle.distributed.get_world_size()
         local_rank = paddle.distributed.get_rank()
         if nranks > 1:
             find_unused_parameters = getattr(self, 'find_unused_parameters',

+ 157 - 8
paddlex/cv/models/detector.py

@@ -18,7 +18,6 @@ import collections
 import copy
 import os
 import os.path as osp
-import six
 import numpy as np
 import paddle
 from paddle.static import InputSpec
@@ -29,6 +28,7 @@ import paddlex.utils.logging as logging
 from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Padding
 from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, _BatchPadding, _Gt2YoloTarget
 from paddlex.cv.transforms import arrange_transforms
+from paddlex.utils import get_single_card_bs
 from .base import BaseModel
 from .utils.det_metrics import VOCMetric, COCOMetric
 from .utils.ema import ExponentialMovingAverage
@@ -193,13 +193,6 @@ class BaseDetector(BaseModel):
                 If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
                 `pretrain_weights` can be set simultaneously. Defaults to None.
         """
-        if train_dataset.pos_num < len(
-                train_dataset.file_list
-        ) and train_batch_size != 1 and 'RCNN' in self.__class__.__name__:
-            train_batch_size = 1
-            logging.warning(
-                "Training RCNN models with negative samples only support batch size equals to 1, "
-                "`train_batch_size` is forcibly set to 1.")
         if self.status == 'Infer':
             logging.error(
                 "Exported inference model does not support training.",
@@ -982,6 +975,84 @@ class FasterRCNN(BaseDetector):
         super(FasterRCNN, self).__init__(
             model_name='FasterRCNN', num_classes=num_classes, **params)
 
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=64,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=10,
+              save_dir='output',
+              pretrain_weights='IMAGENET',
+              learning_rate=.001,
+              warmup_steps=0,
+              warmup_start_lr=0.0,
+              lr_decay_epochs=(216, 243),
+              lr_decay_gamma=0.1,
+              metric=None,
+              use_ema=False,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlex.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
+            eval_dataset(paddlex.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used for training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .001.
+            warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
+            warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
+            lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
+            lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
+            metric({'VOC', 'COCO', None}, optional):
+                Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
+            use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
+        """
+        if train_dataset.pos_num < len(train_dataset.file_list):
+            train_dataset.num_workers = 0
+            if train_batch_size != 1:
+                train_batch_size = 1
+                logging.warning(
+                    "Training RCNN models with negative samples only support batch size equals to 1 "
+                    "on a single gpu/cpu card, `train_batch_size` is forcibly set to 1."
+                )
+            nranks = paddle.distributed.get_world_size()
+            local_rank = paddle.distributed.get_rank()
+            # single card training
+            if nranks < 2 or local_rank == 0:
+                super(FasterRCNN, self).train(
+                    num_epochs, train_dataset, train_batch_size, eval_dataset,
+                    optimizer, save_interval_epochs, log_interval_steps,
+                    save_dir, pretrain_weights, learning_rate, warmup_steps,
+                    warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric,
+                    use_ema, early_stop, early_stop_patience, use_vdl,
+                    resume_checkpoint)
+        else:
+            super(FasterRCNN, self).train(
+                num_epochs, train_dataset, train_batch_size, eval_dataset,
+                optimizer, save_interval_epochs, log_interval_steps, save_dir,
+                pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
+                lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
+                early_stop_patience, use_vdl, resume_checkpoint)
+
     def _compose_batch_transform(self, transforms, mode='train'):
         if mode == 'train':
             default_batch_transforms = [
@@ -1762,6 +1833,84 @@ class MaskRCNN(BaseDetector):
         super(MaskRCNN, self).__init__(
             model_name='MaskRCNN', num_classes=num_classes, **params)
 
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=64,
+              eval_dataset=None,
+              optimizer=None,
+              save_interval_epochs=1,
+              log_interval_steps=10,
+              save_dir='output',
+              pretrain_weights='IMAGENET',
+              learning_rate=.001,
+              warmup_steps=0,
+              warmup_start_lr=0.0,
+              lr_decay_epochs=(216, 243),
+              lr_decay_gamma=0.1,
+              metric=None,
+              use_ema=False,
+              early_stop=False,
+              early_stop_patience=5,
+              use_vdl=True,
+              resume_checkpoint=None):
+        """
+        Train the model.
+        Args:
+            num_epochs(int): The number of epochs.
+            train_dataset(paddlex.dataset): Training dataset.
+            train_batch_size(int, optional): Total batch size among all cards used in training. Defaults to 64.
+            eval_dataset(paddlex.dataset, optional):
+                Evaluation dataset. If None, the model will not be evaluated during training process. Defaults to None.
+            optimizer(paddle.optimizer.Optimizer or None, optional):
+                Optimizer used for training. If None, a default optimizer is used. Defaults to None.
+            save_interval_epochs(int, optional): Epoch interval for saving the model. Defaults to 1.
+            log_interval_steps(int, optional): Step interval for printing training information. Defaults to 10.
+            save_dir(str, optional): Directory to save the model. Defaults to 'output'.
+            pretrain_weights(str or None, optional):
+                None or name/path of pretrained weights. If None, no pretrained weights will be loaded. Defaults to 'IMAGENET'.
+            learning_rate(float, optional): Learning rate for training. Defaults to .001.
+            warmup_steps(int, optional): The number of steps of warm-up training. Defaults to 0.
+            warmup_start_lr(float, optional): Start learning rate of warm-up training. Defaults to 0..
+            lr_decay_epochs(list or tuple, optional): Epoch milestones for learning rate decay. Defaults to (216, 243).
+            lr_decay_gamma(float, optional): Gamma coefficient of learning rate decay. Defaults to .1.
+            metric({'VOC', 'COCO', None}, optional):
+                Evaluation metric. If None, determine the metric according to the dataset format. Defaults to None.
+            use_ema(bool, optional): Whether to use exponential moving average strategy. Defaults to False.
+            early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
+            early_stop_patience(int, optional): Early stop patience. Defaults to 5.
+            use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
+            resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
+                If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
+                `pretrain_weights` can be set simultaneously. Defaults to None.
+        """
+        if train_dataset.pos_num < len(train_dataset.file_list):
+            train_dataset.num_workers = 0
+            if train_batch_size != 1:
+                train_batch_size = 1
+                logging.warning(
+                    "Training RCNN models with negative samples only support batch size equals to 1 "
+                    "on a single gpu/cpu card, `train_batch_size` is forcibly set to 1."
+                )
+            nranks = paddle.distributed.get_world_size()
+            local_rank = paddle.distributed.get_rank()
+            # single card training
+            if nranks < 2 or local_rank == 0:
+                super(MaskRCNN, self).train(
+                    num_epochs, train_dataset, train_batch_size, eval_dataset,
+                    optimizer, save_interval_epochs, log_interval_steps,
+                    save_dir, pretrain_weights, learning_rate, warmup_steps,
+                    warmup_start_lr, lr_decay_epochs, lr_decay_gamma, metric,
+                    use_ema, early_stop, early_stop_patience, use_vdl,
+                    resume_checkpoint)
+        else:
+            super(MaskRCNN, self).train(
+                num_epochs, train_dataset, train_batch_size, eval_dataset,
+                optimizer, save_interval_epochs, log_interval_steps, save_dir,
+                pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
+                lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
+                early_stop_patience, use_vdl, resume_checkpoint)
+
     def _compose_batch_transform(self, transforms, mode='train'):
         if mode == 'train':
             default_batch_transforms = [