paddlex_cli.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import argparse
  16. import subprocess
  17. import sys
  18. from importlib_resources import files, as_file
  19. from . import create_pipeline
  20. from .inference.pipelines import create_pipeline_from_config, load_pipeline_config
  21. from .repo_manager import setup, get_all_supported_repo_names
  22. from .utils import logging
  23. from .utils.interactive_get_pipeline import interactive_get_pipeline
  24. from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
  25. def _install_serving_deps():
  26. with as_file(files("paddlex").joinpath("serving_requirements.txt")) as req_file:
  27. return subprocess.check_call(
  28. [sys.executable, "-m", "pip", "install", "-r", str(req_file)]
  29. )
  30. def args_cfg():
  31. """parse cli arguments"""
  32. def parse_str(s):
  33. """convert str type value
  34. to None type if it is "None",
  35. to bool type if it means True or False.
  36. """
  37. if s in ("None", "none", "NONE"):
  38. return None
  39. elif s in ("TRUE", "True", "true", "T", "t"):
  40. return True
  41. elif s in ("FALSE", "False", "false", "F", "f"):
  42. return False
  43. return s
  44. parser = argparse.ArgumentParser(
  45. "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
  46. )
  47. install_group = parser.add_argument_group("Install PaddleX Options")
  48. pipeline_group = parser.add_argument_group("Pipeline Predict Options")
  49. serving_group = parser.add_argument_group("Serving Options")
  50. ################# install pdx #################
  51. install_group.add_argument(
  52. "--install",
  53. action="store_true",
  54. default=False,
  55. help="Install specified PaddleX plugins.",
  56. )
  57. install_group.add_argument(
  58. "plugins",
  59. nargs="*",
  60. default=[],
  61. help="Names of custom development plugins to install (space-separated).",
  62. )
  63. install_group.add_argument(
  64. "--no_deps",
  65. action="store_true",
  66. help="Install custom development plugins without their dependencies.",
  67. )
  68. install_group.add_argument(
  69. "--platform",
  70. type=str,
  71. choices=["github.com", "gitee.com"],
  72. default="github.com",
  73. help="Platform to use for installation (default: github.com).",
  74. )
  75. install_group.add_argument(
  76. "-y",
  77. "--yes",
  78. dest="update_repos",
  79. action="store_true",
  80. help="Automatically confirm prompts and update repositories.",
  81. )
  82. install_group.add_argument(
  83. "--use_local_repos",
  84. action="store_true",
  85. default=False,
  86. help="Use local repositories if they exist.",
  87. )
  88. ################# pipeline predict #################
  89. pipeline_group.add_argument(
  90. "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
  91. )
  92. pipeline_group.add_argument(
  93. "--input",
  94. type=str,
  95. default=None,
  96. help="Input data or path for the pipeline, supports specific file and directory.",
  97. )
  98. pipeline_group.add_argument(
  99. "--save_path",
  100. type=str,
  101. default=None,
  102. help="Path to save the prediction results.",
  103. )
  104. pipeline_group.add_argument(
  105. "--device",
  106. type=str,
  107. default=None,
  108. help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
  109. )
  110. pipeline_group.add_argument(
  111. "--use_hpip", action="store_true", help="Enable HPIP acceleration if available."
  112. )
  113. pipeline_group.add_argument(
  114. "--serial_number", type=str, help="Serial number for device identification."
  115. )
  116. pipeline_group.add_argument(
  117. "--update_license",
  118. action="store_true",
  119. help="Update the software license information.",
  120. )
  121. pipeline_group.add_argument(
  122. "--get_pipeline_config",
  123. type=str,
  124. default=None,
  125. help="Retrieve the configuration for a specified pipeline.",
  126. )
  127. ################# serving #################
  128. serving_group.add_argument(
  129. "--serve",
  130. action="store_true",
  131. help="Start the serving application to handle requests.",
  132. )
  133. serving_group.add_argument(
  134. "--host",
  135. type=str,
  136. default="0.0.0.0",
  137. help="Host address to serve on (default: 0.0.0.0).",
  138. )
  139. serving_group.add_argument(
  140. "--port",
  141. type=int,
  142. default=8080,
  143. help="Port number to serve on (default: 8080).",
  144. )
  145. # Parse known arguments to get the pipeline name
  146. args, remaining_args = parser.parse_known_args()
  147. pipeline_name = args.pipeline
  148. pipeline_args = []
  149. if not args.install and pipeline_name is not None:
  150. if pipeline_name not in PIPELINE_ARGUMENTS:
  151. support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
  152. logging.error(
  153. f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
  154. )
  155. sys.exit(1)
  156. pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
  157. if pipeline_args is None:
  158. pipeline_args = []
  159. pipeline_specific_group = parser.add_argument_group(
  160. f"{pipeline_name.capitalize()} Pipeline Options"
  161. )
  162. for arg in pipeline_args:
  163. pipeline_specific_group.add_argument(
  164. arg["name"],
  165. type=parse_str if arg["type"] is bool else arg["type"],
  166. help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
  167. )
  168. return parser, pipeline_args
  169. def install(args):
  170. """install paddlex"""
  171. # Enable debug info
  172. os.environ["PADDLE_PDX_DEBUG"] = "True"
  173. # Disable eager initialization
  174. os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
  175. plugins = args.plugins[:]
  176. if "serving" in plugins:
  177. plugins.remove("serving")
  178. _install_serving_deps()
  179. return
  180. if plugins:
  181. repo_names = plugins
  182. elif len(plugins) == 0:
  183. repo_names = get_all_supported_repo_names()
  184. setup(
  185. repo_names=repo_names,
  186. no_deps=args.no_deps,
  187. platform=args.platform,
  188. update_repos=args.update_repos,
  189. use_local_repos=args.use_local_repos,
  190. )
  191. return
  192. def _get_hpi_params(serial_number, update_license):
  193. return {"serial_number": serial_number, "update_license": update_license}
  194. def pipeline_predict(
  195. pipeline,
  196. input,
  197. device,
  198. save_path,
  199. use_hpip,
  200. serial_number,
  201. update_license,
  202. **pipeline_args,
  203. ):
  204. """pipeline predict"""
  205. hpi_params = _get_hpi_params(serial_number, update_license)
  206. pipeline = create_pipeline(
  207. pipeline, device=device, use_hpip=use_hpip, hpi_params=hpi_params
  208. )
  209. result = pipeline.predict(input, **pipeline_args)
  210. for res in result:
  211. res.print()
  212. if save_path:
  213. res.save_all(save_path=save_path)
  214. def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, port):
  215. from .inference.pipelines.serving import create_pipeline_app, run_server
  216. hpi_params = _get_hpi_params(serial_number, update_license)
  217. pipeline_config = load_pipeline_config(pipeline)
  218. pipeline = create_pipeline_from_config(
  219. pipeline_config, device=device, use_hpip=use_hpip, hpi_params=hpi_params
  220. )
  221. app = create_pipeline_app(pipeline, pipeline_config)
  222. run_server(app, host=host, port=port, debug=False)
  223. # for CLI
  224. def main():
  225. """API for commad line"""
  226. parser, pipeline_args = args_cfg()
  227. args = parser.parse_args()
  228. if len(sys.argv) == 1:
  229. logging.warning("No arguments provided. Displaying help information:")
  230. parser.print_help()
  231. return
  232. if args.install:
  233. install(args)
  234. elif args.serve:
  235. serve(
  236. args.pipeline,
  237. device=args.device,
  238. use_hpip=args.use_hpip,
  239. serial_number=args.serial_number,
  240. update_license=args.update_license,
  241. host=args.host,
  242. port=args.port,
  243. )
  244. else:
  245. if args.get_pipeline_config is not None:
  246. interactive_get_pipeline(args.get_pipeline_config, args.save_path)
  247. else:
  248. pipeline_args_dict = {}
  249. from .utils.flags import USE_NEW_INFERENCE
  250. if USE_NEW_INFERENCE:
  251. for arg in pipeline_args:
  252. arg_name = arg["name"].lstrip("-")
  253. if hasattr(args, arg_name):
  254. pipeline_args_dict[arg_name] = getattr(args, arg_name)
  255. else:
  256. logging.warning(f"Argument {arg_name} is missing in args")
  257. return pipeline_predict(
  258. args.pipeline,
  259. args.input,
  260. args.device,
  261. args.save_path,
  262. use_hpip=args.use_hpip,
  263. serial_number=args.serial_number,
  264. update_license=args.update_license,
  265. **pipeline_args_dict,
  266. )