paddlex_cli.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import importlib.resources
  16. import os
  17. import shutil
  18. import subprocess
  19. import sys
  20. from pathlib import Path
  21. from . import create_pipeline
  22. from .constants import MODEL_FILE_PREFIX
  23. from .inference.pipelines import load_pipeline_config
  24. from .repo_manager import get_all_supported_repo_names, setup
  25. from .utils import logging
  26. from .utils.deps import (
  27. get_paddle2onnx_spec,
  28. get_serving_dep_specs,
  29. require_paddle2onnx_plugin,
  30. )
  31. from .utils.env import get_cuda_version
  32. from .utils.flags import FLAGS_json_format_model
  33. from .utils.install import install_packages
  34. from .utils.interactive_get_pipeline import interactive_get_pipeline
  35. from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
  36. def args_cfg():
  37. """parse cli arguments"""
  38. def parse_str(s):
  39. """convert str type value
  40. to None type if it is "None",
  41. to bool type if it means True or False.
  42. """
  43. if s in ("None", "none", "NONE"):
  44. return None
  45. elif s in ("TRUE", "True", "true", "T", "t"):
  46. return True
  47. elif s in ("FALSE", "False", "false", "F", "f"):
  48. return False
  49. return s
  50. parser = argparse.ArgumentParser(
  51. "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
  52. )
  53. install_group = parser.add_argument_group("Install PaddleX Options")
  54. pipeline_group = parser.add_argument_group("Pipeline Predict Options")
  55. serving_group = parser.add_argument_group("Serving Options")
  56. paddle2onnx_group = parser.add_argument_group("Paddle2ONNX Options")
  57. ################# install pdx #################
  58. install_group.add_argument(
  59. "--install",
  60. action="store_true",
  61. default=False,
  62. help="Install specified PaddleX plugins.",
  63. )
  64. install_group.add_argument(
  65. "plugins",
  66. nargs="*",
  67. default=[],
  68. help="Names of custom development plugins to install (space-separated).",
  69. )
  70. install_group.add_argument(
  71. "--no_deps",
  72. action="store_true",
  73. help="Install custom development plugins without their dependencies.",
  74. )
  75. install_group.add_argument(
  76. "--platform",
  77. type=str,
  78. choices=["github.com", "gitee.com"],
  79. default="github.com",
  80. help="Platform to use for installation (default: github.com).",
  81. )
  82. install_group.add_argument(
  83. "-y",
  84. "--yes",
  85. dest="update_repos",
  86. action="store_true",
  87. help="Automatically confirm prompts and update repositories.",
  88. )
  89. install_group.add_argument(
  90. "--use_local_repos",
  91. action="store_true",
  92. default=False,
  93. help="Use local repositories if they exist.",
  94. )
  95. install_group.add_argument(
  96. "--deps_to_replace",
  97. type=str,
  98. nargs="+",
  99. default=None,
  100. help="Replace dependency version when installing from repositories.",
  101. )
  102. ################# pipeline predict #################
  103. pipeline_group.add_argument(
  104. "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
  105. )
  106. pipeline_group.add_argument(
  107. "--input",
  108. type=str,
  109. default=None,
  110. help="Input data or path for the pipeline, supports specific file and directory.",
  111. )
  112. pipeline_group.add_argument(
  113. "--save_path",
  114. type=str,
  115. default=None,
  116. help="Path to save the prediction results.",
  117. )
  118. pipeline_group.add_argument(
  119. "--device",
  120. type=str,
  121. default=None,
  122. help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
  123. )
  124. pipeline_group.add_argument(
  125. "--use_hpip",
  126. action="store_true",
  127. help="Enable HPIP acceleration by default.",
  128. )
  129. pipeline_group.add_argument(
  130. "--get_pipeline_config",
  131. type=str,
  132. default=None,
  133. help="Retrieve the configuration for a specified pipeline.",
  134. )
  135. ################# serving #################
  136. serving_group.add_argument(
  137. "--serve",
  138. action="store_true",
  139. help="Start the serving application to handle requests.",
  140. )
  141. serving_group.add_argument(
  142. "--host",
  143. type=str,
  144. default="0.0.0.0",
  145. help="Host address to serve on (default: 0.0.0.0).",
  146. )
  147. serving_group.add_argument(
  148. "--port",
  149. type=int,
  150. default=8080,
  151. help="Port number to serve on (default: 8080).",
  152. )
  153. # Serving also uses `--pipeline`, `--device`, and `--use_hpip`
  154. ################# paddle2onnx #################
  155. paddle2onnx_group.add_argument(
  156. "--paddle2onnx",
  157. action="store_true",
  158. help="Convert PaddlePaddle model to ONNX format",
  159. )
  160. paddle2onnx_group.add_argument(
  161. "--paddle_model_dir",
  162. type=str,
  163. help="Directory containing the PaddlePaddle model",
  164. )
  165. paddle2onnx_group.add_argument(
  166. "--onnx_model_dir",
  167. type=str,
  168. help="Output directory for the ONNX model",
  169. )
  170. paddle2onnx_group.add_argument(
  171. "--opset_version", type=int, default=7, help="Version of the ONNX opset to use"
  172. )
  173. # Parse known arguments to get the pipeline name
  174. args, remaining_args = parser.parse_known_args()
  175. pipeline = args.pipeline
  176. pipeline_args = []
  177. if not (args.install or args.serve or args.paddle2onnx) and pipeline is not None:
  178. if os.path.isfile(pipeline):
  179. pipeline_name = load_pipeline_config(pipeline)["pipeline_name"]
  180. else:
  181. pipeline_name = pipeline
  182. if pipeline_name not in PIPELINE_ARGUMENTS:
  183. support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
  184. logging.error(
  185. f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
  186. )
  187. sys.exit(1)
  188. pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
  189. if pipeline_args is None:
  190. pipeline_args = []
  191. pipeline_specific_group = parser.add_argument_group(
  192. f"{pipeline_name.capitalize()} Pipeline Options"
  193. )
  194. for arg in pipeline_args:
  195. pipeline_specific_group.add_argument(
  196. arg["name"],
  197. type=parse_str if arg["type"] is bool else arg["type"],
  198. help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
  199. )
  200. return parser, pipeline_args
  201. def install(args):
  202. """install paddlex"""
  203. def _install_serving_deps():
  204. reqs = get_serving_dep_specs()
  205. # Should we sort the requirements?
  206. install_packages(reqs)
  207. def _install_paddle2onnx_deps():
  208. install_packages([get_paddle2onnx_spec()])
  209. def _install_hpi_deps(device_type):
  210. SUPPORTED_DEVICE_TYPES = ["cpu", "gpu", "npu"]
  211. if device_type not in SUPPORTED_DEVICE_TYPES:
  212. logging.error(
  213. "HPI installation failed!\n"
  214. "Supported device_type: %s. Your input device_type: %s.\n"
  215. "Please ensure the device_type is correct.",
  216. SUPPORTED_DEVICE_TYPES,
  217. device_type,
  218. )
  219. sys.exit(2)
  220. if device_type == "cpu":
  221. package = "ultra-infer-python"
  222. elif device_type == "gpu":
  223. if get_cuda_version()[0] != 11:
  224. sys.exit("Currently, the CUDA version must be 11.x for GPU devices.")
  225. package = "ultra-infer-gpu-python"
  226. elif device_type == "npu":
  227. package = "ultra-infer-npu-python"
  228. with importlib.resources.path("paddlex", "hpip_links.html") as f:
  229. install_packages([package], pip_install_opts=["--find-links", str(f)])
  230. # Enable debug info
  231. os.environ["PADDLE_PDX_DEBUG"] = "True"
  232. # Disable eager initialization
  233. os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
  234. plugins = args.plugins[:]
  235. if "serving" in plugins:
  236. plugins.remove("serving")
  237. if plugins:
  238. logging.error("`serving` cannot be used together with other plugins.")
  239. sys.exit(2)
  240. _install_serving_deps()
  241. return
  242. if "paddle2onnx" in plugins:
  243. plugins.remove("paddle2onnx")
  244. if plugins:
  245. logging.error("`paddle2onnx` cannot be used together with other plugins.")
  246. sys.exit(2)
  247. _install_paddle2onnx_deps()
  248. return
  249. hpi_plugins = list(filter(lambda name: name.startswith("hpi-"), plugins))
  250. if hpi_plugins:
  251. for i in hpi_plugins:
  252. plugins.remove(i)
  253. if plugins:
  254. logging.error("`hpi` cannot be used together with other plugins.")
  255. sys.exit(2)
  256. if len(hpi_plugins) > 1 or len(hpi_plugins[0].split("-")) != 2:
  257. logging.error(
  258. "Invalid HPI plugin installation format detected.\n"
  259. "Correct format: paddlex --install hpi-<device_type>\n"
  260. "Example: paddlex --install hpi-gpu"
  261. )
  262. sys.exit(2)
  263. device_type = hpi_plugins[0].split("-")[1]
  264. _install_hpi_deps(device_type=device_type)
  265. return
  266. if plugins:
  267. repo_names = plugins
  268. elif len(plugins) == 0:
  269. repo_names = get_all_supported_repo_names()
  270. setup(
  271. repo_names=repo_names,
  272. no_deps=args.no_deps,
  273. platform=args.platform,
  274. update_repos=args.update_repos,
  275. use_local_repos=args.use_local_repos,
  276. deps_to_replace=args.deps_to_replace,
  277. )
  278. return
  279. def pipeline_predict(
  280. pipeline,
  281. input,
  282. device,
  283. save_path,
  284. use_hpip,
  285. **pipeline_args,
  286. ):
  287. """pipeline predict"""
  288. pipeline = create_pipeline(pipeline, device=device, use_hpip=use_hpip)
  289. result = pipeline.predict(input, **pipeline_args)
  290. for res in result:
  291. res.print()
  292. if save_path:
  293. res.save_all(save_path=save_path)
  294. def serve(pipeline, *, device, use_hpip, host, port):
  295. from .inference.serving.basic_serving import create_pipeline_app, run_server
  296. pipeline_config = load_pipeline_config(pipeline)
  297. pipeline = create_pipeline(config=pipeline_config, device=device, use_hpip=use_hpip)
  298. app = create_pipeline_app(pipeline, pipeline_config)
  299. run_server(app, host=host, port=port)
  300. # TODO: Move to another module
  301. def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
  302. require_paddle2onnx_plugin()
  303. PD_MODEL_FILE_PREFIX = MODEL_FILE_PREFIX
  304. PD_PARAMS_FILENAME = f"{MODEL_FILE_PREFIX}.pdiparams"
  305. ONNX_MODEL_FILENAME = f"{MODEL_FILE_PREFIX}.onnx"
  306. CONFIG_FILENAME = f"{MODEL_FILE_PREFIX}.yml"
  307. ADDITIONAL_FILENAMES = ["scaler.pkl"]
  308. def _check_input_dir(input_dir, pd_model_file_ext):
  309. if input_dir is None:
  310. sys.exit("Input directory must be specified")
  311. if not input_dir.exists():
  312. sys.exit(f"{input_dir} does not exist")
  313. if not input_dir.is_dir():
  314. sys.exit(f"{input_dir} is not a directory")
  315. model_path = (input_dir / PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)
  316. if not model_path.exists():
  317. sys.exit(f"{model_path} does not exist")
  318. params_path = input_dir / PD_PARAMS_FILENAME
  319. if not params_path.exists():
  320. sys.exit(f"{params_path} does not exist")
  321. config_path = input_dir / CONFIG_FILENAME
  322. if not config_path.exists():
  323. sys.exit(f"{config_path} does not exist")
  324. def _check_paddle2onnx():
  325. if shutil.which("paddle2onnx") is None:
  326. sys.exit("Paddle2ONNX is not available. Please install the plugin first.")
  327. def _run_paddle2onnx(input_dir, pd_model_file_ext, output_dir, opset_version):
  328. logging.info("Paddle2ONNX conversion starting...")
  329. # XXX: To circumvent Paddle2ONNX's bug
  330. cmd = [
  331. "paddle2onnx",
  332. "--model_dir",
  333. str(input_dir),
  334. "--model_filename",
  335. str(Path(PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)),
  336. "--params_filename",
  337. PD_PARAMS_FILENAME,
  338. "--save_file",
  339. str(output_dir / ONNX_MODEL_FILENAME),
  340. "--opset_version",
  341. str(opset_version),
  342. ]
  343. try:
  344. subprocess.check_call(cmd)
  345. except subprocess.CalledProcessError as e:
  346. sys.exit(f"Paddle2ONNX conversion failed with exit code {e.returncode}")
  347. logging.info("Paddle2ONNX conversion succeeded")
  348. def _copy_config_file(input_dir, output_dir):
  349. src_path = input_dir / CONFIG_FILENAME
  350. dst_path = output_dir / CONFIG_FILENAME
  351. shutil.copy(src_path, dst_path)
  352. logging.info(f"Copied {src_path} to {dst_path}")
  353. def _copy_additional_files(input_dir, output_dir):
  354. for filename in ADDITIONAL_FILENAMES:
  355. src_path = input_dir / filename
  356. if not src_path.exists():
  357. continue
  358. dst_path = output_dir / filename
  359. shutil.copy(src_path, dst_path)
  360. logging.info(f"Copied {src_path} to {dst_path}")
  361. paddle_model_dir = Path(paddle_model_dir)
  362. if not onnx_model_dir:
  363. onnx_model_dir = paddle_model_dir
  364. onnx_model_dir = Path(onnx_model_dir)
  365. logging.info(f"Input dir: {paddle_model_dir}")
  366. logging.info(f"Output dir: {onnx_model_dir}")
  367. pd_model_file_ext = ".json"
  368. if not FLAGS_json_format_model:
  369. if not (paddle_model_dir / f"{PD_MODEL_FILE_PREFIX}.json").exists():
  370. pd_model_file_ext = ".pdmodel"
  371. _check_input_dir(paddle_model_dir, pd_model_file_ext)
  372. _check_paddle2onnx()
  373. _run_paddle2onnx(paddle_model_dir, pd_model_file_ext, onnx_model_dir, opset_version)
  374. if not (onnx_model_dir.exists() and onnx_model_dir.samefile(paddle_model_dir)):
  375. _copy_config_file(paddle_model_dir, onnx_model_dir)
  376. _copy_additional_files(paddle_model_dir, onnx_model_dir)
  377. logging.info("Done")
  378. # for CLI
  379. def main():
  380. """API for command line"""
  381. parser, pipeline_args = args_cfg()
  382. args = parser.parse_args()
  383. if len(sys.argv) == 1:
  384. logging.warning("No arguments provided. Displaying help information:")
  385. parser.print_help()
  386. sys.exit(2)
  387. if args.install:
  388. install(args)
  389. elif args.serve:
  390. serve(
  391. args.pipeline,
  392. device=args.device,
  393. use_hpip=args.use_hpip,
  394. host=args.host,
  395. port=args.port,
  396. )
  397. elif args.paddle2onnx:
  398. paddle_to_onnx(
  399. args.paddle_model_dir,
  400. args.onnx_model_dir,
  401. opset_version=args.opset_version,
  402. )
  403. else:
  404. if args.get_pipeline_config is not None:
  405. interactive_get_pipeline(args.get_pipeline_config, args.save_path)
  406. else:
  407. pipeline_args_dict = {}
  408. for arg in pipeline_args:
  409. arg_name = arg["name"].lstrip("-")
  410. if hasattr(args, arg_name):
  411. pipeline_args_dict[arg_name] = getattr(args, arg_name)
  412. else:
  413. logging.warning(f"Argument {arg_name} is missing in args")
  414. return pipeline_predict(
  415. args.pipeline,
  416. args.input,
  417. args.device,
  418. args.save_path,
  419. use_hpip=args.use_hpip,
  420. **pipeline_args_dict,
  421. )