|
|
@@ -24,6 +24,7 @@ from .inference.pipelines import create_pipeline_from_config, load_pipeline_conf
|
|
|
from .repo_manager import setup, get_all_supported_repo_names
|
|
|
from .utils import logging
|
|
|
from .utils.interactive_get_pipeline import interactive_get_pipeline
|
|
|
+from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
|
|
|
|
|
|
|
|
|
def _install_serving_deps():
|
|
|
@@ -41,7 +42,7 @@ def args_cfg():
|
|
|
to None type if it is "None",
|
|
|
to bool type if it means True or False.
|
|
|
"""
|
|
|
- if s in ("None"):
|
|
|
+ if s in ("None", "none", "NONE"):
|
|
|
return None
|
|
|
elif s in ("TRUE", "True", "true", "T", "t"):
|
|
|
return True
|
|
|
@@ -49,43 +50,140 @@ def args_cfg():
|
|
|
return False
|
|
|
return s
|
|
|
|
|
|
- parser = argparse.ArgumentParser()
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
|
|
|
+ )
|
|
|
+
|
|
|
+ install_group = parser.add_argument_group("Install PaddleX Options")
|
|
|
+ pipeline_group = parser.add_argument_group("Pipeline Predict Options")
|
|
|
+ serving_group = parser.add_argument_group("Serving Options")
|
|
|
|
|
|
################# install pdx #################
|
|
|
- parser.add_argument("--install", action="store_true", default=False, help="")
|
|
|
- parser.add_argument("plugins", nargs="*", default=[])
|
|
|
- parser.add_argument("--no_deps", action="store_true")
|
|
|
- parser.add_argument("--platform", type=str, default="github.com")
|
|
|
- parser.add_argument(
|
|
|
+ install_group.add_argument(
|
|
|
+ "--install",
|
|
|
+ action="store_true",
|
|
|
+ default=False,
|
|
|
+ help="Install specified PaddleX plugins.",
|
|
|
+ )
|
|
|
+ install_group.add_argument(
|
|
|
+ "plugins",
|
|
|
+ nargs="*",
|
|
|
+ default=[],
|
|
|
+ help="Names of custom development plugins to install (space-separated).",
|
|
|
+ )
|
|
|
+ install_group.add_argument(
|
|
|
+ "--no_deps",
|
|
|
+ action="store_true",
|
|
|
+ help="Install custom development plugins without their dependencies.",
|
|
|
+ )
|
|
|
+ install_group.add_argument(
|
|
|
+ "--platform",
|
|
|
+ type=str,
|
|
|
+ choices=["github.com", "gitee.com"],
|
|
|
+ default="github.com",
|
|
|
+ help="Platform to use for installation (default: github.com).",
|
|
|
+ )
|
|
|
+ install_group.add_argument(
|
|
|
"-y",
|
|
|
"--yes",
|
|
|
dest="update_repos",
|
|
|
action="store_true",
|
|
|
- help="Whether to update_repos all packages.",
|
|
|
+ help="Automatically confirm prompts and update repositories.",
|
|
|
)
|
|
|
- parser.add_argument(
|
|
|
+ install_group.add_argument(
|
|
|
"--use_local_repos",
|
|
|
action="store_true",
|
|
|
default=False,
|
|
|
- help="Use local repos when existing.",
|
|
|
+ help="Use local repositories if they exist.",
|
|
|
)
|
|
|
|
|
|
################# pipeline predict #################
|
|
|
- parser.add_argument("--pipeline", type=str, help="")
|
|
|
- parser.add_argument("--input", type=str, default=None, help="")
|
|
|
- parser.add_argument("--save_path", type=str, default=None, help="")
|
|
|
- parser.add_argument("--device", type=str, default=None, help="")
|
|
|
- parser.add_argument("--use_hpip", action="store_true")
|
|
|
- parser.add_argument("--serial_number", type=str)
|
|
|
- parser.add_argument("--update_license", action="store_true")
|
|
|
- parser.add_argument("--get_pipeline_config", type=str, default=None, help="")
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
|
|
|
+ )
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--input",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="Input data or path for the pipeline, supports specific file and directory.",
|
|
|
+ )
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--save_path",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="Path to save the prediction results.",
|
|
|
+ )
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--device",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
|
|
|
+ )
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--use_hpip", action="store_true", help="Enable HPIP acceleration if available."
|
|
|
+ )
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--serial_number", type=str, help="Serial number for device identification."
|
|
|
+ )
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--update_license",
|
|
|
+ action="store_true",
|
|
|
+ help="Update the software license information.",
|
|
|
+ )
|
|
|
+ pipeline_group.add_argument(
|
|
|
+ "--get_pipeline_config",
|
|
|
+ type=str,
|
|
|
+ default=None,
|
|
|
+ help="Retrieve the configuration for a specified pipeline.",
|
|
|
+ )
|
|
|
|
|
|
################# serving #################
|
|
|
- parser.add_argument("--serve", action="store_true")
|
|
|
- parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
|
- parser.add_argument("--port", type=int, default=8080)
|
|
|
+ serving_group.add_argument(
|
|
|
+ "--serve",
|
|
|
+ action="store_true",
|
|
|
+ help="Start the serving application to handle requests.",
|
|
|
+ )
|
|
|
+ serving_group.add_argument(
|
|
|
+ "--host",
|
|
|
+ type=str,
|
|
|
+ default="0.0.0.0",
|
|
|
+ help="Host address to serve on (default: 0.0.0.0).",
|
|
|
+ )
|
|
|
+ serving_group.add_argument(
|
|
|
+ "--port",
|
|
|
+ type=int,
|
|
|
+ default=8080,
|
|
|
+ help="Port number to serve on (default: 8080).",
|
|
|
+ )
|
|
|
|
|
|
- return parser
|
|
|
+ # Parse known arguments to get the pipeline name
|
|
|
+ args, remaining_args = parser.parse_known_args()
|
|
|
+ pipeline_name = args.pipeline
|
|
|
+ pipeline_args = []
|
|
|
+
|
|
|
+ if not args.install and pipeline_name is not None:
|
|
|
+
|
|
|
+ if pipeline_name not in PIPELINE_ARGUMENTS:
|
|
|
+ support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
|
|
|
+ logging.error(
|
|
|
+ f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
|
|
|
+ )
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
|
|
|
+ if pipeline_args is None:
|
|
|
+ pipeline_args = []
|
|
|
+ pipeline_specific_group = parser.add_argument_group(
|
|
|
+ f"{pipeline_name.capitalize()} Pipeline Options"
|
|
|
+ )
|
|
|
+ for arg in pipeline_args:
|
|
|
+ pipeline_specific_group.add_argument(
|
|
|
+ arg["name"],
|
|
|
+ type=parse_str if arg["type"] is bool else arg["type"],
|
|
|
+ help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
|
|
|
+ )
|
|
|
+
|
|
|
+ return parser, pipeline_args
|
|
|
|
|
|
|
|
|
def install(args):
|
|
|
@@ -121,16 +219,23 @@ def _get_hpi_params(serial_number, update_license):
|
|
|
|
|
|
|
|
|
def pipeline_predict(
|
|
|
- pipeline, input, device, save_path, use_hpip, serial_number, update_license
|
|
|
+ pipeline,
|
|
|
+ input,
|
|
|
+ device,
|
|
|
+ save_path,
|
|
|
+ use_hpip,
|
|
|
+ serial_number,
|
|
|
+ update_license,
|
|
|
+ **pipeline_args,
|
|
|
):
|
|
|
"""pipeline predict"""
|
|
|
hpi_params = _get_hpi_params(serial_number, update_license)
|
|
|
pipeline = create_pipeline(
|
|
|
pipeline, device=device, use_hpip=use_hpip, hpi_params=hpi_params
|
|
|
)
|
|
|
- result = pipeline(input)
|
|
|
+ result = pipeline.predict(input, **pipeline_args)
|
|
|
for res in result:
|
|
|
- res.print(json_format=False)
|
|
|
+ res.print()
|
|
|
if save_path:
|
|
|
res.save_all(save_path=save_path)
|
|
|
|
|
|
@@ -150,10 +255,12 @@ def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, po
|
|
|
# for CLI
|
|
|
def main():
|
|
|
"""API for commad line"""
|
|
|
- args = args_cfg().parse_args()
|
|
|
+ parser, pipeline_args = args_cfg()
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
if len(sys.argv) == 1:
|
|
|
logging.warning("No arguments provided. Displaying help information:")
|
|
|
- args_cfg().print_help()
|
|
|
+ parser.print_help()
|
|
|
return
|
|
|
|
|
|
if args.install:
|
|
|
@@ -172,6 +279,13 @@ def main():
|
|
|
if args.get_pipeline_config is not None:
|
|
|
interactive_get_pipeline(args.get_pipeline_config, args.save_path)
|
|
|
else:
|
|
|
+ pipeline_args_dict = {}
|
|
|
+ for arg in pipeline_args:
|
|
|
+ arg_name = arg["name"].lstrip("-")
|
|
|
+ if hasattr(args, arg_name):
|
|
|
+ pipeline_args_dict[arg_name] = getattr(args, arg_name)
|
|
|
+ else:
|
|
|
+ logging.warning(f"Argument {arg_name} is missing in args")
|
|
|
return pipeline_predict(
|
|
|
args.pipeline,
|
|
|
args.input,
|
|
|
@@ -180,4 +294,5 @@ def main():
|
|
|
use_hpip=args.use_hpip,
|
|
|
serial_number=args.serial_number,
|
|
|
update_license=args.update_license,
|
|
|
+ **pipeline_args_dict,
|
|
|
)
|