paddlex_cli.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import importlib.resources
  16. import os
  17. import shutil
  18. import subprocess
  19. import sys
  20. from pathlib import Path
  21. from . import create_pipeline
  22. from .constants import MODEL_FILE_PREFIX
  23. from .inference.pipelines import load_pipeline_config
  24. from .repo_manager import get_all_supported_repo_names, setup
  25. from .utils import logging
  26. from .utils.deps import (
  27. get_paddle2onnx_spec,
  28. get_serving_dep_specs,
  29. require_paddle2onnx_plugin,
  30. )
  31. from .utils.flags import FLAGS_json_format_model
  32. from .utils.install import install_packages
  33. from .utils.interactive_get_pipeline import interactive_get_pipeline
  34. from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
  35. def args_cfg():
  36. """parse cli arguments"""
  37. def parse_str(s):
  38. """convert str type value
  39. to None type if it is "None",
  40. to bool type if it means True or False.
  41. """
  42. if s in ("None", "none", "NONE"):
  43. return None
  44. elif s in ("TRUE", "True", "true", "T", "t"):
  45. return True
  46. elif s in ("FALSE", "False", "false", "F", "f"):
  47. return False
  48. return s
  49. parser = argparse.ArgumentParser(
  50. "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
  51. )
  52. install_group = parser.add_argument_group("Install PaddleX Options")
  53. pipeline_group = parser.add_argument_group("Pipeline Predict Options")
  54. serving_group = parser.add_argument_group("Serving Options")
  55. paddle2onnx_group = parser.add_argument_group("Paddle2ONNX Options")
  56. ################# install pdx #################
  57. install_group.add_argument(
  58. "--install",
  59. action="store_true",
  60. default=False,
  61. help="Install specified PaddleX plugins.",
  62. )
  63. install_group.add_argument(
  64. "plugins",
  65. nargs="*",
  66. default=[],
  67. help="Names of custom development plugins to install (space-separated).",
  68. )
  69. install_group.add_argument(
  70. "--no_deps",
  71. action="store_true",
  72. help="Install custom development plugins without their dependencies.",
  73. )
  74. install_group.add_argument(
  75. "--platform",
  76. type=str,
  77. choices=["github.com", "gitee.com"],
  78. default="github.com",
  79. help="Platform to use for installation (default: github.com).",
  80. )
  81. install_group.add_argument(
  82. "-y",
  83. "--yes",
  84. dest="update_repos",
  85. action="store_true",
  86. help="Automatically confirm prompts and update repositories.",
  87. )
  88. install_group.add_argument(
  89. "--use_local_repos",
  90. action="store_true",
  91. default=False,
  92. help="Use local repositories if they exist.",
  93. )
  94. install_group.add_argument(
  95. "--deps_to_replace",
  96. type=str,
  97. nargs="+",
  98. default=None,
  99. help="Replace dependency version when installing from repositories.",
  100. )
  101. ################# pipeline predict #################
  102. pipeline_group.add_argument(
  103. "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
  104. )
  105. pipeline_group.add_argument(
  106. "--input",
  107. type=str,
  108. default=None,
  109. help="Input data or path for the pipeline, supports specific file and directory.",
  110. )
  111. pipeline_group.add_argument(
  112. "--save_path",
  113. type=str,
  114. default=None,
  115. help="Path to save the prediction results.",
  116. )
  117. pipeline_group.add_argument(
  118. "--device",
  119. type=str,
  120. default=None,
  121. help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
  122. )
  123. pipeline_group.add_argument(
  124. "--use_hpip",
  125. action="store_true",
  126. help="Enable HPIP acceleration by default.",
  127. )
  128. pipeline_group.add_argument(
  129. "--get_pipeline_config",
  130. type=str,
  131. default=None,
  132. help="Retrieve the configuration for a specified pipeline.",
  133. )
  134. ################# serving #################
  135. serving_group.add_argument(
  136. "--serve",
  137. action="store_true",
  138. help="Start the serving application to handle requests.",
  139. )
  140. serving_group.add_argument(
  141. "--host",
  142. type=str,
  143. default="0.0.0.0",
  144. help="Host address to serve on (default: 0.0.0.0).",
  145. )
  146. serving_group.add_argument(
  147. "--port",
  148. type=int,
  149. default=8080,
  150. help="Port number to serve on (default: 8080).",
  151. )
  152. # Serving also uses `--pipeline`, `--device`, and `--use_hpip`
  153. ################# paddle2onnx #################
  154. paddle2onnx_group.add_argument(
  155. "--paddle2onnx",
  156. action="store_true",
  157. help="Convert PaddlePaddle model to ONNX format",
  158. )
  159. paddle2onnx_group.add_argument(
  160. "--paddle_model_dir",
  161. type=str,
  162. help="Directory containing the PaddlePaddle model",
  163. )
  164. paddle2onnx_group.add_argument(
  165. "--onnx_model_dir",
  166. type=str,
  167. help="Output directory for the ONNX model",
  168. )
  169. paddle2onnx_group.add_argument(
  170. "--opset_version", type=int, help="Version of the ONNX opset to use"
  171. )
  172. # Parse known arguments to get the pipeline name
  173. args, remaining_args = parser.parse_known_args()
  174. pipeline = args.pipeline
  175. pipeline_args = []
  176. if not (args.install or args.serve or args.paddle2onnx) and pipeline is not None:
  177. if os.path.isfile(pipeline):
  178. pipeline_name = load_pipeline_config(pipeline)["pipeline_name"]
  179. else:
  180. pipeline_name = pipeline
  181. if pipeline_name not in PIPELINE_ARGUMENTS:
  182. support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
  183. logging.error(
  184. f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
  185. )
  186. sys.exit(1)
  187. pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
  188. if pipeline_args is None:
  189. pipeline_args = []
  190. pipeline_specific_group = parser.add_argument_group(
  191. f"{pipeline_name.capitalize()} Pipeline Options"
  192. )
  193. for arg in pipeline_args:
  194. pipeline_specific_group.add_argument(
  195. arg["name"],
  196. type=parse_str if arg["type"] is bool else arg["type"],
  197. help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
  198. )
  199. return parser, pipeline_args
  200. def install(args):
  201. """install paddlex"""
  202. def _install_serving_deps():
  203. reqs = get_serving_dep_specs()
  204. # Should we sort the requirements?
  205. install_packages(reqs)
  206. def _install_paddle2onnx_deps():
  207. install_packages([get_paddle2onnx_spec()])
  208. def _install_hpi_deps(device_type):
  209. SUPPORTED_DEVICE_TYPES = ["cpu", "gpu", "npu"]
  210. if device_type not in SUPPORTED_DEVICE_TYPES:
  211. logging.error(
  212. "HPI installation failed!\n"
  213. "Supported device_type: %s. Your input device_type: %s.\n"
  214. "Please ensure the device_type is correct.",
  215. SUPPORTED_DEVICE_TYPES,
  216. device_type,
  217. )
  218. sys.exit(2)
  219. if device_type == "cpu":
  220. package = "ultra-infer-python"
  221. elif device_type == "gpu":
  222. package = "ultra-infer-gpu-python"
  223. elif device_type == "npu":
  224. package = "ultra-infer-npu-python"
  225. with importlib.resources.path("paddlex", "hpip_links.html") as f:
  226. install_packages([package], pip_install_opts=["--find-links", str(f)])
  227. # Enable debug info
  228. os.environ["PADDLE_PDX_DEBUG"] = "True"
  229. # Disable eager initialization
  230. os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
  231. plugins = args.plugins[:]
  232. if "serving" in plugins:
  233. plugins.remove("serving")
  234. if plugins:
  235. logging.error("`serving` cannot be used together with other plugins.")
  236. sys.exit(2)
  237. _install_serving_deps()
  238. return
  239. if "paddle2onnx" in plugins:
  240. plugins.remove("paddle2onnx")
  241. if plugins:
  242. logging.error("`paddle2onnx` cannot be used together with other plugins.")
  243. sys.exit(2)
  244. _install_paddle2onnx_deps()
  245. return
  246. hpi_plugins = list(filter(lambda name: name.startswith("hpi-"), plugins))
  247. if hpi_plugins:
  248. for i in hpi_plugins:
  249. plugins.remove(i)
  250. if plugins:
  251. logging.error("`hpi` cannot be used together with other plugins.")
  252. sys.exit(2)
  253. if len(hpi_plugins) > 1 or len(hpi_plugins[0].split("-")) != 2:
  254. logging.error(
  255. "Invalid HPI plugin installation format detected.\n"
  256. "Correct format: paddlex --install hpi-<device_type>\n"
  257. "Example: paddlex --install hpi-gpu"
  258. )
  259. sys.exit(2)
  260. device_type = hpi_plugins[0].split("-")[1]
  261. _install_hpi_deps(device_type=device_type)
  262. return
  263. if plugins:
  264. repo_names = plugins
  265. elif len(plugins) == 0:
  266. repo_names = get_all_supported_repo_names()
  267. setup(
  268. repo_names=repo_names,
  269. no_deps=args.no_deps,
  270. platform=args.platform,
  271. update_repos=args.update_repos,
  272. use_local_repos=args.use_local_repos,
  273. deps_to_replace=args.deps_to_replace,
  274. )
  275. return
  276. def pipeline_predict(
  277. pipeline,
  278. input,
  279. device,
  280. save_path,
  281. use_hpip,
  282. **pipeline_args,
  283. ):
  284. """pipeline predict"""
  285. pipeline = create_pipeline(pipeline, device=device, use_hpip=use_hpip)
  286. result = pipeline.predict(input, **pipeline_args)
  287. for res in result:
  288. res.print()
  289. if save_path:
  290. res.save_all(save_path=save_path)
  291. def serve(pipeline, *, device, use_hpip, host, port):
  292. from .inference.serving.basic_serving import create_pipeline_app, run_server
  293. pipeline_config = load_pipeline_config(pipeline)
  294. pipeline = create_pipeline(config=pipeline_config, device=device, use_hpip=use_hpip)
  295. app = create_pipeline_app(pipeline, pipeline_config)
  296. run_server(app, host=host, port=port)
  297. # TODO: Move to another module
  298. def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
  299. require_paddle2onnx_plugin()
  300. PD_MODEL_FILE_PREFIX = MODEL_FILE_PREFIX
  301. PD_PARAMS_FILENAME = f"{MODEL_FILE_PREFIX}.pdiparams"
  302. ONNX_MODEL_FILENAME = f"{MODEL_FILE_PREFIX}.onnx"
  303. CONFIG_FILENAME = f"{MODEL_FILE_PREFIX}.yml"
  304. ADDITIONAL_FILENAMES = ["scaler.pkl"]
  305. def _check_input_dir(input_dir, pd_model_file_ext):
  306. if input_dir is None:
  307. sys.exit("Input directory must be specified")
  308. if not input_dir.exists():
  309. sys.exit(f"{input_dir} does not exist")
  310. if not input_dir.is_dir():
  311. sys.exit(f"{input_dir} is not a directory")
  312. model_path = (input_dir / PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)
  313. if not model_path.exists():
  314. sys.exit(f"{model_path} does not exist")
  315. params_path = input_dir / PD_PARAMS_FILENAME
  316. if not params_path.exists():
  317. sys.exit(f"{params_path} does not exist")
  318. config_path = input_dir / CONFIG_FILENAME
  319. if not config_path.exists():
  320. sys.exit(f"{config_path} does not exist")
  321. def _check_paddle2onnx():
  322. if shutil.which("paddle2onnx") is None:
  323. sys.exit("Paddle2ONNX is not available. Please install the plugin first.")
  324. def _run_paddle2onnx(input_dir, pd_model_file_ext, output_dir, opset_version):
  325. logging.info("Paddle2ONNX conversion starting...")
  326. # XXX: To circumvent Paddle2ONNX's bug
  327. if opset_version is None:
  328. if pd_model_file_ext == ".json":
  329. opset_version = 19
  330. else:
  331. opset_version = 7
  332. logging.info("Using default ONNX opset version: %d", opset_version)
  333. cmd = [
  334. "paddle2onnx",
  335. "--model_dir",
  336. str(input_dir),
  337. "--model_filename",
  338. str(Path(PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)),
  339. "--params_filename",
  340. PD_PARAMS_FILENAME,
  341. "--save_file",
  342. str(output_dir / ONNX_MODEL_FILENAME),
  343. "--opset_version",
  344. str(opset_version),
  345. ]
  346. try:
  347. subprocess.check_call(cmd)
  348. except subprocess.CalledProcessError as e:
  349. sys.exit(f"Paddle2ONNX conversion failed with exit code {e.returncode}")
  350. logging.info("Paddle2ONNX conversion succeeded")
  351. def _copy_config_file(input_dir, output_dir):
  352. src_path = input_dir / CONFIG_FILENAME
  353. dst_path = output_dir / CONFIG_FILENAME
  354. shutil.copy(src_path, dst_path)
  355. logging.info(f"Copied {src_path} to {dst_path}")
  356. def _copy_additional_files(input_dir, output_dir):
  357. for filename in ADDITIONAL_FILENAMES:
  358. src_path = input_dir / filename
  359. if not src_path.exists():
  360. continue
  361. dst_path = output_dir / filename
  362. shutil.copy(src_path, dst_path)
  363. logging.info(f"Copied {src_path} to {dst_path}")
  364. paddle_model_dir = Path(paddle_model_dir)
  365. if not onnx_model_dir:
  366. onnx_model_dir = paddle_model_dir
  367. onnx_model_dir = Path(onnx_model_dir)
  368. logging.info(f"Input dir: {paddle_model_dir}")
  369. logging.info(f"Output dir: {onnx_model_dir}")
  370. pd_model_file_ext = ".json"
  371. if not FLAGS_json_format_model:
  372. if not (paddle_model_dir / f"{PD_MODEL_FILE_PREFIX}.json").exists():
  373. pd_model_file_ext = ".pdmodel"
  374. _check_input_dir(paddle_model_dir, pd_model_file_ext)
  375. _check_paddle2onnx()
  376. _run_paddle2onnx(paddle_model_dir, pd_model_file_ext, onnx_model_dir, opset_version)
  377. if not (onnx_model_dir.exists() and onnx_model_dir.samefile(paddle_model_dir)):
  378. _copy_config_file(paddle_model_dir, onnx_model_dir)
  379. _copy_additional_files(paddle_model_dir, onnx_model_dir)
  380. logging.info("Done")
  381. # for CLI
  382. def main():
  383. """API for command line"""
  384. parser, pipeline_args = args_cfg()
  385. args = parser.parse_args()
  386. if len(sys.argv) == 1:
  387. logging.warning("No arguments provided. Displaying help information:")
  388. parser.print_help()
  389. sys.exit(2)
  390. if args.install:
  391. install(args)
  392. elif args.serve:
  393. serve(
  394. args.pipeline,
  395. device=args.device,
  396. use_hpip=args.use_hpip,
  397. host=args.host,
  398. port=args.port,
  399. )
  400. elif args.paddle2onnx:
  401. paddle_to_onnx(
  402. args.paddle_model_dir,
  403. args.onnx_model_dir,
  404. opset_version=args.opset_version,
  405. )
  406. else:
  407. if args.get_pipeline_config is not None:
  408. interactive_get_pipeline(args.get_pipeline_config, args.save_path)
  409. else:
  410. pipeline_args_dict = {}
  411. for arg in pipeline_args:
  412. arg_name = arg["name"].lstrip("-")
  413. if hasattr(args, arg_name):
  414. pipeline_args_dict[arg_name] = getattr(args, arg_name)
  415. else:
  416. logging.warning(f"Argument {arg_name} is missing in args")
  417. return pipeline_predict(
  418. args.pipeline,
  419. args.input,
  420. args.device,
  421. args.save_path,
  422. use_hpip=args.use_hpip,
  423. **pipeline_args_dict,
  424. )