paddlex_cli.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import argparse
  16. import subprocess
  17. import sys
  18. import shutil
  19. from pathlib import Path
  20. from importlib_resources import files, as_file
  21. from . import create_pipeline
  22. from .inference.pipelines import load_pipeline_config
  23. from .repo_manager import setup, get_all_supported_repo_names
  24. from .utils.flags import FLAGS_json_format_model
  25. from .utils import logging
  26. from .utils.interactive_get_pipeline import interactive_get_pipeline
  27. from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
  28. def args_cfg():
  29. """parse cli arguments"""
  30. def parse_str(s):
  31. """convert str type value
  32. to None type if it is "None",
  33. to bool type if it means True or False.
  34. """
  35. if s in ("None", "none", "NONE"):
  36. return None
  37. elif s in ("TRUE", "True", "true", "T", "t"):
  38. return True
  39. elif s in ("FALSE", "False", "false", "F", "f"):
  40. return False
  41. return s
  42. parser = argparse.ArgumentParser(
  43. "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
  44. )
  45. install_group = parser.add_argument_group("Install PaddleX Options")
  46. pipeline_group = parser.add_argument_group("Pipeline Predict Options")
  47. serving_group = parser.add_argument_group("Serving Options")
  48. paddle2onnx_group = parser.add_argument_group("Paddle2ONNX Options")
  49. ################# install pdx #################
  50. install_group.add_argument(
  51. "--install",
  52. action="store_true",
  53. default=False,
  54. help="Install specified PaddleX plugins.",
  55. )
  56. install_group.add_argument(
  57. "plugins",
  58. nargs="*",
  59. default=[],
  60. help="Names of custom development plugins to install (space-separated).",
  61. )
  62. install_group.add_argument(
  63. "--no_deps",
  64. action="store_true",
  65. help="Install custom development plugins without their dependencies.",
  66. )
  67. install_group.add_argument(
  68. "--platform",
  69. type=str,
  70. choices=["github.com", "gitee.com"],
  71. default="github.com",
  72. help="Platform to use for installation (default: github.com).",
  73. )
  74. install_group.add_argument(
  75. "-y",
  76. "--yes",
  77. dest="update_repos",
  78. action="store_true",
  79. help="Automatically confirm prompts and update repositories.",
  80. )
  81. install_group.add_argument(
  82. "--use_local_repos",
  83. action="store_true",
  84. default=False,
  85. help="Use local repositories if they exist.",
  86. )
  87. ################# pipeline predict #################
  88. pipeline_group.add_argument(
  89. "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
  90. )
  91. pipeline_group.add_argument(
  92. "--input",
  93. type=str,
  94. default=None,
  95. help="Input data or path for the pipeline, supports specific file and directory.",
  96. )
  97. pipeline_group.add_argument(
  98. "--save_path",
  99. type=str,
  100. default=None,
  101. help="Path to save the prediction results.",
  102. )
  103. pipeline_group.add_argument(
  104. "--device",
  105. type=str,
  106. default=None,
  107. help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
  108. )
  109. pipeline_group.add_argument(
  110. "--use_hpip", action="store_true", help="Enable HPIP acceleration if available."
  111. )
  112. pipeline_group.add_argument(
  113. "--get_pipeline_config",
  114. type=str,
  115. default=None,
  116. help="Retrieve the configuration for a specified pipeline.",
  117. )
  118. ################# serving #################
  119. serving_group.add_argument(
  120. "--serve",
  121. action="store_true",
  122. help="Start the serving application to handle requests.",
  123. )
  124. serving_group.add_argument(
  125. "--host",
  126. type=str,
  127. default="0.0.0.0",
  128. help="Host address to serve on (default: 0.0.0.0).",
  129. )
  130. serving_group.add_argument(
  131. "--port",
  132. type=int,
  133. default=8080,
  134. help="Port number to serve on (default: 8080).",
  135. )
  136. ################# paddle2onnx #################
  137. paddle2onnx_group.add_argument(
  138. "--paddle2onnx", action="store_true", help="Convert Paddle model to ONNX format"
  139. )
  140. paddle2onnx_group.add_argument(
  141. "--paddle_model_dir", type=str, help="Directory containing the Paddle model"
  142. )
  143. paddle2onnx_group.add_argument(
  144. "--onnx_model_dir",
  145. type=str,
  146. default="onnx",
  147. help="Output directory for the ONNX model",
  148. )
  149. paddle2onnx_group.add_argument(
  150. "--opset_version", type=int, help="Version of the ONNX opset to use"
  151. )
  152. # Parse known arguments to get the pipeline name
  153. args, remaining_args = parser.parse_known_args()
  154. pipeline = args.pipeline
  155. pipeline_args = []
  156. if not (args.install or args.serve or args.paddle2onnx) and pipeline is not None:
  157. if os.path.isfile(pipeline):
  158. pipeline_name = load_pipeline_config(pipeline)["pipeline_name"]
  159. else:
  160. pipeline_name = pipeline
  161. if pipeline_name not in PIPELINE_ARGUMENTS:
  162. support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
  163. logging.error(
  164. f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
  165. )
  166. sys.exit(1)
  167. pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
  168. if pipeline_args is None:
  169. pipeline_args = []
  170. pipeline_specific_group = parser.add_argument_group(
  171. f"{pipeline_name.capitalize()} Pipeline Options"
  172. )
  173. for arg in pipeline_args:
  174. pipeline_specific_group.add_argument(
  175. arg["name"],
  176. type=parse_str if arg["type"] is bool else arg["type"],
  177. help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
  178. )
  179. return parser, pipeline_args
  180. def install(args):
  181. """install paddlex"""
  182. def _install_serving_deps():
  183. with as_file(files("paddlex").joinpath("serving_requirements.txt")) as req_file:
  184. return subprocess.check_call(
  185. [sys.executable, "-m", "pip", "install", "-r", str(req_file)]
  186. )
  187. def _install_paddle2onnx_deps():
  188. with as_file(
  189. files("paddlex").joinpath("paddle2onnx_requirements.txt")
  190. ) as req_file:
  191. return subprocess.check_call(
  192. [sys.executable, "-m", "pip", "install", "-r", str(req_file)]
  193. )
  194. def _install_hpi_deps(device_type):
  195. support_device_type = ["cpu", "gpu"]
  196. if device_type not in support_device_type:
  197. logging.error(
  198. "HPI installation failed!\n"
  199. "Supported device_type: %s. Your input device_type: %s.\n"
  200. "Please ensure the device_type is correct.",
  201. support_device_type,
  202. device_type,
  203. )
  204. sys.exit(2)
  205. if device_type == "cpu":
  206. packages = ["ultra_infer_python", "paddlex_hpi"]
  207. elif device_type == "gpu":
  208. packages = ["ultra_infer_gpu_python", "paddlex_hpi"]
  209. return subprocess.check_call(
  210. [sys.executable, "-m", "pip", "install"]
  211. + packages
  212. + [
  213. "--find-links",
  214. "https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/pipeline_deploy/high_performance_inference.md",
  215. ]
  216. )
  217. # Enable debug info
  218. os.environ["PADDLE_PDX_DEBUG"] = "True"
  219. # Disable eager initialization
  220. os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
  221. plugins = args.plugins[:]
  222. if "serving" in plugins:
  223. plugins.remove("serving")
  224. if plugins:
  225. logging.error("`serving` cannot be used together with other plugins.")
  226. sys.exit(2)
  227. _install_serving_deps()
  228. return
  229. if "paddle2onnx" in plugins:
  230. plugins.remove("paddle2onnx")
  231. if plugins:
  232. logging.error("`paddle2onnx` cannot be used together with other plugins.")
  233. sys.exit(2)
  234. _install_paddle2onnx_deps()
  235. return
  236. hpi_plugins = list(filter(lambda name: name.startswith("hpi-"), plugins))
  237. if hpi_plugins:
  238. for i in hpi_plugins:
  239. plugins.remove(i)
  240. if plugins:
  241. logging.error("`hpi` cannot be used together with other plugins.")
  242. sys.exit(2)
  243. if len(hpi_plugins) > 1 or len(hpi_plugins[0].split("-")) != 2:
  244. logging.error(
  245. "Invalid HPI plugin installation format detected.\n"
  246. "Correct format: paddlex --install hpi-<device_type>\n"
  247. "Example: paddlex --install hpi-gpu"
  248. )
  249. sys.exit(2)
  250. device_type = hpi_plugins[0].split("-")[1]
  251. _install_hpi_deps(device_type=device_type)
  252. return
  253. if plugins:
  254. repo_names = plugins
  255. elif len(plugins) == 0:
  256. repo_names = get_all_supported_repo_names()
  257. setup(
  258. repo_names=repo_names,
  259. no_deps=args.no_deps,
  260. platform=args.platform,
  261. update_repos=args.update_repos,
  262. use_local_repos=args.use_local_repos,
  263. )
  264. return
  265. def pipeline_predict(
  266. pipeline,
  267. input,
  268. device,
  269. save_path,
  270. use_hpip,
  271. **pipeline_args,
  272. ):
  273. """pipeline predict"""
  274. pipeline = create_pipeline(pipeline, device=device, use_hpip=use_hpip)
  275. result = pipeline.predict(input, **pipeline_args)
  276. for res in result:
  277. res.print()
  278. if save_path:
  279. res.save_all(save_path=save_path)
  280. def serve(pipeline, *, device, use_hpip, host, port):
  281. from .inference.serving.basic_serving import create_pipeline_app, run_server
  282. pipeline_config = load_pipeline_config(pipeline)
  283. pipeline = create_pipeline(config=pipeline_config, device=device, use_hpip=use_hpip)
  284. app = create_pipeline_app(pipeline, pipeline_config)
  285. run_server(app, host=host, port=port)
  286. # TODO: Move to another module
  287. def paddle_to_onnx(paddle_model_dir, onnx_model_dir, *, opset_version):
  288. PD_MODEL_FILE_PREFIX = "inference"
  289. PD_PARAMS_FILENAME = "inference.pdiparams"
  290. ONNX_MODEL_FILENAME = "inference.onnx"
  291. CONFIG_FILENAME = "inference.yml"
  292. ADDITIONAL_FILENAMES = ["scaler.pkl"]
  293. def _check_input_dir(input_dir, pd_model_file_ext):
  294. if input_dir is None:
  295. sys.exit("Input directory must be specified")
  296. if not input_dir.exists():
  297. sys.exit(f"{input_dir} does not exist")
  298. if not input_dir.is_dir():
  299. sys.exit(f"{input_dir} is not a directory")
  300. model_path = (input_dir / PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)
  301. if not model_path.exists():
  302. sys.exit(f"{model_path} does not exist")
  303. params_path = input_dir / PD_PARAMS_FILENAME
  304. if not params_path.exists():
  305. sys.exit(f"{params_path} does not exist")
  306. config_path = input_dir / CONFIG_FILENAME
  307. if not config_path.exists():
  308. sys.exit(f"{config_path} does not exist")
  309. def _check_paddle2onnx():
  310. if shutil.which("paddle2onnx") is None:
  311. sys.exit("Paddle2ONNX is not available. Please install the plugin first.")
  312. def _run_paddle2onnx(input_dir, pd_model_file_ext, output_dir, opset_version):
  313. logging.info("Paddle2ONNX conversion starting...")
  314. # XXX: To circumvent Paddle2ONNX's bug
  315. if opset_version is None:
  316. if pd_model_file_ext == ".json":
  317. opset_version = 19
  318. else:
  319. opset_version = 7
  320. logging.info("Using default ONNX opset version: %d", opset_version)
  321. cmd = [
  322. "paddle2onnx",
  323. "--model_dir",
  324. str(input_dir),
  325. "--model_filename",
  326. str(Path(PD_MODEL_FILE_PREFIX).with_suffix(pd_model_file_ext)),
  327. "--params_filename",
  328. PD_PARAMS_FILENAME,
  329. "--save_file",
  330. str(output_dir / ONNX_MODEL_FILENAME),
  331. "--opset_version",
  332. str(opset_version),
  333. ]
  334. try:
  335. subprocess.check_call(cmd)
  336. except subprocess.CalledProcessError as e:
  337. sys.exit(f"Paddle2ONNX conversion failed with exit code {e.returncode}")
  338. logging.info("Paddle2ONNX conversion succeeded")
  339. def _copy_config_file(input_dir, output_dir):
  340. src_path = input_dir / CONFIG_FILENAME
  341. dst_path = output_dir / CONFIG_FILENAME
  342. shutil.copy(src_path, dst_path)
  343. logging.info(f"Copied {src_path} to {dst_path}")
  344. def _copy_additional_files(input_dir, output_dir):
  345. for filename in ADDITIONAL_FILENAMES:
  346. src_path = input_dir / filename
  347. if not src_path.exists():
  348. continue
  349. dst_path = output_dir / filename
  350. shutil.copy(src_path, dst_path)
  351. logging.info(f"Copied {src_path} to {dst_path}")
  352. paddle_model_dir = Path(paddle_model_dir)
  353. onnx_model_dir = Path(onnx_model_dir)
  354. logging.info(f"Input dir: {paddle_model_dir}")
  355. logging.info(f"Output dir: {onnx_model_dir}")
  356. pd_model_file_ext = ".json"
  357. if not FLAGS_json_format_model:
  358. if not (paddle_model_dir / f"{PD_MODEL_FILE_PREFIX}.json").exists():
  359. pd_model_file_ext = ".pdmodel"
  360. _check_input_dir(paddle_model_dir, pd_model_file_ext)
  361. _check_paddle2onnx()
  362. _run_paddle2onnx(paddle_model_dir, pd_model_file_ext, onnx_model_dir, opset_version)
  363. if not (onnx_model_dir.exists() and onnx_model_dir.samefile(paddle_model_dir)):
  364. _copy_config_file(paddle_model_dir, onnx_model_dir)
  365. _copy_additional_files(paddle_model_dir, onnx_model_dir)
  366. logging.info("Done")
  367. # for CLI
  368. def main():
  369. """API for commad line"""
  370. parser, pipeline_args = args_cfg()
  371. args = parser.parse_args()
  372. if len(sys.argv) == 1:
  373. logging.warning("No arguments provided. Displaying help information:")
  374. parser.print_help()
  375. sys.exit(2)
  376. if args.install:
  377. install(args)
  378. elif args.serve:
  379. serve(
  380. args.pipeline,
  381. device=args.device,
  382. use_hpip=args.use_hpip,
  383. host=args.host,
  384. port=args.port,
  385. )
  386. elif args.paddle2onnx:
  387. paddle_to_onnx(
  388. args.paddle_model_dir,
  389. args.onnx_model_dir,
  390. opset_version=args.opset_version,
  391. )
  392. else:
  393. if args.get_pipeline_config is not None:
  394. interactive_get_pipeline(args.get_pipeline_config, args.save_path)
  395. else:
  396. pipeline_args_dict = {}
  397. for arg in pipeline_args:
  398. arg_name = arg["name"].lstrip("-")
  399. if hasattr(args, arg_name):
  400. pipeline_args_dict[arg_name] = getattr(args, arg_name)
  401. else:
  402. logging.warning(f"Argument {arg_name} is missing in args")
  403. return pipeline_predict(
  404. args.pipeline,
  405. args.input,
  406. args.device,
  407. args.save_path,
  408. use_hpip=args.use_hpip,
  409. **pipeline_args_dict,
  410. )