_cli.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from __future__ import annotations
  2. import sys
  3. import logging
  4. import argparse
  5. from typing import Any, List, Type, Optional
  6. from typing_extensions import ClassVar
  7. import httpx
  8. import pydantic
  9. import openai
  10. from . import _tools
  11. from .. import _ApiType, __version__
  12. from ._api import register_commands
  13. from ._utils import can_use_http2
  14. from ._errors import CLIError, display_error
  15. from .._compat import PYDANTIC_V1, ConfigDict, model_parse
  16. from .._models import BaseModel
  17. from .._exceptions import APIError
  18. logger = logging.getLogger()
  19. formatter = logging.Formatter("[%(asctime)s] %(message)s")
  20. handler = logging.StreamHandler(sys.stderr)
  21. handler.setFormatter(formatter)
  22. logger.addHandler(handler)
  23. class Arguments(BaseModel):
  24. if PYDANTIC_V1:
  25. class Config(pydantic.BaseConfig): # type: ignore
  26. extra: Any = pydantic.Extra.ignore # type: ignore
  27. else:
  28. model_config: ClassVar[ConfigDict] = ConfigDict(
  29. extra="ignore",
  30. )
  31. verbosity: int
  32. version: Optional[str] = None
  33. api_key: Optional[str]
  34. api_base: Optional[str]
  35. organization: Optional[str]
  36. proxy: Optional[List[str]]
  37. api_type: Optional[_ApiType] = None
  38. api_version: Optional[str] = None
  39. # azure
  40. azure_endpoint: Optional[str] = None
  41. azure_ad_token: Optional[str] = None
  42. # internal, set by subparsers to parse their specific args
  43. args_model: Optional[Type[BaseModel]] = None
  44. # internal, used so that subparsers can forward unknown arguments
  45. unknown_args: List[str] = []
  46. allow_unknown_args: bool = False
  47. def _build_parser() -> argparse.ArgumentParser:
  48. parser = argparse.ArgumentParser(description=None, prog="openai")
  49. parser.add_argument(
  50. "-v",
  51. "--verbose",
  52. action="count",
  53. dest="verbosity",
  54. default=0,
  55. help="Set verbosity.",
  56. )
  57. parser.add_argument("-b", "--api-base", help="What API base url to use.")
  58. parser.add_argument("-k", "--api-key", help="What API key to use.")
  59. parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
  60. parser.add_argument(
  61. "-o",
  62. "--organization",
  63. help="Which organization to run as (will use your default organization if not specified)",
  64. )
  65. parser.add_argument(
  66. "-t",
  67. "--api-type",
  68. type=str,
  69. choices=("openai", "azure"),
  70. help="The backend API to call, must be `openai` or `azure`",
  71. )
  72. parser.add_argument(
  73. "--api-version",
  74. help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
  75. )
  76. # azure
  77. parser.add_argument(
  78. "--azure-endpoint",
  79. help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
  80. )
  81. parser.add_argument(
  82. "--azure-ad-token",
  83. help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
  84. )
  85. # prints the package version
  86. parser.add_argument(
  87. "-V",
  88. "--version",
  89. action="version",
  90. version="%(prog)s " + __version__,
  91. )
  92. def help() -> None:
  93. parser.print_help()
  94. parser.set_defaults(func=help)
  95. subparsers = parser.add_subparsers()
  96. sub_api = subparsers.add_parser("api", help="Direct API calls")
  97. register_commands(sub_api)
  98. sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
  99. _tools.register_commands(sub_tools, subparsers)
  100. return parser
  101. def main() -> int:
  102. try:
  103. _main()
  104. except (APIError, CLIError, pydantic.ValidationError) as err:
  105. display_error(err)
  106. return 1
  107. except KeyboardInterrupt:
  108. sys.stderr.write("\n")
  109. return 1
  110. return 0
  111. def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
  112. # argparse by default will strip out the `--` but we want to keep it for unknown arguments
  113. if "--" in sys.argv:
  114. idx = sys.argv.index("--")
  115. known_args = sys.argv[1:idx]
  116. unknown_args = sys.argv[idx:]
  117. else:
  118. known_args = sys.argv[1:]
  119. unknown_args = []
  120. parsed, remaining_unknown = parser.parse_known_args(known_args)
  121. # append any remaining unknown arguments from the initial parsing
  122. remaining_unknown.extend(unknown_args)
  123. args = model_parse(Arguments, vars(parsed))
  124. if not args.allow_unknown_args:
  125. # we have to parse twice to ensure any unknown arguments
  126. # result in an error if that behaviour is desired
  127. parser.parse_args()
  128. return parsed, args, remaining_unknown
  129. def _main() -> None:
  130. parser = _build_parser()
  131. parsed, args, unknown = _parse_args(parser)
  132. if args.verbosity != 0:
  133. sys.stderr.write("Warning: --verbosity isn't supported yet\n")
  134. proxies: dict[str, httpx.BaseTransport] = {}
  135. if args.proxy is not None:
  136. for proxy in args.proxy:
  137. key = "https://" if proxy.startswith("https") else "http://"
  138. if key in proxies:
  139. raise CLIError(f"Multiple {key} proxies given - only the last one would be used")
  140. proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy)))
  141. http_client = httpx.Client(
  142. mounts=proxies or None,
  143. http2=can_use_http2(),
  144. )
  145. openai.http_client = http_client
  146. if args.organization:
  147. openai.organization = args.organization
  148. if args.api_key:
  149. openai.api_key = args.api_key
  150. if args.api_base:
  151. openai.base_url = args.api_base
  152. # azure
  153. if args.api_type is not None:
  154. openai.api_type = args.api_type
  155. if args.azure_endpoint is not None:
  156. openai.azure_endpoint = args.azure_endpoint
  157. if args.api_version is not None:
  158. openai.api_version = args.api_version
  159. if args.azure_ad_token is not None:
  160. openai.azure_ad_token = args.azure_ad_token
  161. try:
  162. if args.args_model:
  163. parsed.func(
  164. model_parse(
  165. args.args_model,
  166. {
  167. **{
  168. # we omit None values so that they can be defaulted to `NotGiven`
  169. # and we'll strip it from the API request
  170. key: value
  171. for key, value in vars(parsed).items()
  172. if value is not None
  173. },
  174. "unknown_args": unknown,
  175. },
  176. )
  177. )
  178. else:
  179. parsed.func()
  180. finally:
  181. try:
  182. http_client.close()
  183. except Exception:
  184. pass
  185. if __name__ == "__main__":
  186. sys.exit(main())