arg.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shlex
  16. class CLIArgument(object):
  17. """CLIArgument"""
  18. def __init__(self, key, *vals, quote=False, sep=" "):
  19. super().__init__()
  20. self.key = str(key)
  21. self.vals = [str(v) for v in vals]
  22. if quote and os.name != "posix":
  23. raise ValueError("`quote` cannot be True on non-POSIX compliant systems.")
  24. self.quote = quote
  25. self.sep = sep
  26. def __repr__(self):
  27. return self.sep.join(self.lst)
  28. @property
  29. def lst(self):
  30. """lst"""
  31. if self.quote:
  32. vals = [shlex.quote(val) for val in self.vals]
  33. else:
  34. vals = self.vals
  35. return [self.key, *vals]
  36. def gather_opts_args(args, opts_key):
  37. """gather_opts_args"""
  38. def _is_opts_arg(arg):
  39. return arg.key == opts_key
  40. args = sorted(args, key=_is_opts_arg)
  41. idx = None
  42. for i, arg in enumerate(args):
  43. if _is_opts_arg(arg):
  44. idx = i
  45. break
  46. if idx is not None:
  47. opts_args = args[idx:]
  48. args = args[:idx]
  49. all_vals = []
  50. for arg in opts_args:
  51. all_vals.extend(arg.vals)
  52. args.append(CLIArgument(opts_key, *all_vals))
  53. return args