Forráskód Böngészése

fix batch_size_per_gpu for ppyolo

FlyingQianMM 5 éve
szülő
commit
c1a0754048
1 módosított fájl, 12 hozzáadás és 9 törlés
  1. 12 9
      paddlex/cv/models/ppyolo.py

+ 12 - 9
paddlex/cv/models/ppyolo.py

@@ -125,7 +125,8 @@ class PPYOLO(BaseAPI):
         self.with_dcn_v2 = with_dcn_v2
 
         if paddle.__version__ < '1.8.4' and paddle.__version__ != '0.0.0':
-            raise Exception("PPYOLO requires paddlepaddle or paddlepaddle-gpu >= 1.8.4")
+            raise Exception(
+                "PPYOLO requires paddlepaddle or paddlepaddle-gpu >= 1.8.4")
 
     def _get_backbone(self, backbone_name):
         if backbone_name.startswith('ResNet50_vd'):
@@ -162,8 +163,7 @@ class PPYOLO(BaseAPI):
             use_matrix_nms=self.use_matrix_nms,
             use_fine_grained_loss=self.use_fine_grained_loss,
             use_iou_loss=self.use_iou_loss,
-            batch_size=self.batch_size_per_gpu
-            if hasattr(self, 'batch_size_per_gpu') else 8)
+            batch_size=getattr(self, 'batch_size_per_gpu', 8))
         if mode == 'train' and self.use_iou_loss or self.use_iou_aware:
             model.max_height = self.max_height
             model.max_width = self.max_width
@@ -302,8 +302,7 @@ class PPYOLO(BaseAPI):
         self.use_ema = use_ema
         self.ema_decay = ema_decay
 
-        self.batch_size_per_gpu = int(train_batch_size /
-                                      paddlex.env_info['num'])
+        self.batch_size_per_gpu = self._get_single_card_bs(train_batch_size)
         if self.use_fine_grained_loss:
             for transform in train_dataset.transforms.transforms:
                 if isinstance(transform, paddlex.det.transforms.Resize):
@@ -451,7 +450,11 @@ class PPYOLO(BaseAPI):
         return evaluate_metrics
 
     @staticmethod
-    def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
+    def _preprocess(images,
+                    transforms,
+                    model_type,
+                    class_name,
+                    thread_pool=None):
         arrange_transforms(
             model_type=model_type,
             class_name=class_name,
@@ -546,9 +549,9 @@ class PPYOLO(BaseAPI):
 
         if transforms is None:
             transforms = self.test_transforms
-        im, im_size = PPYOLO._preprocess(img_file_list, transforms,
-                                         self.model_type,
-                                         self.__class__.__name__, self.thread_pool)
+        im, im_size = PPYOLO._preprocess(
+            img_file_list, transforms, self.model_type,
+            self.__class__.__name__, self.thread_pool)
 
         with fluid.scope_guard(self.scope):
             result = self.exe.run(self.test_prog,