浏览代码

set batch_size to 1 if train rcnn with negative samples

will-jl944 4 年之前
父节点
当前提交
ce45d7df9b
共有 1 个文件被更改,包括 7 次插入0 次删除
  1. 7 0
      paddlex/cv/models/detector.py

+ 7 - 0
paddlex/cv/models/detector.py

@@ -193,6 +193,13 @@ class BaseDetector(BaseModel):
                 If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
                 `pretrain_weights` can be set simultaneously. Defaults to None.
         """
+        if train_dataset.pos_num < len(
+                train_dataset.file_list
+        ) and train_batch_size != 1 and 'RCNN' in self.__class__.__name__:
+            train_batch_size = 1
+            logging.warning(
+                "Training RCNN models with negative samples only support batch size equals to 1, "
+                "`train_batch_size` is forcibly set to 1.")
         if self.status == 'Infer':
             logging.error(
                 "Exported inference model does not support training.",