| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import argparse
- import subprocess
- import sys
- from importlib_resources import files, as_file
- from . import create_pipeline
- from .inference.pipelines import create_pipeline_from_config, load_pipeline_config
- from .repo_manager import setup, get_all_supported_repo_names
- from .utils import logging
- from .utils.interactive_get_pipeline import interactive_get_pipeline
- from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
- def _install_serving_deps():
- with as_file(files("paddlex").joinpath("serving_requirements.txt")) as req_file:
- return subprocess.check_call(
- [sys.executable, "-m", "pip", "install", "-r", str(req_file)]
- )
- def args_cfg():
- """parse cli arguments"""
- def parse_str(s):
- """convert str type value
- to None type if it is "None",
- to bool type if it means True or False.
- """
- if s in ("None", "none", "NONE"):
- return None
- elif s in ("TRUE", "True", "true", "T", "t"):
- return True
- elif s in ("FALSE", "False", "false", "F", "f"):
- return False
- return s
- parser = argparse.ArgumentParser(
- "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
- )
- install_group = parser.add_argument_group("Install PaddleX Options")
- pipeline_group = parser.add_argument_group("Pipeline Predict Options")
- serving_group = parser.add_argument_group("Serving Options")
- ################# install pdx #################
- install_group.add_argument(
- "--install",
- action="store_true",
- default=False,
- help="Install specified PaddleX plugins.",
- )
- install_group.add_argument(
- "plugins",
- nargs="*",
- default=[],
- help="Names of custom development plugins to install (space-separated).",
- )
- install_group.add_argument(
- "--no_deps",
- action="store_true",
- help="Install custom development plugins without their dependencies.",
- )
- install_group.add_argument(
- "--platform",
- type=str,
- choices=["github.com", "gitee.com"],
- default="github.com",
- help="Platform to use for installation (default: github.com).",
- )
- install_group.add_argument(
- "-y",
- "--yes",
- dest="update_repos",
- action="store_true",
- help="Automatically confirm prompts and update repositories.",
- )
- install_group.add_argument(
- "--use_local_repos",
- action="store_true",
- default=False,
- help="Use local repositories if they exist.",
- )
- ################# pipeline predict #################
- pipeline_group.add_argument(
- "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
- )
- pipeline_group.add_argument(
- "--input",
- type=str,
- default=None,
- help="Input data or path for the pipeline, supports specific file and directory.",
- )
- pipeline_group.add_argument(
- "--save_path",
- type=str,
- default=None,
- help="Path to save the prediction results.",
- )
- pipeline_group.add_argument(
- "--device",
- type=str,
- default=None,
- help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
- )
- pipeline_group.add_argument(
- "--use_hpip", action="store_true", help="Enable HPIP acceleration if available."
- )
- pipeline_group.add_argument(
- "--serial_number", type=str, help="Serial number for device identification."
- )
- pipeline_group.add_argument(
- "--update_license",
- action="store_true",
- help="Update the software license information.",
- )
- pipeline_group.add_argument(
- "--get_pipeline_config",
- type=str,
- default=None,
- help="Retrieve the configuration for a specified pipeline.",
- )
- ################# serving #################
- serving_group.add_argument(
- "--serve",
- action="store_true",
- help="Start the serving application to handle requests.",
- )
- serving_group.add_argument(
- "--host",
- type=str,
- default="0.0.0.0",
- help="Host address to serve on (default: 0.0.0.0).",
- )
- serving_group.add_argument(
- "--port",
- type=int,
- default=8080,
- help="Port number to serve on (default: 8080).",
- )
- # Parse known arguments to get the pipeline name
- args, remaining_args = parser.parse_known_args()
- pipeline_name = args.pipeline
- pipeline_args = []
- if not args.install and pipeline_name is not None:
- if pipeline_name not in PIPELINE_ARGUMENTS:
- support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
- logging.error(
- f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
- )
- sys.exit(1)
- pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
- if pipeline_args is None:
- pipeline_args = []
- pipeline_specific_group = parser.add_argument_group(
- f"{pipeline_name.capitalize()} Pipeline Options"
- )
- for arg in pipeline_args:
- pipeline_specific_group.add_argument(
- arg["name"],
- type=parse_str if arg["type"] is bool else arg["type"],
- help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
- )
- return parser, pipeline_args
- def install(args):
- """install paddlex"""
- # Enable debug info
- os.environ["PADDLE_PDX_DEBUG"] = "True"
- # Disable eager initialization
- os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
- plugins = args.plugins[:]
- if "serving" in plugins:
- plugins.remove("serving")
- _install_serving_deps()
- return
- if plugins:
- repo_names = plugins
- elif len(plugins) == 0:
- repo_names = get_all_supported_repo_names()
- setup(
- repo_names=repo_names,
- no_deps=args.no_deps,
- platform=args.platform,
- update_repos=args.update_repos,
- use_local_repos=args.use_local_repos,
- )
- return
- def _get_hpi_params(serial_number, update_license):
- return {"serial_number": serial_number, "update_license": update_license}
- def pipeline_predict(
- pipeline,
- input,
- device,
- save_path,
- use_hpip,
- serial_number,
- update_license,
- **pipeline_args,
- ):
- """pipeline predict"""
- hpi_params = _get_hpi_params(serial_number, update_license)
- pipeline = create_pipeline(
- pipeline, device=device, use_hpip=use_hpip, hpi_params=hpi_params
- )
- result = pipeline.predict(input, **pipeline_args)
- for res in result:
- res.print()
- if save_path:
- res.save_all(save_path=save_path)
- def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, port):
- from .inference.pipelines.serving import create_pipeline_app, run_server
- hpi_params = _get_hpi_params(serial_number, update_license)
- pipeline_config = load_pipeline_config(pipeline)
- pipeline = create_pipeline_from_config(
- pipeline_config, device=device, use_hpip=use_hpip, hpi_params=hpi_params
- )
- app = create_pipeline_app(pipeline, pipeline_config)
- run_server(app, host=host, port=port, debug=False)
- # for CLI
- def main():
- """API for commad line"""
- parser, pipeline_args = args_cfg()
- args = parser.parse_args()
- if len(sys.argv) == 1:
- logging.warning("No arguments provided. Displaying help information:")
- parser.print_help()
- return
- if args.install:
- install(args)
- elif args.serve:
- serve(
- args.pipeline,
- device=args.device,
- use_hpip=args.use_hpip,
- serial_number=args.serial_number,
- update_license=args.update_license,
- host=args.host,
- port=args.port,
- )
- else:
- if args.get_pipeline_config is not None:
- interactive_get_pipeline(args.get_pipeline_config, args.save_path)
- else:
- pipeline_args_dict = {}
- from .utils.flags import USE_NEW_INFERENCE
- if USE_NEW_INFERENCE:
- for arg in pipeline_args:
- arg_name = arg["name"].lstrip("-")
- if hasattr(args, arg_name):
- pipeline_args_dict[arg_name] = getattr(args, arg_name)
- else:
- logging.warning(f"Argument {arg_name} is missing in args")
- return pipeline_predict(
- args.pipeline,
- args.input,
- args.device,
- args.save_path,
- use_hpip=args.use_hpip,
- serial_number=args.serial_number,
- update_license=args.update_license,
- **pipeline_args_dict,
- )
|