Browse Source

support CLI pipeline params (#2835)

* support CLI pipeline params

* update

* polish 'paddlex -h' command
cuicheng01 10 months ago
parent
commit
c81b5d13b7
2 changed files with 230 additions and 27 deletions
  1. 142 27
      paddlex/paddlex_cli.py
  2. 88 0
      paddlex/utils/pipeline_arguments.py

+ 142 - 27
paddlex/paddlex_cli.py

@@ -24,6 +24,7 @@ from .inference.pipelines import create_pipeline_from_config, load_pipeline_conf
 from .repo_manager import setup, get_all_supported_repo_names
 from .utils import logging
 from .utils.interactive_get_pipeline import interactive_get_pipeline
+from .utils.pipeline_arguments import PIPELINE_ARGUMENTS
 
 
 def _install_serving_deps():
@@ -41,7 +42,7 @@ def args_cfg():
         to None type if it is "None",
         to bool type if it means True or False.
         """
-        if s in ("None"):
+        if s in ("None", "none", "NONE"):
             return None
         elif s in ("TRUE", "True", "true", "T", "t"):
             return True
@@ -49,43 +50,140 @@ def args_cfg():
             return False
         return s
 
-    parser = argparse.ArgumentParser()
+    parser = argparse.ArgumentParser(
+        "Command-line interface for PaddleX. Use the options below to install plugins, run pipeline predictions, or start the serving application."
+    )
+
+    install_group = parser.add_argument_group("Install PaddleX Options")
+    pipeline_group = parser.add_argument_group("Pipeline Predict Options")
+    serving_group = parser.add_argument_group("Serving Options")
 
     ################# install pdx #################
-    parser.add_argument("--install", action="store_true", default=False, help="")
-    parser.add_argument("plugins", nargs="*", default=[])
-    parser.add_argument("--no_deps", action="store_true")
-    parser.add_argument("--platform", type=str, default="github.com")
-    parser.add_argument(
+    install_group.add_argument(
+        "--install",
+        action="store_true",
+        default=False,
+        help="Install specified PaddleX plugins.",
+    )
+    install_group.add_argument(
+        "plugins",
+        nargs="*",
+        default=[],
+        help="Names of custom development plugins to install (space-separated).",
+    )
+    install_group.add_argument(
+        "--no_deps",
+        action="store_true",
+        help="Install custom development plugins without their dependencies.",
+    )
+    install_group.add_argument(
+        "--platform",
+        type=str,
+        choices=["github.com", "gitee.com"],
+        default="github.com",
+        help="Platform to use for installation (default: github.com).",
+    )
+    install_group.add_argument(
         "-y",
         "--yes",
         dest="update_repos",
         action="store_true",
-        help="Whether to update_repos all packages.",
+        help="Automatically confirm prompts and update repositories.",
     )
-    parser.add_argument(
+    install_group.add_argument(
         "--use_local_repos",
         action="store_true",
         default=False,
-        help="Use local repos when existing.",
+        help="Use local repositories if they exist.",
     )
 
     ################# pipeline predict #################
-    parser.add_argument("--pipeline", type=str, help="")
-    parser.add_argument("--input", type=str, default=None, help="")
-    parser.add_argument("--save_path", type=str, default=None, help="")
-    parser.add_argument("--device", type=str, default=None, help="")
-    parser.add_argument("--use_hpip", action="store_true")
-    parser.add_argument("--serial_number", type=str)
-    parser.add_argument("--update_license", action="store_true")
-    parser.add_argument("--get_pipeline_config", type=str, default=None, help="")
+    pipeline_group.add_argument(
+        "--pipeline", type=str, help="Name of the pipeline to execute for prediction."
+    )
+    pipeline_group.add_argument(
+        "--input",
+        type=str,
+        default=None,
+        help="Input data or path for the pipeline, supports specific file and directory.",
+    )
+    pipeline_group.add_argument(
+        "--save_path",
+        type=str,
+        default=None,
+        help="Path to save the prediction results.",
+    )
+    pipeline_group.add_argument(
+        "--device",
+        type=str,
+        default=None,
+        help="Device to run the pipeline on (e.g., 'cpu', 'gpu:0').",
+    )
+    pipeline_group.add_argument(
+        "--use_hpip", action="store_true", help="Enable HPIP acceleration if available."
+    )
+    pipeline_group.add_argument(
+        "--serial_number", type=str, help="Serial number for device identification."
+    )
+    pipeline_group.add_argument(
+        "--update_license",
+        action="store_true",
+        help="Update the software license information.",
+    )
+    pipeline_group.add_argument(
+        "--get_pipeline_config",
+        type=str,
+        default=None,
+        help="Retrieve the configuration for a specified pipeline.",
+    )
 
     ################# serving #################
-    parser.add_argument("--serve", action="store_true")
-    parser.add_argument("--host", type=str, default="0.0.0.0")
-    parser.add_argument("--port", type=int, default=8080)
+    serving_group.add_argument(
+        "--serve",
+        action="store_true",
+        help="Start the serving application to handle requests.",
+    )
+    serving_group.add_argument(
+        "--host",
+        type=str,
+        default="0.0.0.0",
+        help="Host address to serve on (default: 0.0.0.0).",
+    )
+    serving_group.add_argument(
+        "--port",
+        type=int,
+        default=8080,
+        help="Port number to serve on (default: 8080).",
+    )
 
-    return parser
+    # Parse known arguments to get the pipeline name
+    args, remaining_args = parser.parse_known_args()
+    pipeline_name = args.pipeline
+    pipeline_args = []
+
+    if not args.install and pipeline_name is not None:
+
+        if pipeline_name not in PIPELINE_ARGUMENTS:
+            support_pipelines = ", ".join(PIPELINE_ARGUMENTS.keys())
+            logging.error(
+                f"Unsupported pipeline: {pipeline_name}, CLI predict only supports these pipelines: {support_pipelines}\n"
+            )
+            sys.exit(1)
+
+        pipeline_args = PIPELINE_ARGUMENTS[pipeline_name]
+        if pipeline_args is None:
+            pipeline_args = []
+        pipeline_specific_group = parser.add_argument_group(
+            f"{pipeline_name.capitalize()} Pipeline Options"
+        )
+        for arg in pipeline_args:
+            pipeline_specific_group.add_argument(
+                arg["name"],
+                type=parse_str if arg["type"] is bool else arg["type"],
+                help=arg.get("help", f"Argument for {pipeline_name} pipeline."),
+            )
+
+    return parser, pipeline_args
 
 
 def install(args):
@@ -121,16 +219,23 @@ def _get_hpi_params(serial_number, update_license):
 
 
 def pipeline_predict(
-    pipeline, input, device, save_path, use_hpip, serial_number, update_license
+    pipeline,
+    input,
+    device,
+    save_path,
+    use_hpip,
+    serial_number,
+    update_license,
+    **pipeline_args,
 ):
     """pipeline predict"""
     hpi_params = _get_hpi_params(serial_number, update_license)
     pipeline = create_pipeline(
         pipeline, device=device, use_hpip=use_hpip, hpi_params=hpi_params
     )
-    result = pipeline(input)
+    result = pipeline.predict(input, **pipeline_args)
     for res in result:
-        res.print(json_format=False)
+        res.print()
         if save_path:
             res.save_all(save_path=save_path)
 
@@ -150,10 +255,12 @@ def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, po
 # for CLI
 def main():
     """API for commad line"""
-    args = args_cfg().parse_args()
+    parser, pipeline_args = args_cfg()
+    args = parser.parse_args()
+
     if len(sys.argv) == 1:
         logging.warning("No arguments provided. Displaying help information:")
-        args_cfg().print_help()
+        parser.print_help()
         return
 
     if args.install:
@@ -172,6 +279,13 @@ def main():
         if args.get_pipeline_config is not None:
             interactive_get_pipeline(args.get_pipeline_config, args.save_path)
         else:
+            pipeline_args_dict = {}
+            for arg in pipeline_args:
+                arg_name = arg["name"].lstrip("-")
+                if hasattr(args, arg_name):
+                    pipeline_args_dict[arg_name] = getattr(args, arg_name)
+                else:
+                    logging.warning(f"Argument {arg_name} is missing in args")
             return pipeline_predict(
                 args.pipeline,
                 args.input,
@@ -180,4 +294,5 @@ def main():
                 use_hpip=args.use_hpip,
                 serial_number=args.serial_number,
                 update_license=args.update_license,
+                **pipeline_args_dict,
             )

+ 88 - 0
paddlex/utils/pipeline_arguments.py

@@ -0,0 +1,88 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+PIPELINE_ARGUMENTS = {
+    "OCR": [
+        {
+            "name": "--use_doc_orientation_classify",
+            "type": bool,
+            "help": "Determines whether to use document orientation classification",
+        },
+        {
+            "name": "--use_doc_unwarping",
+            "type": bool,
+            "help": "Determines whether to use document unwarping",
+        },
+        {
+            "name": "--use_textline_orientation",
+            "type": bool,
+            "help": "Determines whether to consider text line orientation",
+        },
+        {
+            "name": "--text_det_limit_side_len",
+            "type": int,
+            "help": "Sets the side length limit for text detection.",
+        },
+        {
+            "name": "--text_det_limit_type",
+            "type": str,
+            "help": "Sets the limit type for text detection.",
+        },
+        {
+            "name": "--text_det_thresh",
+            "type": float,
+            "help": "Sets the threshold for text detection.",
+        },
+        {
+            "name": "--text_det_box_thresh",
+            "type": float,
+            "help": "Sets the box threshold for text detection.",
+        },
+        {
+            "name": "--text_det_max_candidates",
+            "type": int,
+            "help": "Sets the maximum number of candidate boxes for text detection.",
+        },
+        {
+            "name": "--text_det_unclip_ratio",
+            "type": float,
+            "help": "Sets the unclip ratio for text detection.",
+        },
+        {
+            "name": "--text_det_use_dilation",
+            "type": bool,
+            "help": "Determines whether to use dilation in text detection.",
+        },
+        {
+            "name": "--text_rec_score_thresh",
+            "type": float,
+            "help": "Sets the score threshold for text recognition.",
+        },
+    ],
+    "object_detection": [
+        {
+            "name": "--threshold",
+            "type": float,
+            "help": "Sets the threshold for object detection.",
+        },
+    ],
+    "image_classification": [
+        {
+            "name": "--topk",
+            "type": int,
+            "help": "Sets the Top-K value for image classification.",
+        },
+    ],
+    "ts_classification": None,
+}