runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import io
  16. import sys
  17. import abc
  18. import shlex
  19. import locale
  20. import asyncio
  21. from .utils.arg import CLIArgument
  22. from .utils.subprocess import run_cmd as _run_cmd, CompletedProcess
  23. from ...utils import logging
  24. from ...utils.misc import abspath
  25. from ...utils.flags import DRY_RUN
  26. from ...utils.errors import raise_unsupported_api_error, CalledProcessError
  27. __all__ = ["BaseRunner", "InferOnlyRunner"]
  28. class BaseRunner(metaclass=abc.ABCMeta):
  29. """
  30. Abstract base class of Runner.
  31. Runner is responsible for executing training/inference/compression commands.
  32. """
  33. def __init__(self, runner_root_path):
  34. """
  35. Initialize the instance.
  36. Args:
  37. runner_root_path (str): Path of the directory where the scripts reside.
  38. """
  39. super().__init__()
  40. self.runner_root_path = abspath(runner_root_path)
  41. # Path to python interpreter
  42. self.python = sys.executable
  43. def prepare(self):
  44. """
  45. Make preparations for the execution of commands.
  46. For example, download prerequisites and install dependencies.
  47. """
  48. # By default we do nothing
  49. pass
  50. @abc.abstractmethod
  51. def train(self, config_path, cli_args, device, ips, save_dir, do_eval=True):
  52. """
  53. Execute model training command.
  54. Args:
  55. config_path (str): Path of the configuration file.
  56. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  57. arguments.
  58. device (str): A string that describes the device(s) to use, e.g.,
  59. 'cpu', 'xpu:0', 'gpu:1,2'.
  60. ips (str|None): Paddle cluster node ips, e.g.,
  61. '192.168.0.16,192.168.0.17'.
  62. save_dir (str): Directory to save log files.
  63. do_eval (bool, optional): Whether to perform model evaluation during
  64. training. Default: True.
  65. Returns:
  66. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  67. """
  68. raise NotImplementedError
  69. @abc.abstractmethod
  70. def evaluate(self, config_path, cli_args, device, ips):
  71. """
  72. Execute model evaluation command.
  73. Args:
  74. config_path (str): Path of the configuration file.
  75. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  76. arguments.
  77. device (str): A string that describes the device(s) to use, e.g.,
  78. 'cpu', 'xpu:0', 'gpu:1,2'.
  79. ips (str|None): Paddle cluster node ips, e.g.,
  80. '192.168.0.16,192.168.0.17'.
  81. Returns:
  82. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  83. """
  84. raise NotImplementedError
  85. @abc.abstractmethod
  86. def predict(self, config_path, cli_args, device):
  87. """
  88. Execute prediction command.
  89. Args:
  90. config_path (str): Path of the configuration file.
  91. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  92. arguments.
  93. device (str): A string that describes the device(s) to use, e.g.,
  94. 'cpu', 'xpu:0', 'gpu:1,2'.
  95. Returns:
  96. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  97. """
  98. raise NotImplementedError
  99. @abc.abstractmethod
  100. def export(self, config_path, cli_args, device):
  101. """
  102. Execute model export command.
  103. Args:
  104. config_path (str): Path of the configuration file.
  105. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  106. arguments.
  107. device (str): A string that describes the device(s) to use, e.g.,
  108. 'cpu', 'xpu:0', 'gpu:1,2'.
  109. Returns:
  110. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  111. """
  112. raise NotImplementedError
  113. @abc.abstractmethod
  114. def infer(self, config_path, cli_args, device):
  115. """
  116. Execute model inference command.
  117. Args:
  118. config_path (str): Path of the configuration file.
  119. cli_args (list[base.utils.arg.CLIArgument]): List of command-line
  120. arguments.
  121. device (str): A string that describes the device(s) to use, e.g.,
  122. 'cpu', 'xpu:0', 'gpu:1,2'.
  123. Returns:
  124. paddlex.repo_apis.base.utils.subprocess.CompletedProcess
  125. """
  126. raise NotImplementedError
  127. @abc.abstractmethod
  128. def compression(
  129. self, config_path, train_cli_args, export_cli_args, device, train_save_dir
  130. ):
  131. """
  132. Execute model compression (quantization aware training and model export)
  133. commands.
  134. Args:
  135. config_path (str): Path of the configuration file.
  136. train_cli_args (list[base.utils.arg.CLIArgument]): List of
  137. command-line arguments used for model training.
  138. export_cli_args (list[base.utils.arg.CLIArgument]): List of
  139. command-line arguments used for model export.
  140. device (str): A string that describes the device(s) to use, e.g.,
  141. 'cpu', 'xpu:0', 'gpu:1,2'.
  142. train_save_dir (str): Directory to store model snapshots.
  143. Returns:
  144. tuple[paddlex.repo_apis.base.utils.subprocess.CompletedProcess]
  145. """
  146. raise NotImplementedError
  147. def distributed(self, device, ips=None, log_dir=None):
  148. """distributed"""
  149. # TODO: docstring
  150. args = [self.python]
  151. if device is None:
  152. return args, None
  153. device, dev_ids = self.parse_device(device)
  154. if len(dev_ids) == 0:
  155. return args, None
  156. else:
  157. num_devices = len(dev_ids)
  158. dev_ids = ",".join(dev_ids)
  159. if num_devices > 1:
  160. args.extend(["-m", "paddle.distributed.launch"])
  161. args.extend(["--devices", dev_ids])
  162. if ips is not None:
  163. args.extend(["--ips", ips])
  164. if log_dir is None:
  165. log_dir = os.getcwd()
  166. args.extend(["--log_dir", self._get_dist_train_log_dir(log_dir)])
  167. elif num_devices == 1:
  168. new_env = os.environ.copy()
  169. if device == "xpu":
  170. new_env["XPU_VISIBLE_DEVICES"] = dev_ids
  171. elif device == "npu":
  172. new_env["ASCEND_RT_VISIBLE_DEVICES"] = dev_ids
  173. elif device == "mlu":
  174. new_env["MLU_VISIBLE_DEVICES"] = dev_ids
  175. else:
  176. new_env["CUDA_VISIBLE_DEVICES"] = dev_ids
  177. return args, new_env
  178. return args, None
  179. def parse_device(self, device):
  180. """parse_device"""
  181. # According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
  182. if ":" not in device:
  183. device_type, dev_ids = device, []
  184. else:
  185. device_type, dev_ids = device.split(":")
  186. dev_ids = dev_ids.split(",")
  187. if device_type not in ("cpu", "gpu", "xpu", "npu", "mlu"):
  188. raise ValueError("Unsupported device type.")
  189. for dev_id in dev_ids:
  190. if not dev_id.isdigit():
  191. raise ValueError("Device ID must be an integer.")
  192. return device_type, dev_ids
  193. def run_cmd(
  194. self,
  195. cmd,
  196. env=None,
  197. switch_wdir=True,
  198. silent=False,
  199. echo=True,
  200. capture_output=False,
  201. log_path=None,
  202. ):
  203. """run_cmd"""
  204. def _trans_args(cmd):
  205. out = []
  206. for ele in cmd:
  207. if isinstance(ele, CLIArgument):
  208. out.extend(ele.lst)
  209. else:
  210. out.append(ele)
  211. return out
  212. cmd = _trans_args(cmd)
  213. if DRY_RUN:
  214. # TODO: Accommodate Windows system
  215. logging.info(" ".join(shlex.quote(x) for x in cmd))
  216. # Mock return
  217. return CompletedProcess(cmd, returncode=0, stdout=str(cmd), stderr=None)
  218. if switch_wdir:
  219. if isinstance(switch_wdir, str):
  220. # In this case `switch_wdir` specifies a relative path
  221. cwd = os.path.join(self.runner_root_path, switch_wdir)
  222. else:
  223. cwd = self.runner_root_path
  224. else:
  225. cwd = None
  226. if not capture_output:
  227. if log_path is not None:
  228. logging.warning(
  229. "`log_path` will be ignored when `capture_output` is False."
  230. )
  231. cp = _run_cmd(
  232. cmd,
  233. env=env,
  234. cwd=cwd,
  235. silent=silent,
  236. echo=echo,
  237. pipe_stdout=False,
  238. pipe_stderr=False,
  239. blocking=True,
  240. )
  241. cp = CompletedProcess(
  242. cp.args, cp.returncode, stdout=cp.stdout, stderr=cp.stderr
  243. )
  244. else:
  245. # Refer to
  246. # https://stackoverflow.com/questions/17190221/subprocess-popen-cloning-stdout-and-stderr-both-to-terminal-and-variables/25960956
  247. async def _read_display_and_record_from_stream(
  248. in_stream, out_stream, files
  249. ):
  250. # According to
  251. # https://docs.python.org/3/library/subprocess.html#frequently-used-arguments
  252. _ENCODING = locale.getpreferredencoding(False)
  253. chars = []
  254. out_stream_is_buffered = hasattr(out_stream, "buffer")
  255. while True:
  256. flush = False
  257. char = await in_stream.read(1)
  258. if char == b"":
  259. break
  260. if out_stream_is_buffered:
  261. out_stream.buffer.write(char)
  262. chars.append(char)
  263. if char == b"\n":
  264. flush = True
  265. elif char == b"\r":
  266. # NOTE: In order to get tqdm progress bars to produce normal outputs
  267. # we treat '\r' as an ending character of line
  268. flush = True
  269. if flush:
  270. line = b"".join(chars)
  271. line = line.decode(_ENCODING)
  272. if not out_stream_is_buffered:
  273. # We use line buffering
  274. out_stream.write(line)
  275. else:
  276. out_stream.buffer.flush()
  277. for f in files:
  278. f.write(line)
  279. chars.clear()
  280. async def _tee_proc_call(proc_call, out_files, err_files):
  281. proc = await proc_call
  282. await asyncio.gather(
  283. _read_display_and_record_from_stream(
  284. proc.stdout, sys.stdout, out_files
  285. ),
  286. _read_display_and_record_from_stream(
  287. proc.stderr, sys.stderr, err_files
  288. ),
  289. )
  290. # NOTE: https://docs.python.org/3/library/subprocess.html#subprocess.Popen.wait
  291. retcode = await proc.wait()
  292. return retcode
  293. # Non-blocking call with stdout and stderr piped
  294. with io.StringIO() as stdout_buf, io.StringIO() as stderr_buf:
  295. proc_call = _run_cmd(
  296. cmd,
  297. env=env,
  298. cwd=cwd,
  299. echo=echo,
  300. silent=silent,
  301. pipe_stdout=True,
  302. pipe_stderr=True,
  303. blocking=False,
  304. async_run=True,
  305. )
  306. out_files = [stdout_buf]
  307. err_files = [stderr_buf]
  308. if log_path is not None:
  309. log_dir = os.path.dirname(log_path)
  310. os.makedirs(log_dir, exist_ok=True)
  311. log_file = open(log_path, "w", encoding="utf-8")
  312. logging.info(f"\nLog path: {os.path.abspath(log_path)} \n")
  313. out_files.append(log_file)
  314. err_files.append(log_file)
  315. try:
  316. retcode = asyncio.run(
  317. _tee_proc_call(proc_call, out_files, err_files)
  318. )
  319. finally:
  320. if log_path is not None:
  321. log_file.close()
  322. cp = CompletedProcess(
  323. cmd, retcode, stdout_buf.getvalue(), stderr_buf.getvalue()
  324. )
  325. if cp.returncode != 0:
  326. raise CalledProcessError(
  327. cp.returncode, cp.args, output=cp.stdout, stderr=cp.stderr
  328. )
  329. return cp
  330. def _get_dist_train_log_dir(self, log_dir):
  331. """_get_dist_train_log_dir"""
  332. return os.path.join(log_dir, "distributed_train_logs")
  333. def _get_train_log_path(self, log_dir):
  334. """_get_train_log_path"""
  335. return os.path.join(log_dir, "train.log")
  336. class InferOnlyRunner(BaseRunner):
  337. """InferOnlyRunner"""
  338. def train(self, *args, **kwargs):
  339. """train"""
  340. raise_unsupported_api_error(self.__class__, "train")
  341. def evaluate(self, *args, **kwargs):
  342. """evaluate"""
  343. raise_unsupported_api_error(self.__class__, "evalaute")
  344. def predict(self, *args, **kwargs):
  345. """predict"""
  346. raise_unsupported_api_error(self.__class__, "predict")
  347. def export(self, *args, **kwargs):
  348. """export"""
  349. raise_unsupported_api_error(self.__class__, "export")
  350. def compression(self, *args, **kwargs):
  351. """compression"""
  352. raise_unsupported_api_error(self.__class__, "compression")