paddlex_cli.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import ast
  16. import importlib.resources
  17. import os
  18. import shutil
  19. import subprocess
  20. import sys
  21. from pathlib import Path
  22. from . import create_pipeline
  23. from .constants import MODEL_FILE_PREFIX
  24. from .inference.pipelines import load_pipeline_config
  25. from .inference.utils.model_paths import get_model_paths
  26. from .repo_manager import get_all_supported_repo_names, setup
  27. from .utils import logging
  28. from .utils.deps import (
  29. get_paddle2onnx_spec,
  30. get_serving_dep_specs,
  31. require_paddle2onnx_plugin,
  32. )
  33. from .utils.env import get_paddle_cuda_version
  34. from .utils.install import install_packages
  35. from .utils.interactive_get_pipeline import interactive_get_pipeline
  36. from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
  37. def args_cfg():
  38. """parse cli arguments"""
  39. def parse_str(s):
  40. """convert str type value
  41. to None type if it is "None",
  42. to bool type if it means True or False.
  43. """
  44. if s in ("None", "none", "NONE"):
  45. return None
  46. elif s in ("TRUE", "True", "true", "T", "t"):
  47. return True
  48. elif s in ("FALSE", "False", "false", "F", "f"):
  49. return False
  50. return s
  51. parser = argparse.ArgumentParser(
  52. "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
  53. )
  54. install_group = parser.add_argument_group("Install PaddleX Options")
  55. pipeline_group = parser.add_argument_group("Pipeline Predict Options")
  56. serving_group = parser.add_argument_group("Serving Options")
  57. paddle2onnx_group = parser.add_argument_group("Paddle2ONNX Options")
  58. ################# install pdx #################
  59. install_group.add_argument(
  60. "--install",
  61. action="store_true",
  62. default=False,
  63. help="Install specified PaddleX plugins.",
  64. )
  65. install_group.add_argument(
  66. "plugins",
  67. nargs="*",
  68. default=[],
  69. help="Names of custom development plugins to install (space-separated).",
  70. )
  71. install_group.add_argument(
  72. "--no_deps",
  73. action="store_true",
  74. help="Install custom development plugins without their dependencies.",
  75. )
  76. install_group.add_argument(
  77. "--platform",
  78. type=str,
  79. choices=["github.com", "gitee.com"],
  80. default="github.com",
  81. help="Platform to use for installation (default: github.com).",
  82. )
  83. install_group.add_argument(
  84. "-y",
  85. "--yes",
  86. dest="update_repos",
  87. action="store_true",
  88. help="Automatically confirm prompts and update repositories.",
  89. )
  90. install_group.add_argument(
  91. "--use_local_repos",
  92. action="store_true",
  93. default=False,
  94. help="Use local repositories if they exist.",
  95. )
  96. install_group.add_argument(
  97. "--deps_to_replace",
  98. type=str,
  99. nargs="+",
  100. default=None,
  101. help="Replace dependency version when installing from repositories.",
  102. )
  103. ################# pipeline predict #################
  104. pipeline_group.add_argument(
  105. "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
  106. )
  107. pipeline_group.add_argument(
  108. "--input",
  109. type=str,
  110. default=None,
  111. help="Input data or path for the pipeline, supports specific file and directory.",
  112. )
  113. pipeline_group.add_argument(
  114. "--save_path",
  115. type=str,
  116. default=None,
  117. help="Path to save the prediction results.",
  118. )
  119. pipeline_group.add_argument(
  120. "--device",
  121. type=str,
  122. default=None,
  123. help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
  124. )
  125. pipeline_group.add_argument(
  126. "--use_hpip",
  127. action="store_true",
  128. help="Use high-performance inference plugin.",
  129. )
  130. pipeline_group.add_argument(
  131. "--hpi_config",
  132. type=ast.literal_eval,
  133. help="High-performance inference configuration.",
  134. )
  135. pipeline_group.add_argument(
  136. "--get_pipeline_config",
  137. type=str,
  138. default=None,
  139. help="Retrieve the configuration for a specified pipeline.",
  140. )
  141. ################# serving #################
  142. serving_group.add_argument(
  143. "--serve",
  144. action="store_true",
  145. help="Start the serving application to handle requests.",
  146. )
  147. serving_group.add_argument(
  148. "--host",
  149. type=str,
  150. default="0.0.0.0",
  151. help="Host address to serve on (default: 0.0.0.0).",
  152. )
  153. serving_group.add_argument(
  154. "--port",
  155. type=int,
  156. default=8080,
  157. help="Port number to serve on (default: 8080).",
  158. )
  159. # Serving also uses `--pipeline`, `--device`, and `--use_hpip`
  160. ################# paddle2onnx #################
  161. paddle2onnx_group.add_argument(
  162. "--paddle2onnx",
  163. action="store_true",
  164. help="Convert PaddlePaddle model to ONNX format.",
  165. )
  166. paddle2onnx_group.add_argument(
  167. "--paddle_model_dir",
  168. type=str,
  169. help="Directory containing the PaddlePaddle model.",
  170. )
  171. paddle2onnx_group.add_argument(
  172. "--onnx_model_dir",
  173. type=str,
  174. help="Output directory for the ONNX model.",
  175. )
  176. paddle2onnx_group.add_argument(
  177. "--opset_version", type=int, default=7, help="Version of the ONNX opset to use."
  178. )
  179. # Parse known arguments to get the pipeline name
  180. args, remaining_args = parser.parse_known_args()
  181. pipeline = args.pipeline
  182. pipeline_args = []
  183. if not (args.install or args.serve or args.paddle2onnx) and pipeline is not None:
  184. if os.path.isfile(pipeline):
  185. pipeline_name = load_pipeline_config(pipeline)["pipeline_name"]
  186. else:
  187. pipeline_name = pipeline
  188. if pipeline_name not in PIPELINE_ARGUMENTS:
  189. support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
  190. logging.error(
  191. f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
  192. )
  193. sys.exit(1)
  194. pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
  195. if pipeline_args is None:
  196. pipeline_args = []
  197. pipeline_specific_group = parser.add_argument_group(
  198. f"{pipeline_name.capitalize()} Pipeline Options"
  199. )
  200. for arg in pipeline_args:
  201. pipeline_specific_group.add_argument(
  202. arg["name"],
  203. type=parse_str if arg["type"] is bool else arg["type"],
  204. help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
  205. )
  206. return parser, pipeline_args
  207. def install(args):
  208. """install paddlex"""
  209. def _install_serving_deps():
  210. reqs = get_serving_dep_specs()
  211. # Should we sort the requirements?
  212. install_packages(reqs)
  213. def _install_paddle2onnx_deps():
  214. install_packages([get_paddle2onnx_spec()])
  215. def _install_hpi_deps(device_type):
  216. SUPPORTED_DEVICE_TYPES = ["cpu", "gpu", "npu"]
  217. if device_type not in SUPPORTED_DEVICE_TYPES:
  218. logging.error(
  219. "Failed to install the high-performance plugin.\n"
  220. "Supported device types: %s. Your input device type: %s.\n",
  221. SUPPORTED_DEVICE_TYPES,
  222. device_type,
  223. )
  224. sys.exit(2)
  225. if device_type == "cpu":
  226. package = "ultra-infer-python"
  227. elif device_type == "gpu":
  228. if get_paddle_cuda_version()[0] != 11:
  229. sys.exit(
  230. "You are not using PaddlePaddle compiled with CUDA 11. Currently, CUDA versions other than 11.x are not supported by the high-performance inference plugin."
  231. )
  232. package = "ultra-infer-gpu-python"
  233. elif device_type == "npu":
  234. package = "ultra-infer-npu-python"
  235. with importlib.resources.path("paddlex", "hpip_links.html") as f:
  236. install_packages([package], pip_install_opts=["--find-links", str(f)])
  237. # Enable debug info
  238. os.environ["PADDLE_PDX_DEBUG"] = "True"
  239. # Disable eager initialization
  240. os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
  241. plugins = args.plugins[:]
  242. if "serving" in plugins:
  243. plugins.remove("serving")
  244. if plugins:
  245. logging.error("`serving` cannot be used together with other plugins.")
  246. sys.exit(2)
  247. _install_serving_deps()
  248. return
  249. if "paddle2onnx" in plugins:
  250. plugins.remove("paddle2onnx")
  251. if plugins:
  252. logging.error("`paddle2onnx` cannot be used together with other plugins.")
  253. sys.exit(2)
  254. _install_paddle2onnx_deps()
  255. return
  256. hpi_plugins = list(filter(lambda name: name.startswith("hpi-"), plugins))
  257. if hpi_plugins:
  258. for i in hpi_plugins:
  259. plugins.remove(i)
  260. if plugins:
  261. logging.error("`hpi` cannot be used together with other plugins.")
  262. sys.exit(2)
  263. if len(hpi_plugins) > 1 or len(hpi_plugins[0].split("-")) != 2:
  264. logging.error(
  265. "Invalid HPI plugin installation format detected.\n"
  266. "Correct format: paddlex --install hpi-<device_type>\n"
  267. "Example: paddlex --install hpi-gpu"
  268. )
  269. sys.exit(2)
  270. device_type = hpi_plugins[0].split("-")[1]
  271. _install_hpi_deps(device_type=device_type)
  272. return
  273. if plugins:
  274. repo_names = plugins
  275. elif len(plugins) == 0:
  276. repo_names = get_all_supported_repo_names()
  277. setup(
  278. repo_names=repo_names,
  279. no_deps=args.no_deps,
  280. platform=args.platform,
  281. update_repos=args.update_repos,
  282. use_local_repos=args.use_local_repos,
  283. deps_to_replace=args.deps_to_replace,
  284. )
  285. return
  286. def pipeline_predict(
  287. pipeline,
  288. input,
  289. device,
  290. save_path,
  291. use_hpip,
  292. hpi_config,
  293. **pipeline_args,
  294. ):
  295. """pipeline predict"""
  296. pipeline = create_pipeline(
  297. pipeline, device=device, use_hpip=use_hpip, hpi_config=hpi_config
  298. )
  299. result = pipeline.predict(input, **pipeline_args)
  300. for res in result:
  301. res.print()
  302. if save_path:
  303. res.save_all(save_path=save_path)
  304. def serve(pipeline, *, device, use_hpip, hpi_config, host, port):
  305. from .inference.serving.basic_serving import create_pipeline_app, run_server
  306. pipeline_config = load_pipeline_config(pipeline)
  307. pipeline = create_pipeline(
  308. config=pipeline_config, device=device, use_hpip=use_hpip, hpi_config=hpi_config
  309. )
  310. app = create_pipeline_app(pipeline, pipeline_config)
  311. run_server(app, host=host, port=port)
  312. # TODO: Move to another module
  313. def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
  314. require_paddle2onnx_plugin()
  315. ONNX_MODEL_FILENAME = f"{MODEL_FILE_PREFIX}.onnx"
  316. CONFIG_FILENAME = f"{MODEL_FILE_PREFIX}.yml"
  317. ADDITIONAL_FILENAMES = ["scaler.pkl"]
  318. def _check_input_dir(input_dir):
  319. if input_dir is None:
  320. sys.exit("Input directory must be specified")
  321. if not input_dir.exists():
  322. sys.exit(f"{input_dir} does not exist")
  323. if not input_dir.is_dir():
  324. sys.exit(f"{input_dir} is not a directory")
  325. model_paths = get_model_paths(input_dir)
  326. if "paddle" not in model_paths:
  327. sys.exit("PaddlePaddle model does not exist")
  328. config_path = input_dir / CONFIG_FILENAME
  329. if not config_path.exists():
  330. sys.exit(f"{config_path} does not exist")
  331. def _check_paddle2onnx():
  332. if shutil.which("paddle2onnx") is None:
  333. sys.exit("Paddle2ONNX is not available. Please install the plugin first.")
  334. def _run_paddle2onnx(input_dir, output_dir, opset_version):
  335. model_paths = get_model_paths(input_dir)
  336. logging.info("Paddle2ONNX conversion starting...")
  337. # XXX: To circumvent Paddle2ONNX's bug
  338. cmd = [
  339. "paddle2onnx",
  340. "--model_dir",
  341. str(model_paths["paddle"][0].parent),
  342. "--model_filename",
  343. str(model_paths["paddle"][0].name),
  344. "--params_filename",
  345. str(model_paths["paddle"][1].name),
  346. "--save_file",
  347. str(output_dir / ONNX_MODEL_FILENAME),
  348. "--opset_version",
  349. str(opset_version),
  350. ]
  351. try:
  352. subprocess.check_call(cmd)
  353. except subprocess.CalledProcessError as e:
  354. sys.exit(f"Paddle2ONNX conversion failed with exit code {e.returncode}")
  355. logging.info("Paddle2ONNX conversion succeeded")
  356. def _copy_config_file(input_dir, output_dir):
  357. src_path = input_dir / CONFIG_FILENAME
  358. dst_path = output_dir / CONFIG_FILENAME
  359. shutil.copy(src_path, dst_path)
  360. logging.info(f"Copied {src_path} to {dst_path}")
  361. def _copy_additional_files(input_dir, output_dir):
  362. for filename in ADDITIONAL_FILENAMES:
  363. src_path = input_dir / filename
  364. if not src_path.exists():
  365. continue
  366. dst_path = output_dir / filename
  367. shutil.copy(src_path, dst_path)
  368. logging.info(f"Copied {src_path} to {dst_path}")
  369. paddle_model_dir = Path(paddle_model_dir)
  370. if not onnx_model_dir:
  371. onnx_model_dir = paddle_model_dir
  372. onnx_model_dir = Path(onnx_model_dir)
  373. logging.info(f"Input dir: {paddle_model_dir}")
  374. logging.info(f"Output dir: {onnx_model_dir}")
  375. _check_input_dir(paddle_model_dir)
  376. _check_paddle2onnx()
  377. _run_paddle2onnx(paddle_model_dir, onnx_model_dir, opset_version)
  378. if not (onnx_model_dir.exists() and onnx_model_dir.samefile(paddle_model_dir)):
  379. _copy_config_file(paddle_model_dir, onnx_model_dir)
  380. _copy_additional_files(paddle_model_dir, onnx_model_dir)
  381. logging.info("Done")
  382. # for CLI
  383. def main():
  384. """API for command line"""
  385. parser, pipeline_args = args_cfg()
  386. args = parser.parse_args()
  387. if len(sys.argv) == 1:
  388. logging.warning("No arguments provided. Displaying help information:")
  389. parser.print_help()
  390. sys.exit(2)
  391. if args.install:
  392. install(args)
  393. elif args.serve:
  394. serve(
  395. args.pipeline,
  396. device=args.device,
  397. use_hpip=args.use_hpip,
  398. hpi_config=args.hpi_config,
  399. host=args.host,
  400. port=args.port,
  401. )
  402. elif args.paddle2onnx:
  403. paddle_to_onnx(
  404. args.paddle_model_dir,
  405. args.onnx_model_dir,
  406. opset_version=args.opset_version,
  407. )
  408. else:
  409. if args.get_pipeline_config is not None:
  410. interactive_get_pipeline(args.get_pipeline_config, args.save_path)
  411. else:
  412. pipeline_args_dict = {}
  413. for arg in pipeline_args:
  414. arg_name = arg["name"].lstrip("-")
  415. if hasattr(args, arg_name):
  416. pipeline_args_dict[arg_name] = getattr(args, arg_name)
  417. else:
  418. logging.warning(f"Argument {arg_name} is missing in args")
  419. return pipeline_predict(
  420. args.pipeline,
  421. args.input,
  422. args.device,
  423. args.save_path,
  424. use_hpip=args.use_hpip,
  425. hpi_config=args.hpi_config,
  426. **pipeline_args_dict,
  427. )