_response.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848
  1. from __future__ import annotations
  2. import os
  3. import inspect
  4. import logging
  5. import datetime
  6. import functools
  7. from types import TracebackType
  8. from typing import (
  9. TYPE_CHECKING,
  10. Any,
  11. Union,
  12. Generic,
  13. TypeVar,
  14. Callable,
  15. Iterator,
  16. AsyncIterator,
  17. cast,
  18. overload,
  19. )
  20. from typing_extensions import Awaitable, ParamSpec, override, get_origin
  21. import anyio
  22. import httpx
  23. import pydantic
  24. from ._types import NoneType
  25. from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base
  26. from ._models import BaseModel, is_basemodel, add_request_id
  27. from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER
  28. from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
  29. from ._exceptions import OpenAIError, APIResponseValidationError
  30. if TYPE_CHECKING:
  31. from ._models import FinalRequestOptions
  32. from ._base_client import BaseClient
  33. P = ParamSpec("P")
  34. R = TypeVar("R")
  35. _T = TypeVar("_T")
  36. _APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
  37. _AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
  38. log: logging.Logger = logging.getLogger(__name__)
  39. class BaseAPIResponse(Generic[R]):
  40. _cast_to: type[R]
  41. _client: BaseClient[Any, Any]
  42. _parsed_by_type: dict[type[Any], Any]
  43. _is_sse_stream: bool
  44. _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
  45. _options: FinalRequestOptions
  46. http_response: httpx.Response
  47. retries_taken: int
  48. """The number of retries made. If no retries happened this will be `0`"""
  49. def __init__(
  50. self,
  51. *,
  52. raw: httpx.Response,
  53. cast_to: type[R],
  54. client: BaseClient[Any, Any],
  55. stream: bool,
  56. stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
  57. options: FinalRequestOptions,
  58. retries_taken: int = 0,
  59. ) -> None:
  60. self._cast_to = cast_to
  61. self._client = client
  62. self._parsed_by_type = {}
  63. self._is_sse_stream = stream
  64. self._stream_cls = stream_cls
  65. self._options = options
  66. self.http_response = raw
  67. self.retries_taken = retries_taken
  68. @property
  69. def headers(self) -> httpx.Headers:
  70. return self.http_response.headers
  71. @property
  72. def http_request(self) -> httpx.Request:
  73. """Returns the httpx Request instance associated with the current response."""
  74. return self.http_response.request
  75. @property
  76. def status_code(self) -> int:
  77. return self.http_response.status_code
  78. @property
  79. def url(self) -> httpx.URL:
  80. """Returns the URL for which the request was made."""
  81. return self.http_response.url
  82. @property
  83. def method(self) -> str:
  84. return self.http_request.method
  85. @property
  86. def http_version(self) -> str:
  87. return self.http_response.http_version
  88. @property
  89. def elapsed(self) -> datetime.timedelta:
  90. """The time taken for the complete request/response cycle to complete."""
  91. return self.http_response.elapsed
  92. @property
  93. def is_closed(self) -> bool:
  94. """Whether or not the response body has been closed.
  95. If this is False then there is response data that has not been read yet.
  96. You must either fully consume the response body or call `.close()`
  97. before discarding the response to prevent resource leaks.
  98. """
  99. return self.http_response.is_closed
  100. @override
  101. def __repr__(self) -> str:
  102. return (
  103. f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>"
  104. )
  105. def _parse(self, *, to: type[_T] | None = None) -> R | _T:
  106. cast_to = to if to is not None else self._cast_to
  107. # unwrap `TypeAlias('Name', T)` -> `T`
  108. if is_type_alias_type(cast_to):
  109. cast_to = cast_to.__value__ # type: ignore[unreachable]
  110. # unwrap `Annotated[T, ...]` -> `T`
  111. if cast_to and is_annotated_type(cast_to):
  112. cast_to = extract_type_arg(cast_to, 0)
  113. origin = get_origin(cast_to) or cast_to
  114. if self._is_sse_stream:
  115. if to:
  116. if not is_stream_class_type(to):
  117. raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
  118. return cast(
  119. _T,
  120. to(
  121. cast_to=extract_stream_chunk_type(
  122. to,
  123. failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
  124. ),
  125. response=self.http_response,
  126. client=cast(Any, self._client),
  127. ),
  128. )
  129. if self._stream_cls:
  130. return cast(
  131. R,
  132. self._stream_cls(
  133. cast_to=extract_stream_chunk_type(self._stream_cls),
  134. response=self.http_response,
  135. client=cast(Any, self._client),
  136. ),
  137. )
  138. stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls)
  139. if stream_cls is None:
  140. raise MissingStreamClassError()
  141. return cast(
  142. R,
  143. stream_cls(
  144. cast_to=cast_to,
  145. response=self.http_response,
  146. client=cast(Any, self._client),
  147. ),
  148. )
  149. if cast_to is NoneType:
  150. return cast(R, None)
  151. response = self.http_response
  152. if cast_to == str:
  153. return cast(R, response.text)
  154. if cast_to == bytes:
  155. return cast(R, response.content)
  156. if cast_to == int:
  157. return cast(R, int(response.text))
  158. if cast_to == float:
  159. return cast(R, float(response.text))
  160. if cast_to == bool:
  161. return cast(R, response.text.lower() == "true")
  162. # handle the legacy binary response case
  163. if inspect.isclass(cast_to) and cast_to.__name__ == "HttpxBinaryResponseContent":
  164. return cast(R, cast_to(response)) # type: ignore
  165. if origin == APIResponse:
  166. raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
  167. if inspect.isclass(origin) and issubclass(origin, httpx.Response):
  168. # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
  169. # and pass that class to our request functions. We cannot change the variance to be either
  170. # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
  171. # the response class ourselves but that is something that should be supported directly in httpx
  172. # as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
  173. if cast_to != httpx.Response:
  174. raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
  175. return cast(R, response)
  176. if (
  177. inspect.isclass(
  178. origin # pyright: ignore[reportUnknownArgumentType]
  179. )
  180. and not issubclass(origin, BaseModel)
  181. and issubclass(origin, pydantic.BaseModel)
  182. ):
  183. raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
  184. if (
  185. cast_to is not object
  186. and not origin is list
  187. and not origin is dict
  188. and not origin is Union
  189. and not issubclass(origin, BaseModel)
  190. ):
  191. raise RuntimeError(
  192. f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
  193. )
  194. # split is required to handle cases where additional information is included
  195. # in the response, e.g. application/json; charset=utf-8
  196. content_type, *_ = response.headers.get("content-type", "*").split(";")
  197. if not content_type.endswith("json"):
  198. if is_basemodel(cast_to):
  199. try:
  200. data = response.json()
  201. except Exception as exc:
  202. log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
  203. else:
  204. return self._client._process_response_data(
  205. data=data,
  206. cast_to=cast_to, # type: ignore
  207. response=response,
  208. )
  209. if self._client._strict_response_validation:
  210. raise APIResponseValidationError(
  211. response=response,
  212. message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
  213. body=response.text,
  214. )
  215. # If the API responds with content that isn't JSON then we just return
  216. # the (decoded) text without performing any parsing so that you can still
  217. # handle the response however you need to.
  218. return response.text # type: ignore
  219. data = response.json()
  220. return self._client._process_response_data(
  221. data=data,
  222. cast_to=cast_to, # type: ignore
  223. response=response,
  224. )
  225. class APIResponse(BaseAPIResponse[R]):
  226. @property
  227. def request_id(self) -> str | None:
  228. return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
  229. @overload
  230. def parse(self, *, to: type[_T]) -> _T: ...
  231. @overload
  232. def parse(self) -> R: ...
  233. def parse(self, *, to: type[_T] | None = None) -> R | _T:
  234. """Returns the rich python representation of this response's data.
  235. For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
  236. You can customise the type that the response is parsed into through
  237. the `to` argument, e.g.
  238. ```py
  239. from openai import BaseModel
  240. class MyModel(BaseModel):
  241. foo: str
  242. obj = response.parse(to=MyModel)
  243. print(obj.foo)
  244. ```
  245. We support parsing:
  246. - `BaseModel`
  247. - `dict`
  248. - `list`
  249. - `Union`
  250. - `str`
  251. - `int`
  252. - `float`
  253. - `httpx.Response`
  254. """
  255. cache_key = to if to is not None else self._cast_to
  256. cached = self._parsed_by_type.get(cache_key)
  257. if cached is not None:
  258. return cached # type: ignore[no-any-return]
  259. if not self._is_sse_stream:
  260. self.read()
  261. parsed = self._parse(to=to)
  262. if is_given(self._options.post_parser):
  263. parsed = self._options.post_parser(parsed)
  264. if isinstance(parsed, BaseModel):
  265. add_request_id(parsed, self.request_id)
  266. self._parsed_by_type[cache_key] = parsed
  267. return cast(R, parsed)
  268. def read(self) -> bytes:
  269. """Read and return the binary response content."""
  270. try:
  271. return self.http_response.read()
  272. except httpx.StreamConsumed as exc:
  273. # The default error raised by httpx isn't very
  274. # helpful in our case so we re-raise it with
  275. # a different error message.
  276. raise StreamAlreadyConsumed() from exc
  277. def text(self) -> str:
  278. """Read and decode the response content into a string."""
  279. self.read()
  280. return self.http_response.text
  281. def json(self) -> object:
  282. """Read and decode the JSON response content."""
  283. self.read()
  284. return self.http_response.json()
  285. def close(self) -> None:
  286. """Close the response and release the connection.
  287. Automatically called if the response body is read to completion.
  288. """
  289. self.http_response.close()
  290. def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
  291. """
  292. A byte-iterator over the decoded response content.
  293. This automatically handles gzip, deflate and brotli encoded responses.
  294. """
  295. for chunk in self.http_response.iter_bytes(chunk_size):
  296. yield chunk
  297. def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
  298. """A str-iterator over the decoded response content
  299. that handles both gzip, deflate, etc but also detects the content's
  300. string encoding.
  301. """
  302. for chunk in self.http_response.iter_text(chunk_size):
  303. yield chunk
  304. def iter_lines(self) -> Iterator[str]:
  305. """Like `iter_text()` but will only yield chunks for each line"""
  306. for chunk in self.http_response.iter_lines():
  307. yield chunk
  308. class AsyncAPIResponse(BaseAPIResponse[R]):
  309. @property
  310. def request_id(self) -> str | None:
  311. return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
  312. @overload
  313. async def parse(self, *, to: type[_T]) -> _T: ...
  314. @overload
  315. async def parse(self) -> R: ...
  316. async def parse(self, *, to: type[_T] | None = None) -> R | _T:
  317. """Returns the rich python representation of this response's data.
  318. For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
  319. You can customise the type that the response is parsed into through
  320. the `to` argument, e.g.
  321. ```py
  322. from openai import BaseModel
  323. class MyModel(BaseModel):
  324. foo: str
  325. obj = response.parse(to=MyModel)
  326. print(obj.foo)
  327. ```
  328. We support parsing:
  329. - `BaseModel`
  330. - `dict`
  331. - `list`
  332. - `Union`
  333. - `str`
  334. - `httpx.Response`
  335. """
  336. cache_key = to if to is not None else self._cast_to
  337. cached = self._parsed_by_type.get(cache_key)
  338. if cached is not None:
  339. return cached # type: ignore[no-any-return]
  340. if not self._is_sse_stream:
  341. await self.read()
  342. parsed = self._parse(to=to)
  343. if is_given(self._options.post_parser):
  344. parsed = self._options.post_parser(parsed)
  345. if isinstance(parsed, BaseModel):
  346. add_request_id(parsed, self.request_id)
  347. self._parsed_by_type[cache_key] = parsed
  348. return cast(R, parsed)
  349. async def read(self) -> bytes:
  350. """Read and return the binary response content."""
  351. try:
  352. return await self.http_response.aread()
  353. except httpx.StreamConsumed as exc:
  354. # the default error raised by httpx isn't very
  355. # helpful in our case so we re-raise it with
  356. # a different error message
  357. raise StreamAlreadyConsumed() from exc
  358. async def text(self) -> str:
  359. """Read and decode the response content into a string."""
  360. await self.read()
  361. return self.http_response.text
  362. async def json(self) -> object:
  363. """Read and decode the JSON response content."""
  364. await self.read()
  365. return self.http_response.json()
  366. async def close(self) -> None:
  367. """Close the response and release the connection.
  368. Automatically called if the response body is read to completion.
  369. """
  370. await self.http_response.aclose()
  371. async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
  372. """
  373. A byte-iterator over the decoded response content.
  374. This automatically handles gzip, deflate and brotli encoded responses.
  375. """
  376. async for chunk in self.http_response.aiter_bytes(chunk_size):
  377. yield chunk
  378. async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
  379. """A str-iterator over the decoded response content
  380. that handles both gzip, deflate, etc but also detects the content's
  381. string encoding.
  382. """
  383. async for chunk in self.http_response.aiter_text(chunk_size):
  384. yield chunk
  385. async def iter_lines(self) -> AsyncIterator[str]:
  386. """Like `iter_text()` but will only yield chunks for each line"""
  387. async for chunk in self.http_response.aiter_lines():
  388. yield chunk
  389. class BinaryAPIResponse(APIResponse[bytes]):
  390. """Subclass of APIResponse providing helpers for dealing with binary data.
  391. Note: If you want to stream the response data instead of eagerly reading it
  392. all at once then you should use `.with_streaming_response` when making
  393. the API request, e.g. `.with_streaming_response.get_binary_response()`
  394. """
  395. def write_to_file(
  396. self,
  397. file: str | os.PathLike[str],
  398. ) -> None:
  399. """Write the output to the given file.
  400. Accepts a filename or any path-like object, e.g. pathlib.Path
  401. Note: if you want to stream the data to the file instead of writing
  402. all at once then you should use `.with_streaming_response` when making
  403. the API request, e.g. `.with_streaming_response.get_binary_response()`
  404. """
  405. with open(file, mode="wb") as f:
  406. for data in self.iter_bytes():
  407. f.write(data)
  408. class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]):
  409. """Subclass of APIResponse providing helpers for dealing with binary data.
  410. Note: If you want to stream the response data instead of eagerly reading it
  411. all at once then you should use `.with_streaming_response` when making
  412. the API request, e.g. `.with_streaming_response.get_binary_response()`
  413. """
  414. async def write_to_file(
  415. self,
  416. file: str | os.PathLike[str],
  417. ) -> None:
  418. """Write the output to the given file.
  419. Accepts a filename or any path-like object, e.g. pathlib.Path
  420. Note: if you want to stream the data to the file instead of writing
  421. all at once then you should use `.with_streaming_response` when making
  422. the API request, e.g. `.with_streaming_response.get_binary_response()`
  423. """
  424. path = anyio.Path(file)
  425. async with await path.open(mode="wb") as f:
  426. async for data in self.iter_bytes():
  427. await f.write(data)
  428. class StreamedBinaryAPIResponse(APIResponse[bytes]):
  429. def stream_to_file(
  430. self,
  431. file: str | os.PathLike[str],
  432. *,
  433. chunk_size: int | None = None,
  434. ) -> None:
  435. """Streams the output to the given file.
  436. Accepts a filename or any path-like object, e.g. pathlib.Path
  437. """
  438. with open(file, mode="wb") as f:
  439. for data in self.iter_bytes(chunk_size):
  440. f.write(data)
  441. class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]):
  442. async def stream_to_file(
  443. self,
  444. file: str | os.PathLike[str],
  445. *,
  446. chunk_size: int | None = None,
  447. ) -> None:
  448. """Streams the output to the given file.
  449. Accepts a filename or any path-like object, e.g. pathlib.Path
  450. """
  451. path = anyio.Path(file)
  452. async with await path.open(mode="wb") as f:
  453. async for data in self.iter_bytes(chunk_size):
  454. await f.write(data)
  455. class MissingStreamClassError(TypeError):
  456. def __init__(self) -> None:
  457. super().__init__(
  458. "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference",
  459. )
  460. class StreamAlreadyConsumed(OpenAIError):
  461. """
  462. Attempted to read or stream content, but the content has already
  463. been streamed.
  464. This can happen if you use a method like `.iter_lines()` and then attempt
  465. to read th entire response body afterwards, e.g.
  466. ```py
  467. response = await client.post(...)
  468. async for line in response.iter_lines():
  469. ... # do something with `line`
  470. content = await response.read()
  471. # ^ error
  472. ```
  473. If you want this behaviour you'll need to either manually accumulate the response
  474. content or call `await response.read()` before iterating over the stream.
  475. """
  476. def __init__(self) -> None:
  477. message = (
  478. "Attempted to read or stream some content, but the content has "
  479. "already been streamed. "
  480. "This could be due to attempting to stream the response "
  481. "content more than once."
  482. "\n\n"
  483. "You can fix this by manually accumulating the response content while streaming "
  484. "or by calling `.read()` before starting to stream."
  485. )
  486. super().__init__(message)
  487. class ResponseContextManager(Generic[_APIResponseT]):
  488. """Context manager for ensuring that a request is not made
  489. until it is entered and that the response will always be closed
  490. when the context manager exits
  491. """
  492. def __init__(self, request_func: Callable[[], _APIResponseT]) -> None:
  493. self._request_func = request_func
  494. self.__response: _APIResponseT | None = None
  495. def __enter__(self) -> _APIResponseT:
  496. self.__response = self._request_func()
  497. return self.__response
  498. def __exit__(
  499. self,
  500. exc_type: type[BaseException] | None,
  501. exc: BaseException | None,
  502. exc_tb: TracebackType | None,
  503. ) -> None:
  504. if self.__response is not None:
  505. self.__response.close()
  506. class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]):
  507. """Context manager for ensuring that a request is not made
  508. until it is entered and that the response will always be closed
  509. when the context manager exits
  510. """
  511. def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None:
  512. self._api_request = api_request
  513. self.__response: _AsyncAPIResponseT | None = None
  514. async def __aenter__(self) -> _AsyncAPIResponseT:
  515. self.__response = await self._api_request
  516. return self.__response
  517. async def __aexit__(
  518. self,
  519. exc_type: type[BaseException] | None,
  520. exc: BaseException | None,
  521. exc_tb: TracebackType | None,
  522. ) -> None:
  523. if self.__response is not None:
  524. await self.__response.close()
  525. def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]:
  526. """Higher order function that takes one of our bound API methods and wraps it
  527. to support streaming and returning the raw `APIResponse` object directly.
  528. """
  529. @functools.wraps(func)
  530. def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]:
  531. extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  532. extra_headers[RAW_RESPONSE_HEADER] = "stream"
  533. kwargs["extra_headers"] = extra_headers
  534. make_request = functools.partial(func, *args, **kwargs)
  535. return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request))
  536. return wrapped
  537. def async_to_streamed_response_wrapper(
  538. func: Callable[P, Awaitable[R]],
  539. ) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]:
  540. """Higher order function that takes one of our bound API methods and wraps it
  541. to support streaming and returning the raw `APIResponse` object directly.
  542. """
  543. @functools.wraps(func)
  544. def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]:
  545. extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  546. extra_headers[RAW_RESPONSE_HEADER] = "stream"
  547. kwargs["extra_headers"] = extra_headers
  548. make_request = func(*args, **kwargs)
  549. return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request))
  550. return wrapped
  551. def to_custom_streamed_response_wrapper(
  552. func: Callable[P, object],
  553. response_cls: type[_APIResponseT],
  554. ) -> Callable[P, ResponseContextManager[_APIResponseT]]:
  555. """Higher order function that takes one of our bound API methods and an `APIResponse` class
  556. and wraps the method to support streaming and returning the given response class directly.
  557. Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
  558. """
  559. @functools.wraps(func)
  560. def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]:
  561. extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  562. extra_headers[RAW_RESPONSE_HEADER] = "stream"
  563. extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
  564. kwargs["extra_headers"] = extra_headers
  565. make_request = functools.partial(func, *args, **kwargs)
  566. return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request))
  567. return wrapped
  568. def async_to_custom_streamed_response_wrapper(
  569. func: Callable[P, Awaitable[object]],
  570. response_cls: type[_AsyncAPIResponseT],
  571. ) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]:
  572. """Higher order function that takes one of our bound API methods and an `APIResponse` class
  573. and wraps the method to support streaming and returning the given response class directly.
  574. Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
  575. """
  576. @functools.wraps(func)
  577. def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]:
  578. extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  579. extra_headers[RAW_RESPONSE_HEADER] = "stream"
  580. extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
  581. kwargs["extra_headers"] = extra_headers
  582. make_request = func(*args, **kwargs)
  583. return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request))
  584. return wrapped
  585. def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
  586. """Higher order function that takes one of our bound API methods and wraps it
  587. to support returning the raw `APIResponse` object directly.
  588. """
  589. @functools.wraps(func)
  590. def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]:
  591. extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  592. extra_headers[RAW_RESPONSE_HEADER] = "raw"
  593. kwargs["extra_headers"] = extra_headers
  594. return cast(APIResponse[R], func(*args, **kwargs))
  595. return wrapped
  596. def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]:
  597. """Higher order function that takes one of our bound API methods and wraps it
  598. to support returning the raw `APIResponse` object directly.
  599. """
  600. @functools.wraps(func)
  601. async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]:
  602. extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  603. extra_headers[RAW_RESPONSE_HEADER] = "raw"
  604. kwargs["extra_headers"] = extra_headers
  605. return cast(AsyncAPIResponse[R], await func(*args, **kwargs))
  606. return wrapped
  607. def to_custom_raw_response_wrapper(
  608. func: Callable[P, object],
  609. response_cls: type[_APIResponseT],
  610. ) -> Callable[P, _APIResponseT]:
  611. """Higher order function that takes one of our bound API methods and an `APIResponse` class
  612. and wraps the method to support returning the given response class directly.
  613. Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
  614. """
  615. @functools.wraps(func)
  616. def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT:
  617. extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  618. extra_headers[RAW_RESPONSE_HEADER] = "raw"
  619. extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
  620. kwargs["extra_headers"] = extra_headers
  621. return cast(_APIResponseT, func(*args, **kwargs))
  622. return wrapped
  623. def async_to_custom_raw_response_wrapper(
  624. func: Callable[P, Awaitable[object]],
  625. response_cls: type[_AsyncAPIResponseT],
  626. ) -> Callable[P, Awaitable[_AsyncAPIResponseT]]:
  627. """Higher order function that takes one of our bound API methods and an `APIResponse` class
  628. and wraps the method to support returning the given response class directly.
  629. Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])`
  630. """
  631. @functools.wraps(func)
  632. def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]:
  633. extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
  634. extra_headers[RAW_RESPONSE_HEADER] = "raw"
  635. extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls
  636. kwargs["extra_headers"] = extra_headers
  637. return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs))
  638. return wrapped
  639. def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type:
  640. """Given a type like `APIResponse[T]`, returns the generic type variable `T`.
  641. This also handles the case where a concrete subclass is given, e.g.
  642. ```py
  643. class MyResponse(APIResponse[bytes]):
  644. ...
  645. extract_response_type(MyResponse) -> bytes
  646. ```
  647. """
  648. return extract_type_var_from_base(
  649. typ,
  650. generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)),
  651. index=0,
  652. )