zhangyubo0722 před 1 rokem
rodič
revize
c90a468fd3
2 změnil soubory, kde provedl 18 přidání a 14 odebrání
  1. 1 1
      paddlex/paddlex_cli.py
  2. 17 13
      paddlex/utils/interactive_get_pipeline.py

+ 1 - 1
paddlex/paddlex_cli.py

@@ -177,7 +177,7 @@ def main():
         )
     else:
         if args.get_pipeline_config is not None:
-            interactive_get_pipeline(args.get_pipeline_config)
+            interactive_get_pipeline(args.get_pipeline_config, args.save_path)
         else:
             return pipeline_predict(
                 args.pipeline,

+ 17 - 13
paddlex/utils/interactive_get_pipeline.py

@@ -19,24 +19,28 @@ from ..utils import logging
 from ..inference.utils.get_pipeline_path import get_pipeline_path
 
 
-def interactive_get_pipeline(pipeline):
+def interactive_get_pipeline(pipeline, save_path):
     file_path = get_pipeline_path(pipeline)
     file_name = Path(file_path).name
 
-    logging.info(
-        "Please enter the path that you want to save the pipeline config file: (default `./`)"
-    )
-    target_path = input() or "."
+    if save_path is None:
+        logging.info(
+            "Please enter the path that you want to save the pipeline config file: (default `./`)"
+        )
+        target_path = input() or "."
+    else:
+        target_path = save_path
     target_path = Path(target_path)
 
-    if not target_path.suffix in (".yaml", ".yml"):
-        if not target_path.exists():
-            try:
-                target_path.mkdir(parents=True, exist_ok=True)
-            except Exception as e:
-                logging.error(f"Failed to create directory: {e}")
-                return
-        target_path = target_path / file_name
+    if target_path.suffix not in (".yaml", ".yml"):
+        target_path /= file_name
+
+    if not target_path.parent.exists():
+        try:
+            target_path.parent.mkdir(parents=True, exist_ok=True)
+        except Exception as e:
+            logging.error(f"Failed to create directory: {e}")
+            return
 
     if target_path.exists():
         logging.info(f"The file({target_path}) already exists. Is it covered? (y/N):")