paddlex_cli.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import ast
  16. import importlib.resources
  17. import os
  18. import shutil
  19. import subprocess
  20. import sys
  21. from pathlib import Path
  22. from . import create_pipeline
  23. from .constants import MODEL_FILE_PREFIX
  24. from .inference.pipelines import load_pipeline_config
  25. from .inference.utils.model_paths import get_model_paths
  26. from .repo_manager import get_all_supported_repo_names, setup
  27. from .utils import logging
  28. from .utils.deps import (
  29. get_dep_version,
  30. get_paddle2onnx_spec,
  31. get_serving_dep_specs,
  32. is_paddle2onnx_plugin_available,
  33. require_paddle2onnx_plugin,
  34. )
  35. from .utils.env import get_paddle_cuda_version
  36. from .utils.install import install_packages, uninstall_packages
  37. from .utils.interactive_get_pipeline import interactive_get_pipeline
  38. from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
  39. def args_cfg():
  40. """parse cli arguments"""
  41. def parse_str(s):
  42. """convert str type value
  43. to None type if it is "None",
  44. to bool type if it means True or False.
  45. """
  46. if s in ("None", "none", "NONE"):
  47. return None
  48. elif s in ("TRUE", "True", "true", "T", "t"):
  49. return True
  50. elif s in ("FALSE", "False", "false", "F", "f"):
  51. return False
  52. return s
  53. parser = argparse.ArgumentParser(
  54. "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
  55. )
  56. install_group = parser.add_argument_group("Install PaddleX Options")
  57. pipeline_group = parser.add_argument_group("Pipeline Predict Options")
  58. serving_group = parser.add_argument_group("Serving Options")
  59. paddle2onnx_group = parser.add_argument_group("Paddle2ONNX Options")
  60. ################# install pdx #################
  61. install_group.add_argument(
  62. "--install",
  63. nargs="*",
  64. metavar="PLUGIN",
  65. help="Install specified PaddleX plugins.",
  66. )
  67. install_group.add_argument(
  68. "--no_deps",
  69. action="store_true",
  70. help="Install custom development plugins without their dependencies.",
  71. )
  72. install_group.add_argument(
  73. "--platform",
  74. type=str,
  75. choices=["github.com", "gitee.com"],
  76. default="github.com",
  77. help="Platform to use for installation (default: github.com).",
  78. )
  79. install_group.add_argument(
  80. "-y",
  81. "--yes",
  82. dest="update_repos",
  83. action="store_true",
  84. help="Automatically confirm prompts and update repositories.",
  85. )
  86. install_group.add_argument(
  87. "--use_local_repos",
  88. action="store_true",
  89. default=False,
  90. help="Use local repositories if they exist.",
  91. )
  92. install_group.add_argument(
  93. "--deps_to_replace",
  94. type=str,
  95. nargs="+",
  96. default=None,
  97. help="Replace dependency version when installing from repositories.",
  98. )
  99. ################# pipeline predict #################
  100. pipeline_group.add_argument(
  101. "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
  102. )
  103. pipeline_group.add_argument(
  104. "--input",
  105. type=str,
  106. default=None,
  107. help="Input data or path for the pipeline, supports specific file and directory.",
  108. )
  109. pipeline_group.add_argument(
  110. "--save_path",
  111. type=str,
  112. default=None,
  113. help="Path to save the prediction results.",
  114. )
  115. pipeline_group.add_argument(
  116. "--device",
  117. type=str,
  118. default=None,
  119. help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
  120. )
  121. pipeline_group.add_argument(
  122. "--use_hpip",
  123. action="store_true",
  124. help="Use high-performance inference plugin.",
  125. )
  126. pipeline_group.add_argument(
  127. "--hpi_config",
  128. type=ast.literal_eval,
  129. help="High-performance inference configuration.",
  130. )
  131. pipeline_group.add_argument(
  132. "--get_pipeline_config",
  133. type=str,
  134. default=None,
  135. help="Retrieve the configuration for a specified pipeline.",
  136. )
  137. ################# serving #################
  138. serving_group.add_argument(
  139. "--serve",
  140. action="store_true",
  141. help="Start the serving application to handle requests.",
  142. )
  143. serving_group.add_argument(
  144. "--host",
  145. type=str,
  146. default="0.0.0.0",
  147. help="Host address to serve on (default: 0.0.0.0).",
  148. )
  149. serving_group.add_argument(
  150. "--port",
  151. type=int,
  152. default=8080,
  153. help="Port number to serve on (default: 8080).",
  154. )
  155. # Serving also uses `--pipeline`, `--device`, `--use_hpip`, and `--hpi_config`
  156. ################# paddle2onnx #################
  157. paddle2onnx_group.add_argument(
  158. "--paddle2onnx",
  159. action="store_true",
  160. help="Convert PaddlePaddle model to ONNX format.",
  161. )
  162. paddle2onnx_group.add_argument(
  163. "--paddle_model_dir",
  164. type=str,
  165. help="Directory containing the PaddlePaddle model.",
  166. )
  167. paddle2onnx_group.add_argument(
  168. "--onnx_model_dir",
  169. type=str,
  170. help="Output directory for the ONNX model.",
  171. )
  172. paddle2onnx_group.add_argument(
  173. "--opset_version", type=int, default=7, help="Version of the ONNX opset to use."
  174. )
  175. # Parse known arguments to get the pipeline name
  176. args, remaining_args = parser.parse_known_args()
  177. pipeline = args.pipeline
  178. pipeline_args = []
  179. if (
  180. not (args.install is not None or args.serve or args.paddle2onnx)
  181. and pipeline is not None
  182. ):
  183. if os.path.isfile(pipeline):
  184. pipeline_name = load_pipeline_config(pipeline)["pipeline_name"]
  185. else:
  186. pipeline_name = pipeline
  187. if pipeline_name not in PIPELINE_ARGUMENTS:
  188. support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
  189. logging.error(
  190. f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
  191. )
  192. sys.exit(1)
  193. pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
  194. if pipeline_args is None:
  195. pipeline_args = []
  196. pipeline_specific_group = parser.add_argument_group(
  197. f"{pipeline_name.capitalize()} Pipeline Options"
  198. )
  199. for arg in pipeline_args:
  200. pipeline_specific_group.add_argument(
  201. arg["name"],
  202. type=parse_str if arg["type"] is bool else arg["type"],
  203. help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
  204. )
  205. return parser, pipeline_args
  206. def install(args):
  207. """install paddlex"""
  208. def _install_serving_deps():
  209. reqs = get_serving_dep_specs()
  210. # Should we sort the requirements?
  211. install_packages(reqs)
  212. def _install_paddle2onnx_deps():
  213. install_packages([get_paddle2onnx_spec()])
  214. def _install_hpi_deps(device_type):
  215. SUPPORTED_DEVICE_TYPES = ["cpu", "gpu", "npu"]
  216. if device_type not in SUPPORTED_DEVICE_TYPES:
  217. logging.error(
  218. "Failed to install the high-performance plugin.\n"
  219. "Supported device types: %s. Your input device type: %s.\n",
  220. SUPPORTED_DEVICE_TYPES,
  221. device_type,
  222. )
  223. sys.exit(2)
  224. hpip_links_file = "hpip_links.html"
  225. if device_type == "gpu":
  226. cuda_version = get_paddle_cuda_version()
  227. if not cuda_version:
  228. sys.exit(
  229. "No CUDA version found. Please make sure you have installed PaddlePaddle with CUDA enabled."
  230. )
  231. if cuda_version[0] == 12:
  232. hpip_links_file = "hpip_links_cu12.html"
  233. elif cuda_version[0] != 11:
  234. sys.exit(
  235. "Currently, only CUDA versions 11.x and 12.x are supported by the high-performance inference plugin."
  236. )
  237. package_mapping = {
  238. "cpu": "ultra-infer-python",
  239. "gpu": "ultra-infer-gpu-python",
  240. "npu": "ultra-infer-npu-python",
  241. }
  242. package = package_mapping[device_type]
  243. other_packages = set(package_mapping.values()) - {package}
  244. for other_package in other_packages:
  245. version = get_dep_version(other_package)
  246. if version is not None:
  247. logging.info(
  248. f"The high-performance inference plugin '{package}' is mutually exclusive with '{other_package}' (version {version} installed). Uninstalling '{other_package}'..."
  249. )
  250. uninstall_packages([other_package])
  251. with importlib.resources.path("paddlex", hpip_links_file) as f:
  252. version = get_dep_version(package)
  253. if version is None:
  254. install_packages([package], pip_install_opts=["--find-links", str(f)])
  255. else:
  256. response = input(
  257. f"The high-performance inference plugin is already installed (version {repr(version)}). Do you want to reinstall it? (y/n):"
  258. )
  259. if response.lower() in ["y", "yes"]:
  260. uninstall_packages([package])
  261. install_packages(
  262. [package],
  263. pip_install_opts=[
  264. "--find-links",
  265. str(f),
  266. ],
  267. )
  268. else:
  269. return
  270. if not is_paddle2onnx_plugin_available():
  271. logging.info(
  272. "The Paddle2ONNX plugin is not available. It is recommended to run `paddlex --install paddle2onnx` to install the Paddle2ONNX plugin to use the full functionality of high-performance inference."
  273. )
  274. # Enable debug info
  275. os.environ["PADDLE_PDX_DEBUG"] = "True"
  276. # Disable eager initialization
  277. os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
  278. plugins = args.install[:]
  279. if "serving" in plugins:
  280. plugins.remove("serving")
  281. if plugins:
  282. logging.error("`serving` cannot be used together with other plugins.")
  283. sys.exit(2)
  284. _install_serving_deps()
  285. return
  286. if "paddle2onnx" in plugins:
  287. plugins.remove("paddle2onnx")
  288. if plugins:
  289. logging.error("`paddle2onnx` cannot be used together with other plugins.")
  290. sys.exit(2)
  291. _install_paddle2onnx_deps()
  292. return
  293. hpi_plugins = list(filter(lambda name: name.startswith("hpi-"), plugins))
  294. if hpi_plugins:
  295. for i in hpi_plugins:
  296. plugins.remove(i)
  297. if plugins:
  298. logging.error("`hpi` cannot be used together with other plugins.")
  299. sys.exit(2)
  300. if len(hpi_plugins) > 1 or len(hpi_plugins[0].split("-")) != 2:
  301. logging.error(
  302. "Invalid HPI plugin installation format detected.\n"
  303. "Correct format: paddlex --install hpi-<device_type>\n"
  304. "Example: paddlex --install hpi-gpu"
  305. )
  306. sys.exit(2)
  307. device_type = hpi_plugins[0].split("-")[1]
  308. _install_hpi_deps(device_type=device_type)
  309. return
  310. if plugins:
  311. repo_names = plugins
  312. elif len(plugins) == 0:
  313. repo_names = get_all_supported_repo_names()
  314. setup(
  315. repo_names=repo_names,
  316. no_deps=args.no_deps,
  317. platform=args.platform,
  318. update_repos=args.update_repos,
  319. use_local_repos=args.use_local_repos,
  320. deps_to_replace=args.deps_to_replace,
  321. )
  322. return
  323. def pipeline_predict(
  324. pipeline,
  325. input,
  326. device,
  327. save_path,
  328. use_hpip,
  329. hpi_config,
  330. **pipeline_args,
  331. ):
  332. """pipeline predict"""
  333. pipeline = create_pipeline(
  334. pipeline, device=device, use_hpip=use_hpip, hpi_config=hpi_config
  335. )
  336. result = pipeline.predict(input, **pipeline_args)
  337. for res in result:
  338. res.print()
  339. if save_path:
  340. res.save_all(save_path=save_path)
  341. def serve(pipeline, *, device, use_hpip, hpi_config, host, port):
  342. from .inference.serving.basic_serving import create_pipeline_app, run_server
  343. pipeline_config = load_pipeline_config(pipeline)
  344. pipeline = create_pipeline(
  345. config=pipeline_config, device=device, use_hpip=use_hpip, hpi_config=hpi_config
  346. )
  347. app = create_pipeline_app(pipeline, pipeline_config)
  348. run_server(app, host=host, port=port)
  349. # TODO: Move to another module
  350. def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
  351. require_paddle2onnx_plugin()
  352. ONNX_MODEL_FILENAME = f"{MODEL_FILE_PREFIX}.onnx"
  353. CONFIG_FILENAME = f"{MODEL_FILE_PREFIX}.yml"
  354. ADDITIONAL_FILENAMES = ["scaler.pkl"]
  355. def _check_input_dir(input_dir):
  356. if input_dir is None:
  357. sys.exit("Input directory must be specified")
  358. if not input_dir.exists():
  359. sys.exit(f"{input_dir} does not exist")
  360. if not input_dir.is_dir():
  361. sys.exit(f"{input_dir} is not a directory")
  362. model_paths = get_model_paths(input_dir)
  363. if "paddle" not in model_paths:
  364. sys.exit("PaddlePaddle model does not exist")
  365. config_path = input_dir / CONFIG_FILENAME
  366. if not config_path.exists():
  367. sys.exit(f"{config_path} does not exist")
  368. def _check_paddle2onnx():
  369. if shutil.which("paddle2onnx") is None:
  370. sys.exit("Paddle2ONNX is not available. Please install the plugin first.")
  371. def _run_paddle2onnx(input_dir, output_dir, opset_version):
  372. model_paths = get_model_paths(input_dir)
  373. logging.info("Paddle2ONNX conversion starting...")
  374. # XXX: To circumvent Paddle2ONNX's bug
  375. cmd = [
  376. "paddle2onnx",
  377. "--model_dir",
  378. str(model_paths["paddle"][0].parent),
  379. "--model_filename",
  380. str(model_paths["paddle"][0].name),
  381. "--params_filename",
  382. str(model_paths["paddle"][1].name),
  383. "--save_file",
  384. str(output_dir / ONNX_MODEL_FILENAME),
  385. "--opset_version",
  386. str(opset_version),
  387. ]
  388. try:
  389. subprocess.check_call(cmd)
  390. except subprocess.CalledProcessError as e:
  391. sys.exit(f"Paddle2ONNX conversion failed with exit code {e.returncode}")
  392. logging.info("Paddle2ONNX conversion succeeded")
  393. def _copy_config_file(input_dir, output_dir):
  394. src_path = input_dir / CONFIG_FILENAME
  395. dst_path = output_dir / CONFIG_FILENAME
  396. shutil.copy(src_path, dst_path)
  397. logging.info(f"Copied {src_path} to {dst_path}")
  398. def _copy_additional_files(input_dir, output_dir):
  399. for filename in ADDITIONAL_FILENAMES:
  400. src_path = input_dir / filename
  401. if not src_path.exists():
  402. continue
  403. dst_path = output_dir / filename
  404. shutil.copy(src_path, dst_path)
  405. logging.info(f"Copied {src_path} to {dst_path}")
  406. paddle_model_dir = Path(paddle_model_dir)
  407. if not onnx_model_dir:
  408. onnx_model_dir = paddle_model_dir
  409. onnx_model_dir = Path(onnx_model_dir)
  410. logging.info(f"Input dir: {paddle_model_dir}")
  411. logging.info(f"Output dir: {onnx_model_dir}")
  412. _check_input_dir(paddle_model_dir)
  413. _check_paddle2onnx()
  414. _run_paddle2onnx(paddle_model_dir, onnx_model_dir, opset_version)
  415. if not (onnx_model_dir.exists() and onnx_model_dir.samefile(paddle_model_dir)):
  416. _copy_config_file(paddle_model_dir, onnx_model_dir)
  417. _copy_additional_files(paddle_model_dir, onnx_model_dir)
  418. logging.info("Done")
  419. # for CLI
  420. def main():
  421. """API for command line"""
  422. parser, pipeline_args = args_cfg()
  423. args = parser.parse_args()
  424. if len(sys.argv) == 1:
  425. logging.warning("No arguments provided. Displaying help information:")
  426. parser.print_help()
  427. sys.exit(2)
  428. if args.install is not None:
  429. install(args)
  430. return
  431. elif args.serve:
  432. serve(
  433. args.pipeline,
  434. device=args.device,
  435. use_hpip=args.use_hpip or None,
  436. hpi_config=args.hpi_config,
  437. host=args.host,
  438. port=args.port,
  439. )
  440. return
  441. elif args.paddle2onnx:
  442. paddle_to_onnx(
  443. args.paddle_model_dir,
  444. args.onnx_model_dir,
  445. opset_version=args.opset_version,
  446. )
  447. return
  448. else:
  449. if args.get_pipeline_config is not None:
  450. interactive_get_pipeline(args.get_pipeline_config, args.save_path)
  451. else:
  452. pipeline_args_dict = {}
  453. for arg in pipeline_args:
  454. arg_name = arg["name"].lstrip("-")
  455. if hasattr(args, arg_name):
  456. pipeline_args_dict[arg_name] = getattr(args, arg_name)
  457. else:
  458. logging.warning(f"Argument {arg_name} is missing in args")
  459. pipeline_predict(
  460. args.pipeline,
  461. args.input,
  462. args.device,
  463. args.save_path,
  464. use_hpip=args.use_hpip or None,
  465. hpi_config=args.hpi_config,
  466. **pipeline_args_dict,
  467. )
  468. return