Просмотр исходного кода

set batch_size to 1 if train rcnn with negative samples

will-jl944 4 лет назад
Родитель
Сommit
ce45d7df9b
1 измененных файлов с 7 добавлено и 0 удалено
  1. 7 0
      paddlex/cv/models/detector.py

+ 7 - 0
paddlex/cv/models/detector.py

@@ -193,6 +193,13 @@ class BaseDetector(BaseModel):
                 If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
                 `pretrain_weights` can be set simultaneously. Defaults to None.
         """
+        if train_dataset.pos_num < len(
+                train_dataset.file_list
+        ) and train_batch_size != 1 and 'RCNN' in self.__class__.__name__:
+            train_batch_size = 1
+            logging.warning(
+                "Training RCNN models with negative samples only support batch size equals to 1, "
+                "`train_batch_size` is forcibly set to 1.")
         if self.status == 'Infer':
             logging.error(
                 "Exported inference model does not support training.",