|
|
@@ -12,7 +12,6 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-
|
|
|
import os
|
|
|
|
|
|
from ...base import BaseModel
|
|
|
@@ -96,14 +95,6 @@ class SegModel(BaseModel):
|
|
|
if dy2st:
|
|
|
config.update_dy2st(dy2st)
|
|
|
|
|
|
- if amp is not None:
|
|
|
- if amp != 'OFF':
|
|
|
- cli_args.append(CLIArgument('--precision', 'fp16'))
|
|
|
- cli_args.append(CLIArgument('--amp_level', amp))
|
|
|
-
|
|
|
- if num_workers is not None:
|
|
|
- cli_args.append(CLIArgument('--num_workers', num_workers))
|
|
|
-
|
|
|
if use_vdl:
|
|
|
cli_args.append(CLIArgument('--use_vdl'))
|
|
|
|
|
|
@@ -119,6 +110,8 @@ class SegModel(BaseModel):
|
|
|
cli_args.append(CLIArgument('--save_interval', save_interval))
|
|
|
|
|
|
do_eval = kwargs.pop('do_eval', True)
|
|
|
+ repeats = kwargs.pop('repeats', None)
|
|
|
+ seed = kwargs.pop('seed', None)
|
|
|
|
|
|
profile = kwargs.pop('profile', None)
|
|
|
if profile is not None:
|
|
|
@@ -128,13 +121,44 @@ class SegModel(BaseModel):
|
|
|
if log_iters is not None:
|
|
|
cli_args.append(CLIArgument('--log_iters', log_iters))
|
|
|
|
|
|
- repeats = kwargs.pop('repeats', None)
|
|
|
- if repeats is not None:
|
|
|
- cli_args.append(CLIArgument('--repeats', repeats))
|
|
|
-
|
|
|
- seed = kwargs.pop('seed', None)
|
|
|
- if seed is not None:
|
|
|
- cli_args.append(CLIArgument('--seed', seed))
|
|
|
+ # Benchmarking mode settings
|
|
|
+ benchmark = kwargs.pop('benchmark', None)
|
|
|
+ if benchmark is not None:
|
|
|
+ envs = benchmark.get('env', None)
|
|
|
+ seed = benchmark.get('seed', None)
|
|
|
+ repeats = benchmark.get('repeats', None)
|
|
|
+ do_eval = benchmark.get('do_eval', False)
|
|
|
+ num_workers = benchmark.get('num_workers', None)
|
|
|
+ config.update_log_ranks(device)
|
|
|
+ amp = benchmark.get('amp', None)
|
|
|
+ config.update_print_mem_info(benchmark.get('print_mem_info', True))
|
|
|
+ if repeats is not None:
|
|
|
+ assert isinstance(repeats, int), 'repeats must be an integer.'
|
|
|
+ cli_args.append(CLIArgument('--repeats', repeats))
|
|
|
+ if num_workers is not None:
|
|
|
+ assert isinstance(num_workers,
|
|
|
+ int), 'num_workers must be an integer.'
|
|
|
+ cli_args.append(CLIArgument('--num_workers', num_workers))
|
|
|
+ if seed is not None:
|
|
|
+ assert isinstance(seed, int), 'seed must be an integer.'
|
|
|
+ cli_args.append(CLIArgument('--seed', seed))
|
|
|
+ if amp in ['O1', 'O2']:
|
|
|
+ cli_args.append(CLIArgument('--precision', 'fp16'))
|
|
|
+ cli_args.append(CLIArgument('--amp_level', amp))
|
|
|
+ if envs is not None:
|
|
|
+ for env_name, env_value in envs.items():
|
|
|
+ os.environ[env_name] = env_value
|
|
|
+ else:
|
|
|
+ if amp is not None:
|
|
|
+ if amp != 'OFF':
|
|
|
+ cli_args.append(CLIArgument('--precision', 'fp16'))
|
|
|
+ cli_args.append(CLIArgument('--amp_level', amp))
|
|
|
+ if num_workers is not None:
|
|
|
+ cli_args.append(CLIArgument('--num_workers', num_workers))
|
|
|
+ if repeats is not None:
|
|
|
+ cli_args.append(CLIArgument('--repeats', repeats))
|
|
|
+ if seed is not None:
|
|
|
+ cli_args.append(CLIArgument('--seed', seed))
|
|
|
|
|
|
self._assert_empty_kwargs(kwargs)
|
|
|
|