Explorar el Código

Merge pull request #47 from FlyingQianMM/develop_qh

batch_size is forced to be set to 1 in RCNN
Jason hace 5 años
padre
commit
e3de4b9912
Se han modificado 3 ficheros con 14 adiciones y 5 borrados
  1. 2 2
      docs/apis/models.md
  2. 6 2
      paddlex/cv/models/faster_rcnn.py
  3. 6 1
      paddlex/cv/models/mask_rcnn.py

+ 2 - 2
docs/apis/models.md

@@ -228,7 +228,7 @@ paddlex.det.FasterRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspec
 > **参数:**
 >
 > > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
-> > - **batch_size** (int): 验证数据批大小。默认为1。
+> > - **batch_size** (int): 验证数据批大小。默认为1。当前只支持设置为1。
 > > - **epoch_id** (int): 当前评估模型所在的训练轮数。
 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC'; 如为COCODetection,则`metric`为'COCO'。
 > > - **return_details** (bool): 是否返回详细信息。默认值为False。
@@ -309,7 +309,7 @@ paddlex.det.MaskRCNN(num_classes=81, backbone='ResNet50', with_fpn=True, aspect_
 > **参数:**
 >
 > > - **eval_dataset** (paddlex.datasets): 验证数据读取器。
-> > - **batch_size** (int): 验证数据批大小。默认为1。
+> > - **batch_size** (int): 验证数据批大小。默认为1。当前只支持设置为1。
 > > - **epoch_id** (int): 当前评估模型所在的训练轮数。
 > > - **metric** (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,根据用户传入的Dataset自动选择,如为VOCDetection,则`metric`为'VOC'; 如为COCODetection,则`metric`为'COCO'。
 > > - **return_details** (bool): 是否返回详细信息。默认值为False。

+ 6 - 2
paddlex/cv/models/faster_rcnn.py

@@ -259,7 +259,7 @@ class FasterRCNN(BaseAPI):
 
         Args:
             eval_dataset (paddlex.datasets): 验证数据读取器。
-            batch_size (int): 验证数据批大小。默认为1。
+            batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
             epoch_id (int): 当前评估模型所在的训练轮数。
             metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
                 根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
@@ -288,7 +288,11 @@ class FasterRCNN(BaseAPI):
                         "eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
                     )
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
-
+        if batch_size > 1:
+            batch_size = 1
+            logging.warning(
+                "Faster RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
+            )
         dataset = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)
 

+ 6 - 1
paddlex/cv/models/mask_rcnn.py

@@ -225,7 +225,7 @@ class MaskRCNN(FasterRCNN):
 
         Args:
             eval_dataset (paddlex.datasets): 验证数据读取器。
-            batch_size (int): 验证数据批大小。默认为1。
+            batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
             epoch_id (int): 当前评估模型所在的训练轮数。
             metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
                 根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
@@ -253,6 +253,11 @@ class MaskRCNN(FasterRCNN):
                     raise Exception(
                         "eval_dataset should be datasets.COCODetection.")
         assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
+        if batch_size > 1:
+            batch_size = 1
+            logging.warning(
+                "Mask RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
+            )
         data_generator = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)