|
|
@@ -1,5 +1,5 @@
|
|
|
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
|
-#
|
|
|
+#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
# You may obtain a copy of the License at
|
|
|
@@ -12,7 +12,6 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-
|
|
|
import os
|
|
|
import argparse
|
|
|
import textwrap
|
|
|
@@ -32,6 +31,11 @@ def args_cfg():
|
|
|
"""
|
|
|
return v.lower() in ("true", "t", "1")
|
|
|
|
|
|
+ def str2None(s):
|
|
|
+ """convert to None type if it is "None"
|
|
|
+ """
|
|
|
+ return None if s.lower() == 'none' else s
|
|
|
+
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
################# install pdx #################
|
|
|
@@ -52,6 +56,7 @@ def args_cfg():
|
|
|
parser.add_argument('--predict', action='store_true', default=True, help="")
|
|
|
parser.add_argument('--pipeline', type=str, help="")
|
|
|
parser.add_argument('--model', nargs='+', help="")
|
|
|
+ parser.add_argument('--model_dir', nargs='+', type=str2None, help="")
|
|
|
parser.add_argument('--input', type=str, help="")
|
|
|
parser.add_argument('--output', type=str, default="./", help="")
|
|
|
parser.add_argument('--device', type=str, default='gpu:0', help="")
|
|
|
@@ -79,10 +84,12 @@ def install(args):
|
|
|
return
|
|
|
|
|
|
|
|
|
-def pipeline_predict(pipeline, model_name_list, input_path, output, device):
|
|
|
+def pipeline_predict(pipeline, model_name_list, model_dir_list, input_path,
|
|
|
+ output, device):
|
|
|
"""pipeline predict
|
|
|
"""
|
|
|
- pipeline = build_pipeline(pipeline, model_name_list, output, device)
|
|
|
+ pipeline = build_pipeline(pipeline, model_name_list, model_dir_list, output,
|
|
|
+ device)
|
|
|
pipeline.predict({"input_path": input_path})
|
|
|
|
|
|
|
|
|
@@ -94,5 +101,5 @@ def main():
|
|
|
if args.install:
|
|
|
install(args)
|
|
|
else:
|
|
|
- return pipeline_predict(args.pipeline, args.model, args.input,
|
|
|
- args.output, args.device)
|
|
|
+ return pipeline_predict(args.pipeline, args.model, args.model_dir,
|
|
|
+ args.input, args.output, args.device)
|