# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import argparse import textwrap from types import SimpleNamespace from .pipelines import build_pipeline, BasePipeline from .repo_manager import setup, get_all_supported_repo_names from .utils import logging def args_cfg(): """parse cli arguments """ def parse_str(s): """convert str type value to None type if it is "None", to bool type if it means True or False. """ if s in ("None"): return None elif s in ("TRUE", "True", "true", "T", "t"): return True elif s in ("FALSE", "False", "false", "F", "f"): return False return s parser = argparse.ArgumentParser() ################# install pdx ################# parser.add_argument( '--install', action='store_true', default=False, help="") parser.add_argument('devkits', nargs='*', default=[]) parser.add_argument('--no_deps', action='store_true') parser.add_argument('--platform', type=str, default='github.com') parser.add_argument( '-y', '--yes', dest='update_repos', action='store_true', help="Whether to update_repos all packages.") parser.add_argument( '--use_local_repos', action='store_true', default=False, help="Use local repos when existing.") ################# pipeline predict ################# parser.add_argument('--predict', action='store_true', default=True, help="") parser.add_argument('--pipeline', type=str, help="") parser.add_argument('--model', nargs='+', help="") parser.add_argument('--model_dir', nargs='+', type=parse_str, help="") parser.add_argument('--input', type=str, help="") parser.add_argument('--output', type=str, default="./", help="") parser.add_argument('--device', type=str, default='gpu:0', help="") return parser.parse_args() def install(args): """install paddlex """ # Enable debug info os.environ['PADDLE_PDX_DEBUG'] = 'True' # Disable eager initialization os.environ['PADDLE_PDX_EAGER_INIT'] = 'False' repo_names = args.devkits if len(repo_names) == 0: repo_names = get_all_supported_repo_names() setup( repo_names=repo_names, no_deps=args.no_deps, platform=args.platform, update_repos=args.update_repos, use_local_repos=args.use_local_repos) return def pipeline_predict(pipeline, model_name_list, model_dir_list, input_path, output, device): """pipeline predict """ pipeline = build_pipeline(pipeline, model_name_list, model_dir_list, output, device) pipeline.predict({"input_path": input_path}) # for CLI def main(): """API for commad line """ args = args_cfg() if args.install: install(args) else: return pipeline_predict(args.pipeline, args.model, args.model_dir, args.input, args.output, args.device)