Browse Source

[Fix] Rename parameters of pipeline factory function and fix bugs when passing config paths (#2941)

* Remove import of create_pipeline_from_config

* Fix pipeline factory function bug
Lin Manhui 9 months ago
parent
commit
6de2bcff22
2 changed files with 20 additions and 19 deletions
  1. 19 18
      paddlex/inference/pipelines_new/__init__.py
  2. 1 1
      paddlex/paddlex_cli.py

+ 19 - 18
paddlex/inference/pipelines_new/__init__.py

@@ -73,12 +73,12 @@ def get_pipeline_path(pipeline_name: str) -> str:
     return pipeline_path
 
 
-def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
+def load_pipeline_config(pipeline: str) -> Dict[str, Any]:
     """
     Load the pipeline configuration.
 
     Args:
-        pipeline_name (str): The name of the pipeline or the path to the config file.
+        pipeline (str): The name of the pipeline or the path to the config file.
 
     Returns:
         Dict[str, Any]: The parsed pipeline configuration.
@@ -86,20 +86,20 @@ def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
     Raises:
         Exception: If the config file of pipeline does not exist.
     """
-    if not (pipeline_name.endswith(".yml") or pipeline_name.endswith(".yaml")):
-        pipeline_path = get_pipeline_path(pipeline_name)
+    if not (pipeline.endswith(".yml") or pipeline.endswith(".yaml")):
+        pipeline_path = get_pipeline_path(pipeline)
         if pipeline_path is None:
             raise Exception(
-                f"The pipeline ({pipeline_name}) does not exist! Please use a pipeline name or a config file path!"
+                f"The pipeline ({pipeline}) does not exist! Please use a pipeline name or a config file path!"
             )
     else:
-        pipeline_path = pipeline_name
+        pipeline_path = pipeline
     config = parse_config(pipeline_path)
     return config
 
 
 def create_pipeline(
-    pipeline_name: Optional[str] = None,
+    pipeline: Optional[str] = None,
     config: Optional[Dict[str, Any]] = None,
     device: Optional[str] = None,
     pp_option: Optional[PaddlePredictorOption] = None,
@@ -114,7 +114,7 @@ def create_pipeline(
     default config corresponding to the pipeline name.
 
     Args:
-        pipeline_name (Optional[str], optional): The name of the pipeline to
+        pipeline (Optional[str], optional): The name of the pipeline to
             create, or the path to the config file. Defaults to None.
         config (Optional[Dict[str, Any]], optional): The pipeline configuration.
             Defaults to None.
@@ -130,19 +130,20 @@ def create_pipeline(
     Returns:
         BasePipeline: The created pipeline instance.
     """
-    if pipeline_name is None and config is None:
+    if pipeline is None and config is None:
         raise ValueError(
-            "Both `pipeline_name` and `config` cannot be None at the same time."
+            "Both `pipeline` and `config` cannot be None at the same time."
         )
     if config is None:
-        config = load_pipeline_config(pipeline_name)
-    if pipeline_name is not None and config["pipeline_name"] != pipeline_name:
-        logging.warning(
-            "The pipeline name in the config (%r) is different from the specified pipeline name (%r). %r will be used.",
-            config["pipeline_name"],
-            pipeline_name,
-            config["pipeline_name"],
-        )
+        config = load_pipeline_config(pipeline)
+    else:
+        if pipeline is not None and config["pipeline_name"] != pipeline:
+            logging.warning(
+                "The pipeline name in the config (%r) is different from the specified pipeline name (%r). %r will be used.",
+                config["pipeline_name"],
+                pipeline,
+                config["pipeline_name"],
+            )
     pipeline_name = config["pipeline_name"]
 
     pipeline = BasePipeline.get(pipeline_name)(

+ 1 - 1
paddlex/paddlex_cli.py

@@ -22,7 +22,7 @@ from pathlib import Path
 from importlib_resources import files, as_file
 
 from . import create_pipeline
-from .inference.pipelines import create_pipeline_from_config, load_pipeline_config
+from .inference.pipelines import load_pipeline_config
 from .repo_manager import setup, get_all_supported_repo_names
 from .utils.flags import FLAGS_json_format_model
 from .utils import logging