paddlex.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import argparse
  16. import textwrap
  17. from types import SimpleNamespace
  18. from prettytable import PrettyTable
  19. from .utils.config import AttrDict
  20. from .modules.base.predictor import build_predictor, BasePredictor
  21. from .repo_manager import setup, get_all_supported_repo_names
  22. def args_cfg():
  23. """parse cli arguments
  24. """
  25. def str2bool(v):
  26. """convert str to bool type
  27. """
  28. return v.lower() in ("true", "t", "1")
  29. parser = argparse.ArgumentParser()
  30. ################# install pdx #################
  31. parser.add_argument(
  32. '--install', action='store_true', default=False, help="")
  33. parser.add_argument('devkits', nargs='*', default=[])
  34. parser.add_argument('--no_deps', action='store_true')
  35. parser.add_argument('--platform', type=str, default='github.com')
  36. parser.add_argument('--update_repos', action='store_true')
  37. parser.add_argument(
  38. '-y',
  39. '--yes',
  40. dest='reinstall',
  41. action='store_true',
  42. help="Whether to reinstall all packages.")
  43. ################# infer #################
  44. parser.add_argument('--predict', action='store_true', default=True, help="")
  45. parser.add_argument('--model_name', type=str, help="")
  46. parser.add_argument('--model', type=str, help="")
  47. parser.add_argument('--input_path', type=str, help="")
  48. parser.add_argument('--output', type=str, help="")
  49. parser.add_argument('--device', type=str, default='gpu:0', help="")
  50. return parser.parse_args()
  51. def get_all_models():
  52. """Get all models that have been registered
  53. """
  54. all_models = BasePredictor.all()
  55. model_map = {}
  56. for model in all_models:
  57. module = all_models[model].__name__
  58. if module not in model_map:
  59. model_map[module] = []
  60. model_map[module].append(model)
  61. return model_map
  62. def print_info():
  63. """Print list of supported models in formatted.
  64. """
  65. try:
  66. sz = os.get_terminal_size()
  67. total_width = sz.columns
  68. first_width = 30
  69. second_width = total_width - first_width if total_width > 50 else 10
  70. except OSError:
  71. total_width = 100
  72. second_width = 100
  73. total_width -= 4
  74. models_table = PrettyTable()
  75. models_table.field_names = ["Modules", "Models"]
  76. model_map = get_all_models()
  77. for module in model_map:
  78. models = model_map[module]
  79. models_table.add_row(
  80. [
  81. textwrap.fill(
  82. f"{module}", width=total_width // 5), textwrap.fill(
  83. " ".join(models), width=total_width * 4 // 5)
  84. ],
  85. divider=True)
  86. table_width = len(str(models_table).split("\n")[0])
  87. print("{}".format("-" * table_width))
  88. print("PaddleX".center(table_width))
  89. print(models_table)
  90. print("Powered by PaddlePaddle!".rjust(table_width))
  91. print("{}".format("-" * table_width))
  92. def install(args):
  93. """install paddlex
  94. """
  95. # Enable debug info
  96. os.environ['PADDLE_PDX_DEBUG'] = 'True'
  97. # Disable eager initialization
  98. os.environ['PADDLE_PDX_EAGER_INIT'] = 'False'
  99. repo_names = args.devkits
  100. if len(repo_names) == 0:
  101. repo_names = get_all_supported_repo_names()
  102. setup(
  103. repo_names=repo_names,
  104. reinstall=args.reinstall or None,
  105. no_deps=args.no_deps,
  106. platform=args.platform,
  107. update_repos=args.update_repos)
  108. return
  109. def build_predict_config(model_name, model, input_path, device, output):
  110. """build predict config for paddlex
  111. """
  112. def dict2attrdict(dict_obj):
  113. """convert dict object to AttrDict
  114. """
  115. for key, value in dict_obj.items():
  116. if isinstance(value, dict):
  117. dict_obj[key] = dict2attrdict(value)
  118. return AttrDict(dict_obj)
  119. config = {
  120. 'Global': {
  121. 'model': model_name,
  122. 'device': device,
  123. 'output': output
  124. },
  125. 'Predict': {
  126. 'model_dir': model,
  127. 'input_path': input_path,
  128. }
  129. }
  130. return dict2attrdict(config)
  131. def predict(model_name, model, input_path, device, output):
  132. """predict using paddlex
  133. """
  134. config = build_predict_config(model_name, model, input_path, device, output)
  135. predict = build_predictor(config)
  136. return predict()
  137. # for CLI
  138. def main():
  139. """API for commad line
  140. """
  141. args = args_cfg()
  142. if args.install:
  143. install(args)
  144. else:
  145. print_info()
  146. return predict(args.model_name, args.model, args.input_path,
  147. args.device, args.output)