runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import io
  16. import sys
  17. import abc
  18. import shlex
  19. import locale
  20. import asyncio
  21. from .utils.arg import CLIArgument
  22. from .utils.subprocess import run_cmd as _run_cmd, CompletedProcess
  23. from ...utils import logging
  24. from ...utils.misc import abspath
  25. from ...utils.flags import DRY_RUN
  26. from ...utils.errors import raise_unsupported_api_error, CalledProcessError
  27. __all__ = ['BaseRunner', 'InferOnlyRunner']
  28. class BaseRunner(metaclass=abc.ABCMeta):
  29. """
  30. Abstract base class of Runner.
  31. Runner is responsible for executing training/inference/compression commands.
  32. """
  33. def __init__(self, runner_root_path):
  34. """
  35. Initialize the instance.
  36. Args:
  37. runner_root_path (str): Path of the directory where the scripts reside.
  38. """
  39. super().__init__()
  40. self.runner_root_path = abspath(runner_root_path)
  41. # Path to python interpreter
  42. self.python = sys.executable
  43. def prepare(self):
  44. """
  45. Make preparations for the execution of commands.
  46. For example, download prerequisites and install dependencies.
  47. """
  48. # By default we do nothing
  49. pass
  50. @abc.abstractmethod
  51. def train(self, config_path, cli_args, device, ips, save_dir, do_eval=True):
  52. """
  53. Execute model training command.
  54. Args:
  55. config_path (str): Path of the configuration file.
  56. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  57. arguments.
  58. device (str): A string that describes the device(s) to use, e.g.,
  59. 'cpu', 'xpu:0', 'gpu:1,2'.
  60. ips (str|None): Paddle cluster node ips, e.g.,
  61. '192.168.0.16,192.168.0.17'.
  62. save_dir (str): Directory to save log files.
  63. do_eval (bool, optional): Whether to perform model evaluation during
  64. training. Default: True.
  65. Returns:
  66. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  67. """
  68. raise NotImplementedError
  69. @abc.abstractmethod
  70. def evaluate(self, config_path, cli_args, device, ips):
  71. """
  72. Execute model evaluation command.
  73. Args:
  74. config_path (str): Path of the configuration file.
  75. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  76. arguments.
  77. device (str): A string that describes the device(s) to use, e.g.,
  78. 'cpu', 'xpu:0', 'gpu:1,2'.
  79. ips (str|None): Paddle cluster node ips, e.g.,
  80. '192.168.0.16,192.168.0.17'.
  81. Returns:
  82. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  83. """
  84. raise NotImplementedError
  85. @abc.abstractmethod
  86. def predict(self, config_path, cli_args, device):
  87. """
  88. Execute prediction command.
  89. Args:
  90. config_path (str): Path of the configuration file.
  91. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  92. arguments.
  93. device (str): A string that describes the device(s) to use, e.g.,
  94. 'cpu', 'xpu:0', 'gpu:1,2'.
  95. Returns:
  96. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  97. """
  98. raise NotImplementedError
  99. @abc.abstractmethod
  100. def export(self, config_path, cli_args, device):
  101. """
  102. Execute model export command.
  103. Args:
  104. config_path (str): Path of the configuration file.
  105. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  106. arguments.
  107. device (str): A string that describes the device(s) to use, e.g.,
  108. 'cpu', 'xpu:0', 'gpu:1,2'.
  109. Returns:
  110. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  111. """
  112. raise NotImplementedError
  113. @abc.abstractmethod
  114. def infer(self, config_path, cli_args, device):
  115. """
  116. Execute model inference command.
  117. Args:
  118. config_path (str): Path of the configuration file.
  119. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  120. arguments.
  121. device (str): A string that describes the device(s) to use, e.g.,
  122. 'cpu', 'xpu:0', 'gpu:1,2'.
  123. Returns:
  124. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  125. """
  126. raise NotImplementedError
  127. @abc.abstractmethod
  128. def compression(self, config_path, train_cli_args, export_cli_args, device,
  129. train_save_dir):
  130. """
  131. Execute model compression (quantization aware training and model export)
  132. commands.
  133. Args:
  134. config_path (str): Path of the configuration file.
  135. train_cli_args (list[base.utils.arg.CLIArgument]): List of
  136. command-line arguments used for model training.
  137. export_cli_args (list[base.utils.arg.CLIArgument]): List of
  138. command-line arguments used for model export.
  139. device (str): A string that describes the device(s) to use, e.g.,
  140. 'cpu', 'xpu:0', 'gpu:1,2'.
  141. train_save_dir (str): Directory to store model snapshots.
  142. Returns:
  143. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  144. """
  145. raise NotImplementedError
  146. def distributed(self, device, ips=None, log_dir=None):
  147. """ distributed """
  148. # TODO: docstring
  149. args = [self.python]
  150. if device is None:
  151. return args, None
  152. device, dev_ids = self.parse_device(device)
  153. if len(dev_ids) == 0:
  154. return args, None
  155. else:
  156. num_devices = len(dev_ids)
  157. dev_ids = ','.join(dev_ids)
  158. if num_devices > 1:
  159. args.extend(['-m', 'paddle.distributed.launch'])
  160. args.extend(['--devices', dev_ids])
  161. if ips is not None:
  162. args.extend(['--ips', ips])
  163. if log_dir is None:
  164. log_dir = os.getcwd()
  165. args.extend(['--log_dir', self._get_dist_train_log_dir(log_dir)])
  166. elif num_devices == 1:
  167. new_env = os.environ.copy()
  168. if device == 'xpu':
  169. new_env['XPU_VISIBLE_DEVICES'] = dev_ids
  170. elif device == 'npu':
  171. new_env['ASCEND_RT_VISIBLE_DEVICES'] = dev_ids
  172. elif device == 'mlu':
  173. new_env['MLU_VISIBLE_DEVICES'] = dev_ids
  174. else:
  175. new_env['CUDA_VISIBLE_DEVICES'] = dev_ids
  176. return args, new_env
  177. return args, None
  178. def parse_device(self, device):
  179. """ parse_device """
  180. # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
  181. if ':' not in device:
  182. device_type, dev_ids = device, []
  183. else:
  184. device_type, dev_ids = device.split(':')
  185. dev_ids = dev_ids.split(',')
  186. if device_type not in ('cpu', 'gpu', 'xpu', 'npu', 'mlu'):
  187. raise ValueError("Unsupported device type.")
  188. for dev_id in dev_ids:
  189. if not dev_id.isdigit():
  190. raise ValueError("Device ID must be an integer.")
  191. return device_type, dev_ids
  192. def run_cmd(self,
  193. cmd,
  194. env=None,
  195. switch_wdir=True,
  196. silent=False,
  197. echo=True,
  198. capture_output=False,
  199. log_path=None):
  200. """ run_cmd """
  201. def _trans_args(cmd):
  202. out = []
  203. for ele in cmd:
  204. if isinstance(ele, CLIArgument):
  205. out.extend(ele.lst)
  206. else:
  207. out.append(ele)
  208. return out
  209. cmd = _trans_args(cmd)
  210. if DRY_RUN:
  211. # TODO: Accommodate Windows system
  212. logging.info(' '.join(shlex.quote(x) for x in cmd))
  213. # Mock return
  214. return CompletedProcess(
  215. cmd, returncode=0, stdout=str(cmd), stderr=None)
  216. if switch_wdir:
  217. if isinstance(switch_wdir, str):
  218. # In this case `switch_wdir` specifies a relative path
  219. cwd = os.path.join(self.runner_root_path, switch_wdir)
  220. else:
  221. cwd = self.runner_root_path
  222. else:
  223. cwd = None
  224. if not capture_output:
  225. if log_path is not None:
  226. logging.warning(
  227. "`log_path` will be ignored when `capture_output` is False.")
  228. cp = _run_cmd(
  229. cmd,
  230. env=env,
  231. cwd=cwd,
  232. silent=silent,
  233. echo=echo,
  234. pipe_stdout=False,
  235. pipe_stderr=False,
  236. blocking=True)
  237. cp = CompletedProcess(
  238. cp.args, cp.returncode, stdout=cp.stdout, stderr=cp.stderr)
  239. else:
  240. # Refer to
  241. # https://stackoverflow.com/questions/17190221/subprocess-popen-cloning-stdout-and-stderr-both-to-terminal-and-variables/25960956
  242. async def _read_display_and_record_from_stream(in_stream,
  243. out_stream, files):
  244. # According to
  245. # https://docs.python.org/3/library/subprocess.html#frequently-used-arguments
  246. _ENCODING = locale.getpreferredencoding(False)
  247. chars = []
  248. out_stream_is_buffered = hasattr(out_stream, 'buffer')
  249. while True:
  250. flush = False
  251. char = await in_stream.read(1)
  252. if char == b'':
  253. break
  254. if out_stream_is_buffered:
  255. out_stream.buffer.write(char)
  256. chars.append(char)
  257. if char == b'\n':
  258. flush = True
  259. elif char == b'\r':
  260. # NOTE: In order to get tqdm progress bars to produce normal outputs
  261. # we treat '\r' as an ending character of line
  262. flush = True
  263. if flush:
  264. line = b''.join(chars)
  265. line = line.decode(_ENCODING)
  266. if not out_stream_is_buffered:
  267. # We use line buffering
  268. out_stream.write(line)
  269. else:
  270. out_stream.buffer.flush()
  271. for f in files:
  272. f.write(line)
  273. chars.clear()
  274. async def _tee_proc_call(proc_call, out_files, err_files):
  275. proc = await proc_call
  276. await asyncio.gather(
  277. _read_display_and_record_from_stream(proc.stdout,
  278. sys.stdout, out_files),
  279. _read_display_and_record_from_stream(proc.stderr,
  280. sys.stderr, err_files))
  281. # NOTE: https://docs.python.org/3/library/subprocess.html#subprocess.Popen.wait
  282. retcode = await proc.wait()
  283. return retcode
  284. # Non-blocking call with stdout and stderr piped
  285. with io.StringIO() as stdout_buf, io.StringIO() as stderr_buf:
  286. proc_call = _run_cmd(
  287. cmd,
  288. env=env,
  289. cwd=cwd,
  290. echo=echo,
  291. silent=silent,
  292. pipe_stdout=True,
  293. pipe_stderr=True,
  294. blocking=False,
  295. async_run=True)
  296. out_files = [stdout_buf]
  297. err_files = [stderr_buf]
  298. if log_path is not None:
  299. log_dir = os.path.dirname(log_path)
  300. os.makedirs(log_dir, exist_ok=True)
  301. log_file = open(log_path, 'w', encoding='utf-8')
  302. logging.info(f"\nLog path: {os.path.abspath(log_path)} \n")
  303. out_files.append(log_file)
  304. err_files.append(log_file)
  305. try:
  306. retcode = asyncio.run(
  307. _tee_proc_call(proc_call, out_files, err_files))
  308. finally:
  309. if log_path is not None:
  310. log_file.close()
  311. cp = CompletedProcess(cmd, retcode,
  312. stdout_buf.getvalue(),
  313. stderr_buf.getvalue())
  314. if cp.returncode != 0:
  315. raise CalledProcessError(
  316. cp.returncode, cp.args, output=cp.stdout, stderr=cp.stderr)
  317. return cp
  318. def _get_dist_train_log_dir(self, log_dir):
  319. """ _get_dist_train_log_dir """
  320. return os.path.join(log_dir, 'distributed_train_logs')
  321. def _get_train_log_path(self, log_dir):
  322. """ _get_train_log_path """
  323. return os.path.join(log_dir, 'train.log')
  324. class InferOnlyRunner(BaseRunner):
  325. """ InferOnlyRunner """
  326. def train(self, *args, **kwargs):
  327. """ train """
  328. raise_unsupported_api_error(self.__class__, 'train')
  329. def evaluate(self, *args, **kwargs):
  330. """ evaluate """
  331. raise_unsupported_api_error(self.__class__, 'evalaute')
  332. def predict(self, *args, **kwargs):
  333. """ predict """
  334. raise_unsupported_api_error(self.__class__, 'predict')
  335. def export(self, *args, **kwargs):
  336. """ export """
  337. raise_unsupported_api_error(self.__class__, 'export')
  338. def compression(self, *args, **kwargs):
  339. """ compression """
  340. raise_unsupported_api_error(self.__class__, 'compression')