env.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import glob
  16. import os
  17. import os.path as osp
  18. import platform
  19. import random
  20. import numpy as np
  21. import multiprocessing as mp
  22. import paddle
  23. from . import logging
  24. def get_environ_info():
  25. """collect environment information"""
  26. env_info = dict()
  27. # TODO is_compiled_with_cuda() has not been moved
  28. compiled_with_cuda = paddle.is_compiled_with_cuda()
  29. if compiled_with_cuda:
  30. if 'gpu' in paddle.get_device():
  31. gpu_nums = paddle.distributed.get_world_size()
  32. else:
  33. gpu_nums = 0
  34. if gpu_nums == 0:
  35. os.environ['CUDA_VISIBLE_DEVICES'] = ''
  36. place = 'gpu' if compiled_with_cuda and gpu_nums else 'cpu'
  37. env_info['place'] = place
  38. env_info['num'] = int(os.environ.get('CPU_NUM', 1))
  39. if place == 'gpu':
  40. env_info['num'] = gpu_nums
  41. return env_info
  42. def get_num_workers(num_workers):
  43. if not platform.system() == 'Linux':
  44. # Dataloader with multi-process model is not supported
  45. # on MacOS and Windows currently.
  46. return 0
  47. if num_workers == 'auto':
  48. num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 2 else 2
  49. return num_workers
  50. def init_parallel_env():
  51. env = os.environ
  52. if 'FLAGS_allocator_strategy' not in os.environ:
  53. os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
  54. dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
  55. if dist:
  56. trainer_id = int(env['PADDLE_TRAINER_ID'])
  57. local_seed = (99 + trainer_id)
  58. random.seed(local_seed)
  59. np.random.seed(local_seed)
  60. if paddle.distributed.get_world_size() > 1:
  61. paddle.distributed.init_parallel_env()