浏览代码

fix benchmark bug

zhouchangda 1 年之前
父节点
当前提交
5704df5765

+ 3 - 2
paddlex/repo_apis/PaddleClas_api/cls/model.py

@@ -62,7 +62,6 @@ class ClsModel(BaseModel):
         with self._create_new_config_file() as config_path:
             # Update YAML config file
             config = self.config.copy()
-            config._update_amp(amp)
             config.update_device(device)
             config._update_to_static(dy2st)
             config._update_use_vdl(use_vdl)
@@ -110,7 +109,9 @@ class ClsModel(BaseModel):
                     config.update_seed(seed)
                 if envs is not None:
                     for env_name, env_value in envs.items():
-                        os.environ[env_name] = env_value
+                        os.environ[env_name] = str(env_value)
+            else:
+                config._update_amp(amp)
 
             config.dump(config_path)
             self._assert_empty_kwargs(kwargs)

+ 0 - 1
paddlex/repo_apis/PaddleDetection_api/configs/RT-DETR-R50.yaml

@@ -120,7 +120,6 @@ HybridEncoder:
     dim_feedforward: 1024
     dropout: 0.
     activation: 'gelu'
-  expansion: 0.5
   depth_mult: 1.0
 
 RTDETRTransformer:

+ 1 - 1
paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py

@@ -111,7 +111,7 @@ class InstanceSegModel(BaseModel):
                 cli_args.append(CLIArgument('--amp'))
             if envs is not None:
                 for env_name, env_value in envs.items():
-                    os.environ[env_name] = env_value
+                    os.environ[env_name] = str(env_value)
             # set seed to 0 for benchmark mode by enable_ce
             cli_args.append(CLIArgument('--enable_ce', True))
         else:

+ 2 - 2
paddlex/repo_apis/PaddleDetection_api/object_det/config.py

@@ -321,8 +321,8 @@ class DetConfig(BaseConfig, PPDetConfigMixin):
             shuffle (bool): whether or not to shuffle the data
         """
         assert isinstance(shuffle, bool), "shuffle should be a bool"
-        self.update({'TrainReader': {'shuffle': f'{shuffle}'}})
-        self.update({'EvalReader': {'shuffle': f'{shuffle}'}})
+        self.update({'TrainReader': {'shuffle': shuffle}})
+        self.update({'EvalReader': {'shuffle': shuffle}})
 
     def update_weights(self, weight_path: str):
         """update model weight

+ 1 - 1
paddlex/repo_apis/PaddleDetection_api/object_det/model.py

@@ -112,7 +112,7 @@ class DetModel(BaseModel):
                 cli_args.append(CLIArgument('--amp'))
             if envs is not None:
                 for env_name, env_value in envs.items():
-                    os.environ[env_name] = env_value
+                    os.environ[env_name] = str(env_value)
             # set seed to 0 for benchmark mode by enable_ce
             cli_args.append(CLIArgument('--enable_ce', True))
         else:

+ 3 - 3
paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

@@ -307,8 +307,8 @@ class TextRecConfig(BaseConfig):
         """
         assert isinstance(shuffle, bool), "shuffle should be a bool"
         _cfg = {
-            f'Train.loader.shuffle': f'{shuffle}',
-            f'Train.loader.shuffle': f'{shuffle}',
+            f'Train.loader.shuffle': shuffle,
+            f'Train.loader.shuffle': shuffle,
         }
         self.update(_cfg)
 
@@ -327,7 +327,7 @@ class TextRecConfig(BaseConfig):
             seed (int): the random seed value to set
         """
         assert isinstance(seed, int), "seed should be an int"
-        self.update({'Global.seed': f'{seed}'})
+        self.update({'Global.seed': seed})
 
     def _update_eval_interval_by_epoch(self, eval_interval):
         """update eval interval(by epoch)

+ 1 - 1
paddlex/repo_apis/PaddleOCR_api/text_rec/model.py

@@ -122,7 +122,7 @@ class TextRecModel(BaseModel):
                 config.update_seed(seed)
             if envs is not None:
                 for env_name, env_value in envs.items():
-                    os.environ[env_name] = env_value
+                    os.environ[env_name] = str(env_value)
 
         self._assert_empty_kwargs(kwargs)
 

+ 1 - 1
paddlex/repo_apis/PaddleSeg_api/seg/model.py

@@ -147,7 +147,7 @@ class SegModel(BaseModel):
                 cli_args.append(CLIArgument('--amp_level', amp))
             if envs is not None:
                 for env_name, env_value in envs.items():
-                    os.environ[env_name] = env_value
+                    os.environ[env_name] = str(env_value)
         else:
             if amp is not None:
                 if amp != 'OFF':

+ 1 - 1
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -103,7 +103,7 @@ class TSModel(BaseModel):
                 cli_args.append(CLIArgument('--num_workers', num_workers))
             if envs is not None:
                 for env_name, env_value in envs.items():
-                    os.environ[env_name] = env_value
+                    os.environ[env_name] = str(env_value)
         else:
             if num_workers is not None:
                 cli_args.append(CLIArgument('--num_workers', num_workers))