Prechádzať zdrojové kódy

[Feat] Optimize CLI (#3748)

* Optimize PaddleX CLI

* Fix bug
Lin Manhui 6 mesiacov pred
rodič
commit
b91d75adfa
1 zmenil súbory, kde vykonal 13 pridanie a 12 odobranie
  1. 13 12
      paddlex/paddlex_cli.py

+ 13 - 12
paddlex/paddlex_cli.py

@@ -66,15 +66,9 @@ def args_cfg():
     ################# install pdx #################
     install_group.add_argument(
         "--install",
-        action="store_true",
-        default=False,
-        help="Install specified PaddleX plugins.",
-    )
-    install_group.add_argument(
-        "plugins",
         nargs="*",
-        default=[],
-        help="Names of custom development plugins to install (space-separated).",
+        metavar="PLUGIN",
+        help="Install specified PaddleX plugins.",
     )
     install_group.add_argument(
         "--no_deps",
@@ -193,7 +187,10 @@ def args_cfg():
     pipeline = args.pipeline
     pipeline_args = []
 
-    if not (args.install or args.serve or args.paddle2onnx) and pipeline is not None:
+    if (
+        not (args.install is not None or args.serve or args.paddle2onnx)
+        and pipeline is not None
+    ):
         if os.path.isfile(pipeline):
             pipeline_name = load_pipeline_config(pipeline)["pipeline_name"]
         else:
@@ -263,7 +260,7 @@ def install(args):
     # Disable eager initialization
     os.environ["PADDLE_PDX_EAGER_INIT"] = "False"
 
-    plugins = args.plugins[:]
+    plugins = args.install[:]
 
     if "serving" in plugins:
         plugins.remove("serving")
@@ -435,8 +432,9 @@ def main():
         parser.print_help()
         sys.exit(2)
 
-    if args.install:
+    if args.install is not None:
         install(args)
+        return
     elif args.serve:
         serve(
             args.pipeline,
@@ -446,12 +444,14 @@ def main():
             host=args.host,
             port=args.port,
         )
+        return
     elif args.paddle2onnx:
         paddle_to_onnx(
             args.paddle_model_dir,
             args.onnx_model_dir,
             opset_version=args.opset_version,
         )
+        return
     else:
         if args.get_pipeline_config is not None:
             interactive_get_pipeline(args.get_pipeline_config, args.save_path)
@@ -464,7 +464,7 @@ def main():
                     pipeline_args_dict[arg_name] = getattr(args, arg_name)
                 else:
                     logging.warning(f"Argument {arg_name} is missing in args")
-            return pipeline_predict(
+            pipeline_predict(
                 args.pipeline,
                 args.input,
                 args.device,
@@ -473,3 +473,4 @@ def main():
                 hpi_config=args.hpi_config,
                 **pipeline_args_dict,
             )
+            return