runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import io
  16. import sys
  17. import abc
  18. import shlex
  19. import locale
  20. import asyncio
  21. from .utils.arg import CLIArgument
  22. from .utils.subprocess import run_cmd as _run_cmd, CompletedProcess
  23. from ...utils import logging
  24. from ...utils.misc import abspath
  25. from ...utils.device import parse_device
  26. from ...utils.flags import DRY_RUN
  27. from ...utils.errors import raise_unsupported_api_error, CalledProcessError
  28. __all__ = ["BaseRunner", "InferOnlyRunner"]
  29. class BaseRunner(metaclass=abc.ABCMeta):
  30. """
  31. Abstract base class of Runner.
  32. Runner is responsible for executing training/inference/compression commands.
  33. """
  34. def __init__(self, runner_root_path):
  35. """
  36. Initialize the instance.
  37. Args:
  38. runner_root_path (str): Path of the directory where the scripts reside.
  39. """
  40. super().__init__()
  41. self.runner_root_path = abspath(runner_root_path)
  42. # Path to python interpreter
  43. self.python = sys.executable
  44. def prepare(self):
  45. """
  46. Make preparations for the execution of commands.
  47. For example, download prerequisites and install dependencies.
  48. """
  49. # By default we do nothing
  50. pass
  51. @abc.abstractmethod
  52. def train(self, config_path, cli_args, device, ips, save_dir, do_eval=True):
  53. """
  54. Execute model training command.
  55. Args:
  56. config_path (str): Path of the configuration file.
  57. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  58. arguments.
  59. device (str): A string that describes the device(s) to use, e.g.,
  60. 'cpu', 'xpu:0', 'gpu:1,2'.
  61. ips (str|None): Paddle cluster node ips, e.g.,
  62. '192.168.0.16,192.168.0.17'.
  63. save_dir (str): Directory to save log files.
  64. do_eval (bool, optional): Whether to perform model evaluation during
  65. training. Default: True.
  66. Returns:
  67. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  68. """
  69. raise NotImplementedError
  70. @abc.abstractmethod
  71. def evaluate(self, config_path, cli_args, device, ips):
  72. """
  73. Execute model evaluation command.
  74. Args:
  75. config_path (str): Path of the configuration file.
  76. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  77. arguments.
  78. device (str): A string that describes the device(s) to use, e.g.,
  79. 'cpu', 'xpu:0', 'gpu:1,2'.
  80. ips (str|None): Paddle cluster node ips, e.g.,
  81. '192.168.0.16,192.168.0.17'.
  82. Returns:
  83. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  84. """
  85. raise NotImplementedError
  86. @abc.abstractmethod
  87. def predict(self, config_path, cli_args, device):
  88. """
  89. Execute prediction command.
  90. Args:
  91. config_path (str): Path of the configuration file.
  92. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  93. arguments.
  94. device (str): A string that describes the device(s) to use, e.g.,
  95. 'cpu', 'xpu:0', 'gpu:1,2'.
  96. Returns:
  97. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  98. """
  99. raise NotImplementedError
  100. @abc.abstractmethod
  101. def export(self, config_path, cli_args, device):
  102. """
  103. Execute model export command.
  104. Args:
  105. config_path (str): Path of the configuration file.
  106. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  107. arguments.
  108. device (str): A string that describes the device(s) to use, e.g.,
  109. 'cpu', 'xpu:0', 'gpu:1,2'.
  110. Returns:
  111. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  112. """
  113. raise NotImplementedError
  114. @abc.abstractmethod
  115. def infer(self, config_path, cli_args, device):
  116. """
  117. Execute model inference command.
  118. Args:
  119. config_path (str): Path of the configuration file.
  120. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  121. arguments.
  122. device (str): A string that describes the device(s) to use, e.g.,
  123. 'cpu', 'xpu:0', 'gpu:1,2'.
  124. Returns:
  125. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  126. """
  127. raise NotImplementedError
  128. @abc.abstractmethod
  129. def compression(
  130. self, config_path, train_cli_args, export_cli_args, device, train_save_dir
  131. ):
  132. """
  133. Execute model compression (quantization aware training and model export)
  134. commands.
  135. Args:
  136. config_path (str): Path of the configuration file.
  137. train_cli_args (list[base.utils.arg.CLIArgument]): List of
  138. command-line arguments used for model training.
  139. export_cli_args (list[base.utils.arg.CLIArgument]): List of
  140. command-line arguments used for model export.
  141. device (str): A string that describes the device(s) to use, e.g.,
  142. 'cpu', 'xpu:0', 'gpu:1,2'.
  143. train_save_dir (str): Directory to store model snapshots.
  144. Returns:
  145. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  146. """
  147. raise NotImplementedError
  148. def distributed(self, device, ips=None, log_dir=None):
  149. """distributed"""
  150. # TODO: docstring
  151. args = [self.python]
  152. if device is None:
  153. return args, None
  154. device, dev_ids = parse_device(device)
  155. if len(dev_ids) == 0:
  156. return args, None
  157. else:
  158. num_devices = len(dev_ids)
  159. dev_ids = ",".join([str(n) for n in dev_ids])
  160. if num_devices > 1:
  161. args.extend(["-m", "paddle.distributed.launch"])
  162. args.extend(["--devices", dev_ids])
  163. if ips is not None:
  164. args.extend(["--ips", ips])
  165. if log_dir is None:
  166. log_dir = os.getcwd()
  167. args.extend(["--log_dir", self._get_dist_train_log_dir(log_dir)])
  168. elif num_devices == 1:
  169. new_env = os.environ.copy()
  170. if device == "xpu":
  171. new_env["XPU_VISIBLE_DEVICES"] = dev_ids
  172. elif device == "npu":
  173. new_env["ASCEND_RT_VISIBLE_DEVICES"] = dev_ids
  174. elif device == "mlu":
  175. new_env["MLU_VISIBLE_DEVICES"] = dev_ids
  176. else:
  177. new_env["CUDA_VISIBLE_DEVICES"] = dev_ids
  178. return args, new_env
  179. return args, None
  180. def run_cmd(
  181. self,
  182. cmd,
  183. env=None,
  184. switch_wdir=True,
  185. silent=False,
  186. echo=True,
  187. capture_output=False,
  188. log_path=None,
  189. ):
  190. """run_cmd"""
  191. def _trans_args(cmd):
  192. out = []
  193. for ele in cmd:
  194. if isinstance(ele, CLIArgument):
  195. out.extend(ele.lst)
  196. else:
  197. out.append(ele)
  198. return out
  199. cmd = _trans_args(cmd)
  200. if DRY_RUN:
  201. # TODO: Accommodate Windows system
  202. logging.info(" ".join(shlex.quote(x) for x in cmd))
  203. # Mock return
  204. return CompletedProcess(cmd, returncode=0, stdout=str(cmd), stderr=None)
  205. if switch_wdir:
  206. if isinstance(switch_wdir, str):
  207. # In this case `switch_wdir` specifies a relative path
  208. cwd = os.path.join(self.runner_root_path, switch_wdir)
  209. else:
  210. cwd = self.runner_root_path
  211. else:
  212. cwd = None
  213. if not capture_output:
  214. if log_path is not None:
  215. logging.warning(
  216. "`log_path` will be ignored when `capture_output` is False."
  217. )
  218. cp = _run_cmd(
  219. cmd,
  220. env=env,
  221. cwd=cwd,
  222. silent=silent,
  223. echo=echo,
  224. pipe_stdout=False,
  225. pipe_stderr=False,
  226. blocking=True,
  227. )
  228. cp = CompletedProcess(
  229. cp.args, cp.returncode, stdout=cp.stdout, stderr=cp.stderr
  230. )
  231. else:
  232. # Refer to
  233. # https://stackoverflow.com/questions/17190221/subprocess-popen-cloning-stdout-and-stderr-both-to-terminal-and-variables/25960956
  234. async def _read_display_and_record_from_stream(
  235. in_stream, out_stream, files
  236. ):
  237. # According to
  238. # https://docs.python.org/3/library/subprocess.html#frequently-used-arguments
  239. _ENCODING = locale.getpreferredencoding(False)
  240. chars = []
  241. out_stream_is_buffered = hasattr(out_stream, "buffer")
  242. while True:
  243. flush = False
  244. char = await in_stream.read(1)
  245. if char == b"":
  246. break
  247. if out_stream_is_buffered:
  248. out_stream.buffer.write(char)
  249. chars.append(char)
  250. if char == b"\n":
  251. flush = True
  252. elif char == b"\r":
  253. # NOTE: In order to get tqdm progress bars to produce normal outputs
  254. # we treat '\r' as an ending character of line
  255. flush = True
  256. if flush:
  257. line = b"".join(chars)
  258. line = line.decode(_ENCODING)
  259. if not out_stream_is_buffered:
  260. # We use line buffering
  261. out_stream.write(line)
  262. else:
  263. out_stream.buffer.flush()
  264. for f in files:
  265. f.write(line)
  266. chars.clear()
  267. async def _tee_proc_call(proc_call, out_files, err_files):
  268. proc = await proc_call
  269. await asyncio.gather(
  270. _read_display_and_record_from_stream(
  271. proc.stdout, sys.stdout, out_files
  272. ),
  273. _read_display_and_record_from_stream(
  274. proc.stderr, sys.stderr, err_files
  275. ),
  276. )
  277. # NOTE: https://docs.python.org/3/library/subprocess.html#subprocess.Popen.wait
  278. retcode = await proc.wait()
  279. return retcode
  280. # Non-blocking call with stdout and stderr piped
  281. with io.StringIO() as stdout_buf, io.StringIO() as stderr_buf:
  282. proc_call = _run_cmd(
  283. cmd,
  284. env=env,
  285. cwd=cwd,
  286. echo=echo,
  287. silent=silent,
  288. pipe_stdout=True,
  289. pipe_stderr=True,
  290. blocking=False,
  291. async_run=True,
  292. )
  293. out_files = [stdout_buf]
  294. err_files = [stderr_buf]
  295. if log_path is not None:
  296. log_dir = os.path.dirname(log_path)
  297. os.makedirs(log_dir, exist_ok=True)
  298. log_file = open(log_path, "w", encoding="utf-8")
  299. logging.info(f"\nLog path: {os.path.abspath(log_path)} \n")
  300. out_files.append(log_file)
  301. err_files.append(log_file)
  302. try:
  303. retcode = asyncio.run(
  304. _tee_proc_call(proc_call, out_files, err_files)
  305. )
  306. finally:
  307. if log_path is not None:
  308. log_file.close()
  309. cp = CompletedProcess(
  310. cmd, retcode, stdout_buf.getvalue(), stderr_buf.getvalue()
  311. )
  312. if cp.returncode != 0:
  313. raise CalledProcessError(
  314. cp.returncode, cp.args, output=cp.stdout, stderr=cp.stderr
  315. )
  316. return cp
  317. def _get_dist_train_log_dir(self, log_dir):
  318. """_get_dist_train_log_dir"""
  319. return os.path.join(log_dir, "distributed_train_logs")
  320. def _get_train_log_path(self, log_dir):
  321. """_get_train_log_path"""
  322. return os.path.join(log_dir, "train.log")
  323. class InferOnlyRunner(BaseRunner):
  324. """InferOnlyRunner"""
  325. def train(self, *args, **kwargs):
  326. """train"""
  327. raise_unsupported_api_error(self.__class__, "train")
  328. def evaluate(self, *args, **kwargs):
  329. """evaluate"""
  330. raise_unsupported_api_error(self.__class__, "evalaute")
  331. def predict(self, *args, **kwargs):
  332. """predict"""
  333. raise_unsupported_api_error(self.__class__, "predict")
  334. def export(self, *args, **kwargs):
  335. """export"""
  336. raise_unsupported_api_error(self.__class__, "export")
  337. def compression(self, *args, **kwargs):
  338. """compression"""
  339. raise_unsupported_api_error(self.__class__, "compression")