瀏覽代碼

update supported params for benchmark

zhouchangda 1 年之前
父節點
當前提交
e3f0bb5e61

+ 5 - 4
paddlex/modules/base/trainer/trainer.py

@@ -47,12 +47,10 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
         super().__init__()
         self.global_config = config.Global
         self.train_config = config.Train
+        self.benchmark_config = config.get('Benchmark', None)
 
         self.deamon = self.build_deamon(self.global_config)
         self.pdx_config, self.pdx_model = build_model(self.global_config.model)
-        #TODO: Control by configuration to support benchmark
-        self.pdx_config.update_log_ranks(self.get_device())
-        self.pdx_config.disable_print_mem_info()
 
     def train(self, *args, **kwargs):
         """execute model training
@@ -60,7 +58,10 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
         os.makedirs(self.global_config.output, exist_ok=True)
         self.update_config()
         self.dump_config()
-        train_result = self.pdx_model.train(**self.get_train_kwargs())
+        train_args = self.get_train_kwargs()
+        if self.benchmark_config is not None:
+            train_args.update({'benchmark': self.benchmark_config})
+        train_result = self.pdx_model.train(**train_args)
         assert train_result.returncode == 0, f"Encountered an unexpected error({train_result.returncode}) in \
 training!"
 

+ 44 - 15
paddlex/repo_apis/PaddleClas_api/cls/config.py

@@ -218,24 +218,55 @@ indicating that no pretrained model to be used."
         ]
         self.update(_cfg)
 
-    def enable_shared_memory(self):
-        """enable shared memory setting of train and eval dataloader
+    def update_shared_memory(self, shared_memeory: bool):
+        """update shared memory setting of train and eval dataloader
+        
+        Args:
+            shared_memeory (bool): whether or not to use shared memory
+        """
+        assert isinstance(shared_memeory,
+                          bool), "shared_memeory should be a bool"
+        _cfg = [
+            f'DataLoader.Train.loader.use_shared_memory={shared_memeory}',
+            f'DataLoader.Eval.loader.use_shared_memory={shared_memeory}',
+        ]
+        self.update(_cfg)
+
+    def update_shuffle(self, shuffle: bool):
+        """update shuffle setting of train and eval dataloader
+        
+        Args:
+            shuffle (bool): whether or not to shuffle the data
         """
+        assert isinstance(shuffle, bool), "shuffle should be a bool"
         _cfg = [
-            f'DataLoader.Train.loader.use_shared_memory=True',
-            f'DataLoader.Eval.loader.use_shared_memory=True',
+            f'DataLoader.Train.loader.shuffle={shuffle}',
+            f'DataLoader.Eval.loader.shuffle={shuffle}',
         ]
         self.update(_cfg)
 
-    def disable_shared_memory(self):
-        """disable shared memory setting of train and eval dataloader
+    def update_dali(self, dali: bool):
+        """enable DALI setting of train and eval dataloader
+        
+        Args:
+            dali (bool): whether or not to use DALI
         """
+        assert isinstance(dali, bool), "dali should be a bool"
         _cfg = [
-            f'DataLoader.Train.loader.use_shared_memory=False',
-            f'DataLoader.Eval.loader.use_shared_memory=False',
+            f'Global.use_dali={dali}',
+            f'Global.use_dali={dali}',
         ]
         self.update(_cfg)
 
+    def update_seed(self, seed: int):
+        """update seed
+
+        Args:
+            seed (int): the random seed value to set
+        """
+        _cfg = [f'Global.seed={seed}']
+        self.update(_cfg)
+
     def update_device(self, device: str):
         """update device setting
 
@@ -331,13 +362,11 @@ indicating that no pretrained model to be used."
         log_ranks = device.split(':')[1]
         self.update([f'Global.log_ranks="{log_ranks}"'])
 
-    def enable_print_mem_info(self):
-        """print memory info"""
-        self.update([f'Global.print_mem_info=True'])
-
-    def disable_print_mem_info(self):
-        """do not print memory info"""
-        self.update([f'Global.print_mem_info=False'])
+    def update_print_mem_info(self, print_mem_info: bool):
+        """setting print memory info"""
+        assert isinstance(print_mem_info,
+                          bool), "print_mem_info should be a bool"
+        self.update([f'Global.print_mem_info={print_mem_info}'])
 
     def _update_predict_img(self, infer_img: str, infer_list: str=None):
         """update image to be predicted

+ 24 - 1
paddlex/repo_apis/PaddleClas_api/cls/model.py

@@ -82,7 +82,6 @@ class ClsModel(BaseModel):
             config._update_output_dir(save_dir)
             if num_workers is not None:
                 config.update_num_workers(num_workers)
-            config.dump(config_path)
 
             cli_args = []
             do_eval = kwargs.pop('do_eval', True)
@@ -90,6 +89,30 @@ class ClsModel(BaseModel):
             if profile is not None:
                 cli_args.append(CLIArgument('--profiler_options', profile))
 
+            # Benchmarking mode settings
+            benchmark = kwargs.pop('benchmark', None)
+            if benchmark is not None:
+                envs = benchmark.get('env', None)
+                seed = benchmark.get('seed', None)
+                do_eval = benchmark.get('do_eval', False)
+                num_workers = benchmark.get('num_workers', None)
+                config.update_log_ranks(device)
+                config._update_amp(benchmark.get('amp', None))
+                config.update_dali(benchmark.get('dali', False))
+                config.update_shuffle(benchmark.get('shuffle', False))
+                config.update_shared_memory(
+                    benchmark.get('shared_memory', True))
+                config.update_print_mem_info(
+                    benchmark.get('print_mem_info', True))
+                if num_workers is not None:
+                    config.update_num_workers(num_workers)
+                if seed is not None:
+                    config.update_seed(seed)
+                if envs is not None:
+                    for env_name, env_value in envs.items():
+                        os.environ[env_name] = env_value
+
+            config.dump(config_path)
             self._assert_empty_kwargs(kwargs)
 
             return self.runner.train(

+ 2 - 21
paddlex/repo_apis/PaddleDetection_api/instance_seg/config.py

@@ -15,9 +15,10 @@
 from ...base import BaseConfig
 from ....utils.misc import abspath
 from ..config_helper import PPDetConfigMixin
+from ..object_det.config import DetConfig
 
 
-class InstanceSegConfig(BaseConfig, PPDetConfigMixin):
+class InstanceSegConfig(DetConfig):
     """ InstanceSegConfig """
 
     def load(self, config_path: str):
@@ -287,14 +288,6 @@ class InstanceSegConfig(BaseConfig, PPDetConfigMixin):
         log_ranks = device.split(':')[1]
         self.update({'log_ranks': log_ranks})
 
-    def enable_print_mem_info(self):
-        """print memory info"""
-        self.update({'print_mem_info': True})
-
-    def disable_print_mem_info(self):
-        """do not print memory info"""
-        self.update({'print_mem_info': False})
-
     def update_weights(self, weight_path: str):
         """update model weight
 
@@ -339,18 +332,6 @@ class InstanceSegConfig(BaseConfig, PPDetConfigMixin):
         """
         self['worker_num'] = num_workers
 
-    def enable_shared_memory(self):
-        """enable shared memory setting of train and eval dataloader
-        """
-        self.update({'TrainReader': {'use_shared_memory': True}})
-        self.update({'EvalReader': {'use_shared_memory': True}})
-
-    def disable_shared_memory(self):
-        """disable shared memory setting of train and eval dataloader
-        """
-        self.update({'TrainReader': {'use_shared_memory': False}})
-        self.update({'EvalReader': {'use_shared_memory': False}})
-
     def update_static_assigner_epochs(self, static_assigner_epochs: int):
         """update static assigner epochs value
 

+ 28 - 6
paddlex/repo_apis/PaddleDetection_api/instance_seg/model.py

@@ -75,9 +75,6 @@ class InstanceSegModel(BaseModel):
             cli_args.append(CLIArgument('--resume', resume_dir))
         if dy2st:
             cli_args.append(CLIArgument('--to_static'))
-        if amp != 'OFF' and amp is not None:
-            # TODO: consider amp is O1 or O2 in ppdet
-            cli_args.append(CLIArgument('--amp'))
         if num_workers is not None:
             config.update_num_workers(num_workers)
         if save_dir is None:
@@ -90,14 +87,39 @@ class InstanceSegModel(BaseModel):
             cli_args.append(CLIArgument('--vdl_log_dir', save_dir))
 
         do_eval = kwargs.pop('do_eval', True)
+        enable_ce = kwargs.pop('enable_ce', None)
 
         profile = kwargs.pop('profile', None)
         if profile is not None:
             cli_args.append(CLIArgument('--profiler_options', profile))
 
-        enable_ce = kwargs.pop('enable_ce', None)
-        if enable_ce is not None:
-            cli_args.append(CLIArgument('--enable_ce', enable_ce))
+        # Benchmarking mode settings
+        benchmark = kwargs.pop('benchmark', None)
+        if benchmark is not None:
+            envs = benchmark.get('env', None)
+            amp = benchmark.get('amp', None)
+            do_eval = benchmark.get('do_eval', False)
+            num_workers = benchmark.get('num_workers', None)
+            config.update_log_ranks(device)
+            config.update_shuffle(benchmark.get('shuffle', False))
+            config.update_shared_memory(benchmark.get('shared_memory', True))
+            config.update_print_mem_info(benchmark.get('print_mem_info', True))
+            if num_workers is not None:
+                config.update_num_workers(num_workers)
+            if amp == 'O1':
+                # TODO: ppdet only support ampO1
+                cli_args.append(CLIArgument('--amp'))
+            if envs is not None:
+                for env_name, env_value in envs.items():
+                    os.environ[env_name] = env_value
+            # set seed to 0 for benchmark mode by enable_ce
+            cli_args.append(CLIArgument('--enable_ce', True))
+        else:
+            if amp != 'OFF' and amp is not None:
+                # TODO: consider amp is O1 or O2 in ppdet
+                cli_args.append(CLIArgument('--amp'))
+            if enable_ce is not None:
+                cli_args.append(CLIArgument('--enable_ce', enable_ce))
 
         self._assert_empty_kwargs(kwargs)
 

+ 24 - 18
paddlex/repo_apis/PaddleDetection_api/object_det/config.py

@@ -298,13 +298,31 @@ class DetConfig(BaseConfig, PPDetConfigMixin):
         log_ranks = device.split(':')[1]
         self.update({'log_ranks': log_ranks})
 
-    def enable_print_mem_info(self):
-        """print memory info"""
-        self.update({'print_mem_info': True})
+    def update_print_mem_info(self, print_mem_info: bool):
+        """setting print memory info"""
+        assert isinstance(print_mem_info,
+                          bool), "print_mem_info should be a bool"
+        self.update({'print_mem_info': f'{print_mem_info}'})
+
+    def update_shared_memory(self, shared_memeory: bool):
+        """update shared memory setting of train and eval dataloader
+        
+        Args:
+            shared_memeory (bool): whether or not to use shared memory
+        """
+        assert isinstance(shared_memeory,
+                          bool), "shared_memeory should be a bool"
+        self.update({'print_mem_info': f'{shared_memeory}'})
 
-    def disable_print_mem_info(self):
-        """do not print memory info"""
-        self.update({'print_mem_info': False})
+    def update_shuffle(self, shuffle: bool):
+        """update shuffle setting of train and eval dataloader
+        
+        Args:
+            shuffle (bool): whether or not to shuffle the data
+        """
+        assert isinstance(shuffle, bool), "shuffle should be a bool"
+        self.update({'TrainReader': {'shuffle': f'{shuffle}'}})
+        self.update({'EvalReader': {'shuffle': f'{shuffle}'}})
 
     def update_weights(self, weight_path: str):
         """update model weight
@@ -350,18 +368,6 @@ class DetConfig(BaseConfig, PPDetConfigMixin):
         """
         self['worker_num'] = num_workers
 
-    def enable_shared_memory(self):
-        """enable shared memory setting of train and eval dataloader
-        """
-        self.update({'TrainReader': {'use_shared_memory': True}})
-        self.update({'EvalReader': {'use_shared_memory': True}})
-
-    def disable_shared_memory(self):
-        """disable shared memory setting of train and eval dataloader
-        """
-        self.update({'TrainReader': {'use_shared_memory': False}})
-        self.update({'EvalReader': {'use_shared_memory': False}})
-
     def _recursively_set(self, config: dict, update_dict: dict):
         """recursively set config
 

+ 29 - 7
paddlex/repo_apis/PaddleDetection_api/object_det/model.py

@@ -76,9 +76,6 @@ class DetModel(BaseModel):
             cli_args.append(CLIArgument('--resume', resume_dir))
         if dy2st:
             cli_args.append(CLIArgument('--to_static'))
-        if amp != 'OFF' and amp is not None:
-            # TODO: consider amp is O1 or O2 in ppdet
-            cli_args.append(CLIArgument('--amp'))
         if num_workers is not None:
             config.update_num_workers(num_workers)
         if save_dir is None:
@@ -91,16 +88,41 @@ class DetModel(BaseModel):
             cli_args.append(CLIArgument('--vdl_log_dir', save_dir))
 
         do_eval = kwargs.pop('do_eval', True)
+        enable_ce = kwargs.pop('enable_ce', None)
 
         profile = kwargs.pop('profile', None)
         if profile is not None:
             cli_args.append(CLIArgument('--profiler_options', profile))
 
-        enable_ce = kwargs.pop('enable_ce', None)
-        if enable_ce is not None:
-            cli_args.append(CLIArgument('--enable_ce', enable_ce))
+        # Benchmarking mode settings
+        benchmark = kwargs.pop('benchmark', None)
+        if benchmark is not None:
+            envs = benchmark.get('env', None)
+            amp = benchmark.get('amp', None)
+            do_eval = benchmark.get('do_eval', False)
+            num_workers = benchmark.get('num_workers', None)
+            config.update_log_ranks(device)
+            config.update_shuffle(benchmark.get('shuffle', False))
+            config.update_shared_memory(benchmark.get('shared_memory', True))
+            config.update_print_mem_info(benchmark.get('print_mem_info', True))
+            if num_workers is not None:
+                config.update_num_workers(num_workers)
+            if amp == 'O1':
+                # TODO: ppdet only support ampO1
+                cli_args.append(CLIArgument('--amp'))
+            if envs is not None:
+                for env_name, env_value in envs.items():
+                    os.environ[env_name] = env_value
+            # set seed to 0 for benchmark mode by enable_ce
+            cli_args.append(CLIArgument('--enable_ce', True))
+        else:
+            if amp != 'OFF' and amp is not None:
+                # TODO: consider amp is O1 or O2 in ppdet
+                cli_args.append(CLIArgument('--amp'))
+            if enable_ce is not None:
+                cli_args.append(CLIArgument('--enable_ce', enable_ce))
 
-        self._assert_empty_kwargs(kwargs)
+            self._assert_empty_kwargs(kwargs)
 
         with self._create_new_config_file() as config_path:
             config.dump(config_path)

+ 0 - 1
paddlex/repo_apis/PaddleOCR_api/table_rec/model.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 from ....utils import logging

+ 0 - 1
paddlex/repo_apis/PaddleOCR_api/text_det/model.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 from ...base.utils.arg import CLIArgument

+ 48 - 24
paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

@@ -279,13 +279,55 @@ class TextRecConfig(BaseConfig):
         log_ranks = device.split(':')[1]
         self.update({'Global.log_ranks': log_ranks})
 
-    def enable_print_mem_info(self):
-        """print memory info"""
-        self.update({'Global.print_mem_info': True})
+    def update_print_mem_info(self, print_mem_info: bool):
+        """setting print memory info"""
+        assert isinstance(print_mem_info,
+                          bool), "print_mem_info should be a bool"
+        self.update({'Global.print_mem_info': f'{print_mem_info}'})
+
+    def update_shared_memory(self, shared_memeory: bool):
+        """update shared memory setting of train and eval dataloader
+        
+        Args:
+            shared_memeory (bool): whether or not to use shared memory
+        """
+        assert isinstance(shared_memeory,
+                          bool), "shared_memeory should be a bool"
+        _cfg = {
+            'Train.loader.use_shared_memory': f'{shared_memeory}',
+            'Train.loader.use_shared_memory': f'{shared_memeory}',
+        }
+        self.update(_cfg)
+
+    def update_shuffle(self, shuffle: bool):
+        """update shuffle setting of train and eval dataloader
+        
+        Args:
+            shuffle (bool): whether or not to shuffle the data
+        """
+        assert isinstance(shuffle, bool), "shuffle should be a bool"
+        _cfg = {
+            f'Train.loader.shuffle': f'{shuffle}',
+            f'Train.loader.shuffle': f'{shuffle}',
+        }
+        self.update(_cfg)
+
+    def update_cal_metrics(self, cal_metrics: bool):
+        """update calculate metrics setting
+        Args:
+            cal_metrics (bool): whether or not to calculate metrics during train
+        """
+        assert isinstance(cal_metrics, bool), "cal_metrics should be a bool"
+        self.update({'Global.cal_metric_during_train': f'{cal_metrics}'})
 
-    def disable_print_mem_info(self):
-        """do not print memory info"""
-        self.update({'Global.print_mem_info': False})
+    def update_seed(self, seed: int):
+        """update seed
+
+        Args:
+            seed (int): the random seed value to set
+        """
+        assert isinstance(seed, int), "seed should be an int"
+        self.update({'Global.seed': f'{seed}'})
 
     def _update_eval_interval_by_epoch(self, eval_interval):
         """update eval interval(by epoch)
@@ -370,24 +412,6 @@ class TextRecConfig(BaseConfig):
             else:
                 self['Eval']['loader']['num_workers'] = num_workers
 
-    def enable_shared_memory(self):
-        """enable shared memory setting of train and eval dataloader
-        """
-        _cfg = {
-            'Train.loader.use_shared_memory': True,
-            'Train.loader.use_shared_memory': True,
-        }
-        self.update(_cfg)
-
-    def disable_shared_memory(self):
-        """disable shared memory setting of train and eval dataloader
-        """
-        _cfg = {
-            'Train.loader.use_shared_memory': False,
-            'Train.loader.use_shared_memory': False,
-        }
-        self.update(_cfg)
-
     def _get_model_type(self) -> str:
         """get model type
 

+ 21 - 1
paddlex/repo_apis/PaddleOCR_api/text_rec/model.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 from ...base import BaseModel
@@ -104,6 +103,27 @@ class TextRecModel(BaseModel):
         if profile is not None:
             cli_args.append(CLIArgument('--profiler_options', profile))
 
+        # Benchmarking mode settings
+        benchmark = kwargs.pop('benchmark', None)
+        if benchmark is not None:
+            envs = benchmark.get('env', None)
+            seed = benchmark.get('seed', None)
+            do_eval = benchmark.get('do_eval', False)
+            num_workers = benchmark.get('num_workers', None)
+            config.update_log_ranks(device)
+            config._update_amp(benchmark.get('amp', None))
+            config.update_shuffle(benchmark.get('shuffle', False))
+            config.update_cal_metrics(benchmark.get('cal_metrics', True))
+            config.update_shared_memory(benchmark.get('shared_memory', True))
+            config.update_print_mem_info(benchmark.get('print_mem_info', True))
+            if num_workers is not None:
+                config.update_num_workers(num_workers)
+            if seed is not None:
+                config.update_seed(seed)
+            if envs is not None:
+                for env_name, env_value in envs.items():
+                    os.environ[env_name] = env_value
+
         self._assert_empty_kwargs(kwargs)
 
         with self._create_new_config_file() as config_path:

+ 5 - 7
paddlex/repo_apis/PaddleSeg_api/base_seg_config.py

@@ -67,13 +67,11 @@ class BaseSegConfig(BaseConfig):
         log_ranks = device.split(':')[1]
         self.set_val('log_ranks', log_ranks)
 
-    def enable_print_mem_info(self):
-        """print memory info"""
-        self.set_val('print_mem_info', True)
-
-    def disable_print_mem_info(self):
-        """do not print memory info"""
-        self.set_val('print_mem_info', False)
+    def update_print_mem_info(self, print_mem_info: bool):
+        """setting print memory info"""
+        assert isinstance(print_mem_info,
+                          bool), "print_mem_info should be a bool"
+        self.set_val('print_mem_info', print_mem_info)
 
     def update_pretrained_weights(self, weight_path, is_backbone=False):
         """ update_pretrained_weights """

+ 40 - 16
paddlex/repo_apis/PaddleSeg_api/seg/model.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 from ...base import BaseModel
@@ -96,14 +95,6 @@ class SegModel(BaseModel):
         if dy2st:
             config.update_dy2st(dy2st)
 
-        if amp is not None:
-            if amp != 'OFF':
-                cli_args.append(CLIArgument('--precision', 'fp16'))
-                cli_args.append(CLIArgument('--amp_level', amp))
-
-        if num_workers is not None:
-            cli_args.append(CLIArgument('--num_workers', num_workers))
-
         if use_vdl:
             cli_args.append(CLIArgument('--use_vdl'))
 
@@ -119,6 +110,8 @@ class SegModel(BaseModel):
             cli_args.append(CLIArgument('--save_interval', save_interval))
 
         do_eval = kwargs.pop('do_eval', True)
+        repeats = kwargs.pop('repeats', None)
+        seed = kwargs.pop('seed', None)
 
         profile = kwargs.pop('profile', None)
         if profile is not None:
@@ -128,13 +121,44 @@ class SegModel(BaseModel):
         if log_iters is not None:
             cli_args.append(CLIArgument('--log_iters', log_iters))
 
-        repeats = kwargs.pop('repeats', None)
-        if repeats is not None:
-            cli_args.append(CLIArgument('--repeats', repeats))
-
-        seed = kwargs.pop('seed', None)
-        if seed is not None:
-            cli_args.append(CLIArgument('--seed', seed))
+        # Benchmarking mode settings
+        benchmark = kwargs.pop('benchmark', None)
+        if benchmark is not None:
+            envs = benchmark.get('env', None)
+            seed = benchmark.get('seed', None)
+            repeats = benchmark.get('repeats', None)
+            do_eval = benchmark.get('do_eval', False)
+            num_workers = benchmark.get('num_workers', None)
+            config.update_log_ranks(device)
+            amp = benchmark.get('amp', None)
+            config.update_print_mem_info(benchmark.get('print_mem_info', True))
+            if repeats is not None:
+                assert isinstance(repeats, int), 'repeats must be an integer.'
+                cli_args.append(CLIArgument('--repeats', repeats))
+            if num_workers is not None:
+                assert isinstance(num_workers,
+                                  int), 'num_workers must be an integer.'
+                cli_args.append(CLIArgument('--num_workers', num_workers))
+            if seed is not None:
+                assert isinstance(seed, int), 'seed must be an integer.'
+                cli_args.append(CLIArgument('--seed', seed))
+            if amp in ['O1', 'O2']:
+                cli_args.append(CLIArgument('--precision', 'fp16'))
+                cli_args.append(CLIArgument('--amp_level', amp))
+            if envs is not None:
+                for env_name, env_value in envs.items():
+                    os.environ[env_name] = env_value
+        else:
+            if amp is not None:
+                if amp != 'OFF':
+                    cli_args.append(CLIArgument('--precision', 'fp16'))
+                    cli_args.append(CLIArgument('--amp_level', amp))
+            if num_workers is not None:
+                cli_args.append(CLIArgument('--num_workers', num_workers))
+            if repeats is not None:
+                cli_args.append(CLIArgument('--repeats', repeats))
+            if seed is not None:
+                cli_args.append(CLIArgument('--seed', seed))
 
         self._assert_empty_kwargs(kwargs)
 

+ 5 - 7
paddlex/repo_apis/PaddleTS_api/ts_base/config.py

@@ -141,13 +141,11 @@ class BaseTSConfig(BaseConfig):
         # PaddleTS does not support multi-device training currently.
         pass
 
-    def enable_print_mem_info(self):
-        """print memory info"""
-        self.update({'print_mem_info': True})
-
-    def disable_print_mem_info(self):
-        """do not print memory info"""
-        self.update({'print_mem_info': False})
+    def update_print_mem_info(self, print_mem_info: bool):
+        """setting print memory info"""
+        assert isinstance(print_mem_info,
+                          bool), "print_mem_info should be a bool"
+        self.update({'print_mem_info': print_mem_info})
 
     def update_dataset(self, dataset_dir: str, dataset_type: str=None):
         """update dataset settings

+ 18 - 4
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 from ...base import BaseModel
@@ -84,9 +83,6 @@ class TSModel(BaseModel):
             device_type, _ = self.runner.parse_device(device)
             cli_args.append(CLIArgument('--device', device_type))
 
-        if num_workers is not None:
-            cli_args.append(CLIArgument('--num_workers', num_workers))
-
         if save_dir is not None:
             save_dir = abspath(save_dir)
         else:
@@ -94,6 +90,24 @@ class TSModel(BaseModel):
             save_dir = abspath(os.path.join('output', 'train'))
         cli_args.append(CLIArgument('--save_dir', save_dir))
 
+        # Benchmarking mode settings
+        benchmark = kwargs.pop('benchmark', None)
+        if benchmark is not None:
+            envs = benchmark.get('env', None)
+            num_workers = benchmark.get('num_workers', None)
+            config.update_log_ranks(device)
+            config.update_print_mem_info(benchmark.get('print_mem_info', True))
+            if num_workers is not None:
+                assert isinstance(num_workers,
+                                  int), "num_workers must be an integer"
+                cli_args.append(CLIArgument('--num_workers', num_workers))
+            if envs is not None:
+                for env_name, env_value in envs.items():
+                    os.environ[env_name] = env_value
+        else:
+            if num_workers is not None:
+                cli_args.append(CLIArgument('--num_workers', num_workers))
+
         self._assert_empty_kwargs(kwargs)
 
         with self._create_new_config_file() as config_path: