runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import io
  16. import sys
  17. import abc
  18. import shlex
  19. import locale
  20. import asyncio
  21. from .utils.arg import CLIArgument
  22. from .utils.subprocess import run_cmd as _run_cmd, CompletedProcess
  23. from ...utils import logging
  24. from ...utils.misc import abspath
  25. from ...utils.device import parse_device
  26. from ...utils.flags import DRY_RUN
  27. from ...utils.errors import raise_unsupported_api_error, CalledProcessError
  28. __all__ = ["BaseRunner", "InferOnlyRunner"]
  29. class BaseRunner(metaclass=abc.ABCMeta):
  30. """
  31. Abstract base class of Runner.
  32. Runner is responsible for executing training/inference/compression commands.
  33. """
  34. def __init__(self, runner_root_path):
  35. """
  36. Initialize the instance.
  37. Args:
  38. runner_root_path (str): Path of the directory where the scripts reside.
  39. """
  40. super().__init__()
  41. self.runner_root_path = abspath(runner_root_path)
  42. # Path to python interpreter
  43. self.python = sys.executable
  44. def prepare(self):
  45. """
  46. Make preparations for the execution of commands.
  47. For example, download prerequisites and install dependencies.
  48. """
  49. # By default we do nothing
  50. pass
  51. @abc.abstractmethod
  52. def train(self, config_path, cli_args, device, ips, save_dir, do_eval=True):
  53. """
  54. Execute model training command.
  55. Args:
  56. config_path (str): Path of the configuration file.
  57. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  58. arguments.
  59. device (str): A string that describes the device(s) to use, e.g.,
  60. 'cpu', 'xpu:0', 'gpu:1,2'.
  61. ips (str|None): Paddle cluster node ips, e.g.,
  62. '192.168.0.16,192.168.0.17'.
  63. save_dir (str): Directory to save log files.
  64. do_eval (bool, optional): Whether to perform model evaluation during
  65. training. Default: True.
  66. Returns:
  67. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  68. """
  69. raise NotImplementedError
  70. @abc.abstractmethod
  71. def evaluate(self, config_path, cli_args, device, ips):
  72. """
  73. Execute model evaluation command.
  74. Args:
  75. config_path (str): Path of the configuration file.
  76. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  77. arguments.
  78. device (str): A string that describes the device(s) to use, e.g.,
  79. 'cpu', 'xpu:0', 'gpu:1,2'.
  80. ips (str|None): Paddle cluster node ips, e.g.,
  81. '192.168.0.16,192.168.0.17'.
  82. Returns:
  83. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  84. """
  85. raise NotImplementedError
  86. @abc.abstractmethod
  87. def predict(self, config_path, cli_args, device):
  88. """
  89. Execute prediction command.
  90. Args:
  91. config_path (str): Path of the configuration file.
  92. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  93. arguments.
  94. device (str): A string that describes the device(s) to use, e.g.,
  95. 'cpu', 'xpu:0', 'gpu:1,2'.
  96. Returns:
  97. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  98. """
  99. raise NotImplementedError
  100. @abc.abstractmethod
  101. def export(self, config_path, cli_args, device):
  102. """
  103. Execute model export command.
  104. Args:
  105. config_path (str): Path of the configuration file.
  106. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  107. arguments.
  108. device (str): A string that describes the device(s) to use, e.g.,
  109. 'cpu', 'xpu:0', 'gpu:1,2'.
  110. Returns:
  111. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  112. """
  113. raise NotImplementedError
  114. @abc.abstractmethod
  115. def infer(self, config_path, cli_args, device):
  116. """
  117. Execute model inference command.
  118. Args:
  119. config_path (str): Path of the configuration file.
  120. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  121. arguments.
  122. device (str): A string that describes the device(s) to use, e.g.,
  123. 'cpu', 'xpu:0', 'gpu:1,2'.
  124. Returns:
  125. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  126. """
  127. raise NotImplementedError
  128. @abc.abstractmethod
  129. def compression(
  130. self, config_path, train_cli_args, export_cli_args, device, train_save_dir
  131. ):
  132. """
  133. Execute model compression (quantization aware training and model export)
  134. commands.
  135. Args:
  136. config_path (str): Path of the configuration file.
  137. train_cli_args (list[base.utils.arg.CLIArgument]): List of
  138. command-line arguments used for model training.
  139. export_cli_args (list[base.utils.arg.CLIArgument]): List of
  140. command-line arguments used for model export.
  141. device (str): A string that describes the device(s) to use, e.g.,
  142. 'cpu', 'xpu:0', 'gpu:1,2'.
  143. train_save_dir (str): Directory to store model snapshots.
  144. Returns:
  145. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  146. """
  147. raise NotImplementedError
  148. def distributed(self, device, ips=None, log_dir=None):
  149. """distributed"""
  150. # TODO: docstring
  151. args = [self.python]
  152. if device is None:
  153. return args, None
  154. device, dev_ids = parse_device(device)
  155. if dev_ids is None or len(dev_ids) == 0:
  156. return args, None
  157. else:
  158. num_devices = len(dev_ids)
  159. dev_ids = ",".join([str(n) for n in dev_ids])
  160. if num_devices > 1:
  161. args.extend(["-m", "paddle.distributed.launch"])
  162. args.extend(["--devices", dev_ids])
  163. if ips is not None:
  164. args.extend(["--ips", ips])
  165. if log_dir is None:
  166. log_dir = os.getcwd()
  167. args.extend(["--log_dir", self._get_dist_train_log_dir(log_dir)])
  168. elif num_devices == 1:
  169. new_env = os.environ.copy()
  170. if device == "xpu":
  171. new_env["XPU_VISIBLE_DEVICES"] = dev_ids
  172. elif device == "npu":
  173. new_env["ASCEND_RT_VISIBLE_DEVICES"] = dev_ids
  174. elif device == "mlu":
  175. new_env["MLU_VISIBLE_DEVICES"] = dev_ids
  176. elif device == "gcu":
  177. new_env["TOPS_VISIBLE_DEVICES"] = dev_ids
  178. else:
  179. new_env["CUDA_VISIBLE_DEVICES"] = dev_ids
  180. return args, new_env
  181. return args, None
  182. def run_cmd(
  183. self,
  184. cmd,
  185. env=None,
  186. switch_wdir=True,
  187. silent=False,
  188. echo=True,
  189. capture_output=False,
  190. log_path=None,
  191. ):
  192. """run_cmd"""
  193. def _trans_args(cmd):
  194. out = []
  195. for ele in cmd:
  196. if isinstance(ele, CLIArgument):
  197. out.extend(ele.lst)
  198. else:
  199. out.append(ele)
  200. return out
  201. cmd = _trans_args(cmd)
  202. if DRY_RUN:
  203. # TODO: Accommodate Windows system
  204. logging.info(" ".join(shlex.quote(x) for x in cmd))
  205. # Mock return
  206. return CompletedProcess(cmd, returncode=0, stdout=str(cmd), stderr=None)
  207. if switch_wdir:
  208. if isinstance(switch_wdir, str):
  209. # In this case `switch_wdir` specifies a relative path
  210. cwd = os.path.join(self.runner_root_path, switch_wdir)
  211. else:
  212. cwd = self.runner_root_path
  213. else:
  214. cwd = None
  215. if not capture_output:
  216. if log_path is not None:
  217. logging.warning(
  218. "`log_path` will be ignored when `capture_output` is False."
  219. )
  220. cp = _run_cmd(
  221. cmd,
  222. env=env,
  223. cwd=cwd,
  224. silent=silent,
  225. echo=echo,
  226. pipe_stdout=False,
  227. pipe_stderr=False,
  228. blocking=True,
  229. )
  230. cp = CompletedProcess(
  231. cp.args, cp.returncode, stdout=cp.stdout, stderr=cp.stderr
  232. )
  233. else:
  234. # Refer to
  235. # https://stackoverflow.com/questions/17190221/subprocess-popen-cloning-stdout-and-stderr-both-to-terminal-and-variables/25960956
  236. async def _read_display_and_record_from_stream(
  237. in_stream, out_stream, files
  238. ):
  239. # According to
  240. # https://docs.python.org/3/library/subprocess.html#frequently-used-arguments
  241. _ENCODING = locale.getpreferredencoding(False)
  242. chars = []
  243. out_stream_is_buffered = hasattr(out_stream, "buffer")
  244. while True:
  245. flush = False
  246. char = await in_stream.read(1)
  247. if char == b"":
  248. break
  249. if out_stream_is_buffered:
  250. out_stream.buffer.write(char)
  251. chars.append(char)
  252. if char == b"\n":
  253. flush = True
  254. elif char == b"\r":
  255. # NOTE: In order to get tqdm progress bars to produce normal outputs
  256. # we treat '\r' as an ending character of line
  257. flush = True
  258. if flush:
  259. line = b"".join(chars)
  260. line = line.decode(_ENCODING)
  261. if not out_stream_is_buffered:
  262. # We use line buffering
  263. out_stream.write(line)
  264. else:
  265. out_stream.buffer.flush()
  266. for f in files:
  267. f.write(line)
  268. chars.clear()
  269. async def _tee_proc_call(proc_call, out_files, err_files):
  270. proc = await proc_call
  271. await asyncio.gather(
  272. _read_display_and_record_from_stream(
  273. proc.stdout, sys.stdout, out_files
  274. ),
  275. _read_display_and_record_from_stream(
  276. proc.stderr, sys.stderr, err_files
  277. ),
  278. )
  279. # NOTE: https://docs.python.org/3/library/subprocess.html#subprocess.Popen.wait
  280. retcode = await proc.wait()
  281. return retcode
  282. # Non-blocking call with stdout and stderr piped
  283. with io.StringIO() as stdout_buf, io.StringIO() as stderr_buf:
  284. proc_call = _run_cmd(
  285. cmd,
  286. env=env,
  287. cwd=cwd,
  288. echo=echo,
  289. silent=silent,
  290. pipe_stdout=True,
  291. pipe_stderr=True,
  292. blocking=False,
  293. async_run=True,
  294. )
  295. out_files = [stdout_buf]
  296. err_files = [stderr_buf]
  297. if log_path is not None:
  298. log_dir = os.path.dirname(log_path)
  299. os.makedirs(log_dir, exist_ok=True)
  300. log_file = open(log_path, "w", encoding="utf-8")
  301. logging.info(f"\nLog path: {os.path.abspath(log_path)} \n")
  302. out_files.append(log_file)
  303. err_files.append(log_file)
  304. try:
  305. retcode = asyncio.run(
  306. _tee_proc_call(proc_call, out_files, err_files)
  307. )
  308. finally:
  309. if log_path is not None:
  310. log_file.close()
  311. cp = CompletedProcess(
  312. cmd, retcode, stdout_buf.getvalue(), stderr_buf.getvalue()
  313. )
  314. if cp.returncode != 0:
  315. raise CalledProcessError(
  316. cp.returncode, cp.args, output=cp.stdout, stderr=cp.stderr
  317. )
  318. return cp
  319. def _get_dist_train_log_dir(self, log_dir):
  320. """_get_dist_train_log_dir"""
  321. return os.path.join(log_dir, "distributed_train_logs")
  322. def _get_train_log_path(self, log_dir):
  323. """_get_train_log_path"""
  324. return os.path.join(log_dir, "train.log")
  325. class InferOnlyRunner(BaseRunner):
  326. """InferOnlyRunner"""
  327. def train(self, *args, **kwargs):
  328. """train"""
  329. raise_unsupported_api_error(self.__class__, "train")
  330. def evaluate(self, *args, **kwargs):
  331. """evaluate"""
  332. raise_unsupported_api_error(self.__class__, "evalaute")
  333. def predict(self, *args, **kwargs):
  334. """predict"""
  335. raise_unsupported_api_error(self.__class__, "predict")
  336. def export(self, *args, **kwargs):
  337. """export"""
  338. raise_unsupported_api_error(self.__class__, "export")
  339. def compression(self, *args, **kwargs):
  340. """compression"""
  341. raise_unsupported_api_error(self.__class__, "compression")