_models.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. from __future__ import annotations
  2. import os
  3. import inspect
  4. import weakref
  5. from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
  6. from datetime import date, datetime
  7. from typing_extensions import (
  8. List,
  9. Unpack,
  10. Literal,
  11. ClassVar,
  12. Protocol,
  13. Required,
  14. Sequence,
  15. ParamSpec,
  16. TypedDict,
  17. TypeGuard,
  18. final,
  19. override,
  20. runtime_checkable,
  21. )
  22. import pydantic
  23. from pydantic.fields import FieldInfo
  24. from ._types import (
  25. Body,
  26. IncEx,
  27. Query,
  28. ModelT,
  29. Headers,
  30. Timeout,
  31. NotGiven,
  32. AnyMapping,
  33. HttpxRequestFiles,
  34. )
  35. from ._utils import (
  36. PropertyInfo,
  37. is_list,
  38. is_given,
  39. json_safe,
  40. lru_cache,
  41. is_mapping,
  42. parse_date,
  43. coerce_boolean,
  44. parse_datetime,
  45. strip_not_given,
  46. extract_type_arg,
  47. is_annotated_type,
  48. is_type_alias_type,
  49. strip_annotated_type,
  50. )
  51. from ._compat import (
  52. PYDANTIC_V1,
  53. ConfigDict,
  54. GenericModel as BaseGenericModel,
  55. get_args,
  56. is_union,
  57. parse_obj,
  58. get_origin,
  59. is_literal_type,
  60. get_model_config,
  61. get_model_fields,
  62. field_get_default,
  63. )
  64. from ._constants import RAW_RESPONSE_HEADER
  65. if TYPE_CHECKING:
  66. from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
  67. __all__ = ["BaseModel", "GenericModel"]
  68. _T = TypeVar("_T")
  69. _BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
  70. P = ParamSpec("P")
  71. ReprArgs = Sequence[Tuple[Optional[str], Any]]
  72. @runtime_checkable
  73. class _ConfigProtocol(Protocol):
  74. allow_population_by_field_name: bool
  75. class BaseModel(pydantic.BaseModel):
  76. if PYDANTIC_V1:
  77. @property
  78. @override
  79. def model_fields_set(self) -> set[str]:
  80. # a forwards-compat shim for pydantic v2
  81. return self.__fields_set__ # type: ignore
  82. class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
  83. extra: Any = pydantic.Extra.allow # type: ignore
  84. @override
  85. def __repr_args__(self) -> ReprArgs:
  86. # we don't want these attributes to be included when something like `rich.print` is used
  87. return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]
  88. else:
  89. model_config: ClassVar[ConfigDict] = ConfigDict(
  90. extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
  91. )
  92. if TYPE_CHECKING:
  93. _request_id: Optional[str] = None
  94. """The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
  95. This will **only** be set for the top-level response object, it will not be defined for nested objects. For example:
  96. ```py
  97. completion = await client.chat.completions.create(...)
  98. completion._request_id # req_id_xxx
  99. completion.usage._request_id # raises `AttributeError`
  100. ```
  101. Note: unlike other properties that use an `_` prefix, this property
  102. *is* public. Unless documented otherwise, all other `_` prefix properties,
  103. methods and modules are *private*.
  104. """
  105. def to_dict(
  106. self,
  107. *,
  108. mode: Literal["json", "python"] = "python",
  109. use_api_names: bool = True,
  110. exclude_unset: bool = True,
  111. exclude_defaults: bool = False,
  112. exclude_none: bool = False,
  113. warnings: bool = True,
  114. ) -> dict[str, object]:
  115. """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
  116. By default, fields that were not set by the API will not be included,
  117. and keys will match the API response, *not* the property names from the model.
  118. For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
  119. the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
  120. Args:
  121. mode:
  122. If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
  123. If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
  124. use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
  125. exclude_unset: Whether to exclude fields that have not been explicitly set.
  126. exclude_defaults: Whether to exclude fields that are set to their default value from the output.
  127. exclude_none: Whether to exclude fields that have a value of `None` from the output.
  128. warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
  129. """
  130. return self.model_dump(
  131. mode=mode,
  132. by_alias=use_api_names,
  133. exclude_unset=exclude_unset,
  134. exclude_defaults=exclude_defaults,
  135. exclude_none=exclude_none,
  136. warnings=warnings,
  137. )
  138. def to_json(
  139. self,
  140. *,
  141. indent: int | None = 2,
  142. use_api_names: bool = True,
  143. exclude_unset: bool = True,
  144. exclude_defaults: bool = False,
  145. exclude_none: bool = False,
  146. warnings: bool = True,
  147. ) -> str:
  148. """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
  149. By default, fields that were not set by the API will not be included,
  150. and keys will match the API response, *not* the property names from the model.
  151. For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
  152. the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
  153. Args:
  154. indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
  155. use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
  156. exclude_unset: Whether to exclude fields that have not been explicitly set.
  157. exclude_defaults: Whether to exclude fields that have the default value.
  158. exclude_none: Whether to exclude fields that have a value of `None`.
  159. warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
  160. """
  161. return self.model_dump_json(
  162. indent=indent,
  163. by_alias=use_api_names,
  164. exclude_unset=exclude_unset,
  165. exclude_defaults=exclude_defaults,
  166. exclude_none=exclude_none,
  167. warnings=warnings,
  168. )
  169. @override
  170. def __str__(self) -> str:
  171. # mypy complains about an invalid self arg
  172. return f"{self.__repr_name__()}({self.__repr_str__(', ')})" # type: ignore[misc]
  173. # Override the 'construct' method in a way that supports recursive parsing without validation.
  174. # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
  175. @classmethod
  176. @override
  177. def construct( # pyright: ignore[reportIncompatibleMethodOverride]
  178. __cls: Type[ModelT],
  179. _fields_set: set[str] | None = None,
  180. **values: object,
  181. ) -> ModelT:
  182. m = __cls.__new__(__cls)
  183. fields_values: dict[str, object] = {}
  184. config = get_model_config(__cls)
  185. populate_by_name = (
  186. config.allow_population_by_field_name
  187. if isinstance(config, _ConfigProtocol)
  188. else config.get("populate_by_name")
  189. )
  190. if _fields_set is None:
  191. _fields_set = set()
  192. model_fields = get_model_fields(__cls)
  193. for name, field in model_fields.items():
  194. key = field.alias
  195. if key is None or (key not in values and populate_by_name):
  196. key = name
  197. if key in values:
  198. fields_values[name] = _construct_field(value=values[key], field=field, key=key)
  199. _fields_set.add(name)
  200. else:
  201. fields_values[name] = field_get_default(field)
  202. extra_field_type = _get_extra_fields_type(__cls)
  203. _extra = {}
  204. for key, value in values.items():
  205. if key not in model_fields:
  206. parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value
  207. if PYDANTIC_V1:
  208. _fields_set.add(key)
  209. fields_values[key] = parsed
  210. else:
  211. _extra[key] = parsed
  212. object.__setattr__(m, "__dict__", fields_values)
  213. if PYDANTIC_V1:
  214. # init_private_attributes() does not exist in v2
  215. m._init_private_attributes() # type: ignore
  216. # copied from Pydantic v1's `construct()` method
  217. object.__setattr__(m, "__fields_set__", _fields_set)
  218. else:
  219. # these properties are copied from Pydantic's `model_construct()` method
  220. object.__setattr__(m, "__pydantic_private__", None)
  221. object.__setattr__(m, "__pydantic_extra__", _extra)
  222. object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
  223. return m
  224. if not TYPE_CHECKING:
  225. # type checkers incorrectly complain about this assignment
  226. # because the type signatures are technically different
  227. # although not in practice
  228. model_construct = construct
  229. if PYDANTIC_V1:
  230. # we define aliases for some of the new pydantic v2 methods so
  231. # that we can just document these methods without having to specify
  232. # a specific pydantic version as some users may not know which
  233. # pydantic version they are currently using
  234. @override
  235. def model_dump(
  236. self,
  237. *,
  238. mode: Literal["json", "python"] | str = "python",
  239. include: IncEx | None = None,
  240. exclude: IncEx | None = None,
  241. context: Any | None = None,
  242. by_alias: bool | None = None,
  243. exclude_unset: bool = False,
  244. exclude_defaults: bool = False,
  245. exclude_none: bool = False,
  246. exclude_computed_fields: bool = False,
  247. round_trip: bool = False,
  248. warnings: bool | Literal["none", "warn", "error"] = True,
  249. fallback: Callable[[Any], Any] | None = None,
  250. serialize_as_any: bool = False,
  251. ) -> dict[str, Any]:
  252. """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
  253. Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
  254. Args:
  255. mode: The mode in which `to_python` should run.
  256. If mode is 'json', the output will only contain JSON serializable types.
  257. If mode is 'python', the output may contain non-JSON-serializable Python objects.
  258. include: A set of fields to include in the output.
  259. exclude: A set of fields to exclude from the output.
  260. context: Additional context to pass to the serializer.
  261. by_alias: Whether to use the field's alias in the dictionary key if defined.
  262. exclude_unset: Whether to exclude fields that have not been explicitly set.
  263. exclude_defaults: Whether to exclude fields that are set to their default value.
  264. exclude_none: Whether to exclude fields that have a value of `None`.
  265. exclude_computed_fields: Whether to exclude computed fields.
  266. While this can be useful for round-tripping, it is usually recommended to use the dedicated
  267. `round_trip` parameter instead.
  268. round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T].
  269. warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
  270. "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
  271. fallback: A function to call when an unknown value is encountered. If not provided,
  272. a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
  273. serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
  274. Returns:
  275. A dictionary representation of the model.
  276. """
  277. if mode not in {"json", "python"}:
  278. raise ValueError("mode must be either 'json' or 'python'")
  279. if round_trip != False:
  280. raise ValueError("round_trip is only supported in Pydantic v2")
  281. if warnings != True:
  282. raise ValueError("warnings is only supported in Pydantic v2")
  283. if context is not None:
  284. raise ValueError("context is only supported in Pydantic v2")
  285. if serialize_as_any != False:
  286. raise ValueError("serialize_as_any is only supported in Pydantic v2")
  287. if fallback is not None:
  288. raise ValueError("fallback is only supported in Pydantic v2")
  289. if exclude_computed_fields != False:
  290. raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
  291. dumped = super().dict( # pyright: ignore[reportDeprecated]
  292. include=include,
  293. exclude=exclude,
  294. by_alias=by_alias if by_alias is not None else False,
  295. exclude_unset=exclude_unset,
  296. exclude_defaults=exclude_defaults,
  297. exclude_none=exclude_none,
  298. )
  299. return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped
  300. @override
  301. def model_dump_json(
  302. self,
  303. *,
  304. indent: int | None = None,
  305. ensure_ascii: bool = False,
  306. include: IncEx | None = None,
  307. exclude: IncEx | None = None,
  308. context: Any | None = None,
  309. by_alias: bool | None = None,
  310. exclude_unset: bool = False,
  311. exclude_defaults: bool = False,
  312. exclude_none: bool = False,
  313. exclude_computed_fields: bool = False,
  314. round_trip: bool = False,
  315. warnings: bool | Literal["none", "warn", "error"] = True,
  316. fallback: Callable[[Any], Any] | None = None,
  317. serialize_as_any: bool = False,
  318. ) -> str:
  319. """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
  320. Generates a JSON representation of the model using Pydantic's `to_json` method.
  321. Args:
  322. indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
  323. include: Field(s) to include in the JSON output. Can take either a string or set of strings.
  324. exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
  325. by_alias: Whether to serialize using field aliases.
  326. exclude_unset: Whether to exclude fields that have not been explicitly set.
  327. exclude_defaults: Whether to exclude fields that have the default value.
  328. exclude_none: Whether to exclude fields that have a value of `None`.
  329. round_trip: Whether to use serialization/deserialization between JSON and class instance.
  330. warnings: Whether to show any warnings that occurred during serialization.
  331. Returns:
  332. A JSON string representation of the model.
  333. """
  334. if round_trip != False:
  335. raise ValueError("round_trip is only supported in Pydantic v2")
  336. if warnings != True:
  337. raise ValueError("warnings is only supported in Pydantic v2")
  338. if context is not None:
  339. raise ValueError("context is only supported in Pydantic v2")
  340. if serialize_as_any != False:
  341. raise ValueError("serialize_as_any is only supported in Pydantic v2")
  342. if fallback is not None:
  343. raise ValueError("fallback is only supported in Pydantic v2")
  344. if ensure_ascii != False:
  345. raise ValueError("ensure_ascii is only supported in Pydantic v2")
  346. if exclude_computed_fields != False:
  347. raise ValueError("exclude_computed_fields is only supported in Pydantic v2")
  348. return super().json( # type: ignore[reportDeprecated]
  349. indent=indent,
  350. include=include,
  351. exclude=exclude,
  352. by_alias=by_alias if by_alias is not None else False,
  353. exclude_unset=exclude_unset,
  354. exclude_defaults=exclude_defaults,
  355. exclude_none=exclude_none,
  356. )
  357. def _construct_field(value: object, field: FieldInfo, key: str) -> object:
  358. if value is None:
  359. return field_get_default(field)
  360. if PYDANTIC_V1:
  361. type_ = cast(type, field.outer_type_) # type: ignore
  362. else:
  363. type_ = field.annotation # type: ignore
  364. if type_ is None:
  365. raise RuntimeError(f"Unexpected field type is None for {key}")
  366. return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None))
  367. def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None:
  368. if PYDANTIC_V1:
  369. # TODO
  370. return None
  371. schema = cls.__pydantic_core_schema__
  372. if schema["type"] == "model":
  373. fields = schema["schema"]
  374. if fields["type"] == "model-fields":
  375. extras = fields.get("extras_schema")
  376. if extras and "cls" in extras:
  377. # mypy can't narrow the type
  378. return extras["cls"] # type: ignore[no-any-return]
  379. return None
  380. def is_basemodel(type_: type) -> bool:
  381. """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
  382. if is_union(type_):
  383. for variant in get_args(type_):
  384. if is_basemodel(variant):
  385. return True
  386. return False
  387. return is_basemodel_type(type_)
  388. def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
  389. origin = get_origin(type_) or type_
  390. if not inspect.isclass(origin):
  391. return False
  392. return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
  393. def build(
  394. base_model_cls: Callable[P, _BaseModelT],
  395. *args: P.args,
  396. **kwargs: P.kwargs,
  397. ) -> _BaseModelT:
  398. """Construct a BaseModel class without validation.
  399. This is useful for cases where you need to instantiate a `BaseModel`
  400. from an API response as this provides type-safe params which isn't supported
  401. by helpers like `construct_type()`.
  402. ```py
  403. build(MyModel, my_field_a="foo", my_field_b=123)
  404. ```
  405. """
  406. if args:
  407. raise TypeError(
  408. "Received positional arguments which are not supported; Keyword arguments must be used instead",
  409. )
  410. return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
  411. def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
  412. """Loose coercion to the expected type with construction of nested values.
  413. Note: the returned value from this function is not guaranteed to match the
  414. given type.
  415. """
  416. return cast(_T, construct_type(value=value, type_=type_))
  417. def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object:
  418. """Loose coercion to the expected type with construction of nested values.
  419. If the given value does not match the expected type then it is returned as-is.
  420. """
  421. # store a reference to the original type we were given before we extract any inner
  422. # types so that we can properly resolve forward references in `TypeAliasType` annotations
  423. original_type = None
  424. # we allow `object` as the input type because otherwise, passing things like
  425. # `Literal['value']` will be reported as a type error by type checkers
  426. type_ = cast("type[object]", type_)
  427. if is_type_alias_type(type_):
  428. original_type = type_ # type: ignore[unreachable]
  429. type_ = type_.__value__ # type: ignore[unreachable]
  430. # unwrap `Annotated[T, ...]` -> `T`
  431. if metadata is not None and len(metadata) > 0:
  432. meta: tuple[Any, ...] = tuple(metadata)
  433. elif is_annotated_type(type_):
  434. meta = get_args(type_)[1:]
  435. type_ = extract_type_arg(type_, 0)
  436. else:
  437. meta = tuple()
  438. # we need to use the origin class for any types that are subscripted generics
  439. # e.g. Dict[str, object]
  440. origin = get_origin(type_) or type_
  441. args = get_args(type_)
  442. if is_union(origin):
  443. try:
  444. return validate_type(type_=cast("type[object]", original_type or type_), value=value)
  445. except Exception:
  446. pass
  447. # if the type is a discriminated union then we want to construct the right variant
  448. # in the union, even if the data doesn't match exactly, otherwise we'd break code
  449. # that relies on the constructed class types, e.g.
  450. #
  451. # class FooType:
  452. # kind: Literal['foo']
  453. # value: str
  454. #
  455. # class BarType:
  456. # kind: Literal['bar']
  457. # value: int
  458. #
  459. # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
  460. # we'd end up constructing `FooType` when it should be `BarType`.
  461. discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
  462. if discriminator and is_mapping(value):
  463. variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
  464. if variant_value and isinstance(variant_value, str):
  465. variant_type = discriminator.mapping.get(variant_value)
  466. if variant_type:
  467. return construct_type(type_=variant_type, value=value)
  468. # if the data is not valid, use the first variant that doesn't fail while deserializing
  469. for variant in args:
  470. try:
  471. return construct_type(value=value, type_=variant)
  472. except Exception:
  473. continue
  474. raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
  475. if origin == dict:
  476. if not is_mapping(value):
  477. return value
  478. _, items_type = get_args(type_) # Dict[_, items_type]
  479. return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
  480. if (
  481. not is_literal_type(type_)
  482. and inspect.isclass(origin)
  483. and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel))
  484. ):
  485. if is_list(value):
  486. return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
  487. if is_mapping(value):
  488. if issubclass(type_, BaseModel):
  489. return type_.construct(**value) # type: ignore[arg-type]
  490. return cast(Any, type_).construct(**value)
  491. if origin == list:
  492. if not is_list(value):
  493. return value
  494. inner_type = args[0] # List[inner_type]
  495. return [construct_type(value=entry, type_=inner_type) for entry in value]
  496. if origin == float:
  497. if isinstance(value, int):
  498. coerced = float(value)
  499. if coerced != value:
  500. return value
  501. return coerced
  502. return value
  503. if type_ == datetime:
  504. try:
  505. return parse_datetime(value) # type: ignore
  506. except Exception:
  507. return value
  508. if type_ == date:
  509. try:
  510. return parse_date(value) # type: ignore
  511. except Exception:
  512. return value
  513. return value
  514. @runtime_checkable
  515. class CachedDiscriminatorType(Protocol):
  516. __discriminator__: DiscriminatorDetails
  517. DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary()
  518. class DiscriminatorDetails:
  519. field_name: str
  520. """The name of the discriminator field in the variant class, e.g.
  521. ```py
  522. class Foo(BaseModel):
  523. type: Literal['foo']
  524. ```
  525. Will result in field_name='type'
  526. """
  527. field_alias_from: str | None
  528. """The name of the discriminator field in the API response, e.g.
  529. ```py
  530. class Foo(BaseModel):
  531. type: Literal['foo'] = Field(alias='type_from_api')
  532. ```
  533. Will result in field_alias_from='type_from_api'
  534. """
  535. mapping: dict[str, type]
  536. """Mapping of discriminator value to variant type, e.g.
  537. {'foo': FooVariant, 'bar': BarVariant}
  538. """
  539. def __init__(
  540. self,
  541. *,
  542. mapping: dict[str, type],
  543. discriminator_field: str,
  544. discriminator_alias: str | None,
  545. ) -> None:
  546. self.mapping = mapping
  547. self.field_name = discriminator_field
  548. self.field_alias_from = discriminator_alias
  549. def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
  550. cached = DISCRIMINATOR_CACHE.get(union)
  551. if cached is not None:
  552. return cached
  553. discriminator_field_name: str | None = None
  554. for annotation in meta_annotations:
  555. if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
  556. discriminator_field_name = annotation.discriminator
  557. break
  558. if not discriminator_field_name:
  559. return None
  560. mapping: dict[str, type] = {}
  561. discriminator_alias: str | None = None
  562. for variant in get_args(union):
  563. variant = strip_annotated_type(variant)
  564. if is_basemodel_type(variant):
  565. if PYDANTIC_V1:
  566. field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
  567. if not field_info:
  568. continue
  569. # Note: if one variant defines an alias then they all should
  570. discriminator_alias = field_info.alias
  571. if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
  572. for entry in get_args(annotation):
  573. if isinstance(entry, str):
  574. mapping[entry] = variant
  575. else:
  576. field = _extract_field_schema_pv2(variant, discriminator_field_name)
  577. if not field:
  578. continue
  579. # Note: if one variant defines an alias then they all should
  580. discriminator_alias = field.get("serialization_alias")
  581. field_schema = field["schema"]
  582. if field_schema["type"] == "literal":
  583. for entry in cast("LiteralSchema", field_schema)["expected"]:
  584. if isinstance(entry, str):
  585. mapping[entry] = variant
  586. if not mapping:
  587. return None
  588. details = DiscriminatorDetails(
  589. mapping=mapping,
  590. discriminator_field=discriminator_field_name,
  591. discriminator_alias=discriminator_alias,
  592. )
  593. DISCRIMINATOR_CACHE.setdefault(union, details)
  594. return details
  595. def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
  596. schema = model.__pydantic_core_schema__
  597. if schema["type"] == "definitions":
  598. schema = schema["schema"]
  599. if schema["type"] != "model":
  600. return None
  601. schema = cast("ModelSchema", schema)
  602. fields_schema = schema["schema"]
  603. if fields_schema["type"] != "model-fields":
  604. return None
  605. fields_schema = cast("ModelFieldsSchema", fields_schema)
  606. field = fields_schema["fields"].get(field_name)
  607. if not field:
  608. return None
  609. return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
  610. def validate_type(*, type_: type[_T], value: object) -> _T:
  611. """Strict validation that the given value matches the expected type"""
  612. if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
  613. return cast(_T, parse_obj(type_, value))
  614. return cast(_T, _validate_non_model_type(type_=type_, value=value))
  615. def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None:
  616. """Add a pydantic config for the given type.
  617. Note: this is a no-op on Pydantic v1.
  618. """
  619. setattr(typ, "__pydantic_config__", config) # noqa: B010
  620. def add_request_id(obj: BaseModel, request_id: str | None) -> None:
  621. obj._request_id = request_id
  622. # in Pydantic v1, using setattr like we do above causes the attribute
  623. # to be included when serializing the model which we don't want in this
  624. # case so we need to explicitly exclude it
  625. if PYDANTIC_V1:
  626. try:
  627. exclude_fields = obj.__exclude_fields__ # type: ignore
  628. except AttributeError:
  629. cast(Any, obj).__exclude_fields__ = {"_request_id", "__exclude_fields__"}
  630. else:
  631. cast(Any, obj).__exclude_fields__ = {*(exclude_fields or {}), "_request_id", "__exclude_fields__"}
  632. # our use of subclassing here causes weirdness for type checkers,
  633. # so we just pretend that we don't subclass
  634. if TYPE_CHECKING:
  635. GenericModel = BaseModel
  636. else:
  637. class GenericModel(BaseGenericModel, BaseModel):
  638. pass
  639. if not PYDANTIC_V1:
  640. from pydantic import TypeAdapter as _TypeAdapter
  641. _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
  642. if TYPE_CHECKING:
  643. from pydantic import TypeAdapter
  644. else:
  645. TypeAdapter = _CachedTypeAdapter
  646. def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
  647. return TypeAdapter(type_).validate_python(value)
  648. elif not TYPE_CHECKING: # TODO: condition is weird
  649. class RootModel(GenericModel, Generic[_T]):
  650. """Used as a placeholder to easily convert runtime types to a Pydantic format
  651. to provide validation.
  652. For example:
  653. ```py
  654. validated = RootModel[int](__root__="5").__root__
  655. # validated: 5
  656. ```
  657. """
  658. __root__: _T
  659. def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
  660. model = _create_pydantic_model(type_).validate(value)
  661. return cast(_T, model.__root__)
  662. def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
  663. return RootModel[type_] # type: ignore
  664. class FinalRequestOptionsInput(TypedDict, total=False):
  665. method: Required[str]
  666. url: Required[str]
  667. params: Query
  668. headers: Headers
  669. max_retries: int
  670. timeout: float | Timeout | None
  671. files: HttpxRequestFiles | None
  672. idempotency_key: str
  673. json_data: Body
  674. extra_json: AnyMapping
  675. follow_redirects: bool
  676. @final
  677. class FinalRequestOptions(pydantic.BaseModel):
  678. method: str
  679. url: str
  680. params: Query = {}
  681. headers: Union[Headers, NotGiven] = NotGiven()
  682. max_retries: Union[int, NotGiven] = NotGiven()
  683. timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
  684. files: Union[HttpxRequestFiles, None] = None
  685. idempotency_key: Union[str, None] = None
  686. post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
  687. follow_redirects: Union[bool, None] = None
  688. # It should be noted that we cannot use `json` here as that would override
  689. # a BaseModel method in an incompatible fashion.
  690. json_data: Union[Body, None] = None
  691. extra_json: Union[AnyMapping, None] = None
  692. if PYDANTIC_V1:
  693. class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
  694. arbitrary_types_allowed: bool = True
  695. else:
  696. model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
  697. def get_max_retries(self, max_retries: int) -> int:
  698. if isinstance(self.max_retries, NotGiven):
  699. return max_retries
  700. return self.max_retries
  701. def _strip_raw_response_header(self) -> None:
  702. if not is_given(self.headers):
  703. return
  704. if self.headers.get(RAW_RESPONSE_HEADER):
  705. self.headers = {**self.headers}
  706. self.headers.pop(RAW_RESPONSE_HEADER)
  707. # override the `construct` method so that we can run custom transformations.
  708. # this is necessary as we don't want to do any actual runtime type checking
  709. # (which means we can't use validators) but we do want to ensure that `NotGiven`
  710. # values are not present
  711. #
  712. # type ignore required because we're adding explicit types to `**values`
  713. @classmethod
  714. def construct( # type: ignore
  715. cls,
  716. _fields_set: set[str] | None = None,
  717. **values: Unpack[FinalRequestOptionsInput],
  718. ) -> FinalRequestOptions:
  719. kwargs: dict[str, Any] = {
  720. # we unconditionally call `strip_not_given` on any value
  721. # as it will just ignore any non-mapping types
  722. key: strip_not_given(value)
  723. for key, value in values.items()
  724. }
  725. if PYDANTIC_V1:
  726. return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
  727. return super().model_construct(_fields_set, **kwargs)
  728. if not TYPE_CHECKING:
  729. # type checkers incorrectly complain about this assignment
  730. model_construct = construct