| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- from __future__ import annotations
- import sys
- import logging
- import argparse
- from typing import Any, List, Type, Optional
- from typing_extensions import ClassVar
- import httpx
- import pydantic
- import openai
- from . import _tools
- from .. import _ApiType, __version__
- from ._api import register_commands
- from ._utils import can_use_http2
- from ._errors import CLIError, display_error
- from .._compat import PYDANTIC_V1, ConfigDict, model_parse
- from .._models import BaseModel
- from .._exceptions import APIError
- logger = logging.getLogger()
- formatter = logging.Formatter("[%(asctime)s] %(message)s")
- handler = logging.StreamHandler(sys.stderr)
- handler.setFormatter(formatter)
- logger.addHandler(handler)
- class Arguments(BaseModel):
- if PYDANTIC_V1:
- class Config(pydantic.BaseConfig): # type: ignore
- extra: Any = pydantic.Extra.ignore # type: ignore
- else:
- model_config: ClassVar[ConfigDict] = ConfigDict(
- extra="ignore",
- )
- verbosity: int
- version: Optional[str] = None
- api_key: Optional[str]
- api_base: Optional[str]
- organization: Optional[str]
- proxy: Optional[List[str]]
- api_type: Optional[_ApiType] = None
- api_version: Optional[str] = None
- # azure
- azure_endpoint: Optional[str] = None
- azure_ad_token: Optional[str] = None
- # internal, set by subparsers to parse their specific args
- args_model: Optional[Type[BaseModel]] = None
- # internal, used so that subparsers can forward unknown arguments
- unknown_args: List[str] = []
- allow_unknown_args: bool = False
- def _build_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser(description=None, prog="openai")
- parser.add_argument(
- "-v",
- "--verbose",
- action="count",
- dest="verbosity",
- default=0,
- help="Set verbosity.",
- )
- parser.add_argument("-b", "--api-base", help="What API base url to use.")
- parser.add_argument("-k", "--api-key", help="What API key to use.")
- parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
- parser.add_argument(
- "-o",
- "--organization",
- help="Which organization to run as (will use your default organization if not specified)",
- )
- parser.add_argument(
- "-t",
- "--api-type",
- type=str,
- choices=("openai", "azure"),
- help="The backend API to call, must be `openai` or `azure`",
- )
- parser.add_argument(
- "--api-version",
- help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
- )
- # azure
- parser.add_argument(
- "--azure-endpoint",
- help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
- )
- parser.add_argument(
- "--azure-ad-token",
- help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
- )
- # prints the package version
- parser.add_argument(
- "-V",
- "--version",
- action="version",
- version="%(prog)s " + __version__,
- )
- def help() -> None:
- parser.print_help()
- parser.set_defaults(func=help)
- subparsers = parser.add_subparsers()
- sub_api = subparsers.add_parser("api", help="Direct API calls")
- register_commands(sub_api)
- sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
- _tools.register_commands(sub_tools, subparsers)
- return parser
- def main() -> int:
- try:
- _main()
- except (APIError, CLIError, pydantic.ValidationError) as err:
- display_error(err)
- return 1
- except KeyboardInterrupt:
- sys.stderr.write("\n")
- return 1
- return 0
- def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
- # argparse by default will strip out the `--` but we want to keep it for unknown arguments
- if "--" in sys.argv:
- idx = sys.argv.index("--")
- known_args = sys.argv[1:idx]
- unknown_args = sys.argv[idx:]
- else:
- known_args = sys.argv[1:]
- unknown_args = []
- parsed, remaining_unknown = parser.parse_known_args(known_args)
- # append any remaining unknown arguments from the initial parsing
- remaining_unknown.extend(unknown_args)
- args = model_parse(Arguments, vars(parsed))
- if not args.allow_unknown_args:
- # we have to parse twice to ensure any unknown arguments
- # result in an error if that behaviour is desired
- parser.parse_args()
- return parsed, args, remaining_unknown
- def _main() -> None:
- parser = _build_parser()
- parsed, args, unknown = _parse_args(parser)
- if args.verbosity != 0:
- sys.stderr.write("Warning: --verbosity isn't supported yet\n")
- proxies: dict[str, httpx.BaseTransport] = {}
- if args.proxy is not None:
- for proxy in args.proxy:
- key = "https://" if proxy.startswith("https") else "http://"
- if key in proxies:
- raise CLIError(f"Multiple {key} proxies given - only the last one would be used")
- proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy)))
- http_client = httpx.Client(
- mounts=proxies or None,
- http2=can_use_http2(),
- )
- openai.http_client = http_client
- if args.organization:
- openai.organization = args.organization
- if args.api_key:
- openai.api_key = args.api_key
- if args.api_base:
- openai.base_url = args.api_base
- # azure
- if args.api_type is not None:
- openai.api_type = args.api_type
- if args.azure_endpoint is not None:
- openai.azure_endpoint = args.azure_endpoint
- if args.api_version is not None:
- openai.api_version = args.api_version
- if args.azure_ad_token is not None:
- openai.azure_ad_token = args.azure_ad_token
- try:
- if args.args_model:
- parsed.func(
- model_parse(
- args.args_model,
- {
- **{
- # we omit None values so that they can be defaulted to `NotGiven`
- # and we'll strip it from the API request
- key: value
- for key, value in vars(parsed).items()
- if value is not None
- },
- "unknown_args": unknown,
- },
- )
- )
- else:
- parsed.func()
- finally:
- try:
- http_client.close()
- except Exception:
- pass
- if __name__ == "__main__":
- sys.exit(main())
|