_compat.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
  3. from datetime import date, datetime
  4. from typing_extensions import Self, Literal
  5. import pydantic
  6. from pydantic.fields import FieldInfo
  7. from ._types import IncEx, StrBytesIntFloat
  8. _T = TypeVar("_T")
  9. _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
  10. # --------------- Pydantic v2, v3 compatibility ---------------
  11. # Pyright incorrectly reports some of our functions as overriding a method when they don't
  12. # pyright: reportIncompatibleMethodOverride=false
  13. PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
  14. if TYPE_CHECKING:
  15. def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
  16. ...
  17. def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
  18. ...
  19. def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
  20. ...
  21. def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
  22. ...
  23. def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
  24. ...
  25. def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
  26. ...
  27. def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
  28. ...
  29. else:
  30. # v1 re-exports
  31. if PYDANTIC_V1:
  32. from pydantic.typing import (
  33. get_args as get_args,
  34. is_union as is_union,
  35. get_origin as get_origin,
  36. is_typeddict as is_typeddict,
  37. is_literal_type as is_literal_type,
  38. )
  39. from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
  40. else:
  41. from ._utils import (
  42. get_args as get_args,
  43. is_union as is_union,
  44. get_origin as get_origin,
  45. parse_date as parse_date,
  46. is_typeddict as is_typeddict,
  47. parse_datetime as parse_datetime,
  48. is_literal_type as is_literal_type,
  49. )
  50. # refactored config
  51. if TYPE_CHECKING:
  52. from pydantic import ConfigDict as ConfigDict
  53. else:
  54. if PYDANTIC_V1:
  55. # TODO: provide an error message here?
  56. ConfigDict = None
  57. else:
  58. from pydantic import ConfigDict as ConfigDict
  59. # renamed methods / properties
  60. def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
  61. if PYDANTIC_V1:
  62. return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
  63. else:
  64. return model.model_validate(value)
  65. def field_is_required(field: FieldInfo) -> bool:
  66. if PYDANTIC_V1:
  67. return field.required # type: ignore
  68. return field.is_required()
  69. def field_get_default(field: FieldInfo) -> Any:
  70. value = field.get_default()
  71. if PYDANTIC_V1:
  72. return value
  73. from pydantic_core import PydanticUndefined
  74. if value == PydanticUndefined:
  75. return None
  76. return value
  77. def field_outer_type(field: FieldInfo) -> Any:
  78. if PYDANTIC_V1:
  79. return field.outer_type_ # type: ignore
  80. return field.annotation
  81. def get_model_config(model: type[pydantic.BaseModel]) -> Any:
  82. if PYDANTIC_V1:
  83. return model.__config__ # type: ignore
  84. return model.model_config
  85. def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
  86. if PYDANTIC_V1:
  87. return model.__fields__ # type: ignore
  88. return model.model_fields
  89. def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
  90. if PYDANTIC_V1:
  91. return model.copy(deep=deep) # type: ignore
  92. return model.model_copy(deep=deep)
  93. def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
  94. if PYDANTIC_V1:
  95. return model.json(indent=indent) # type: ignore
  96. return model.model_dump_json(indent=indent)
  97. def model_dump(
  98. model: pydantic.BaseModel,
  99. *,
  100. exclude: IncEx | None = None,
  101. exclude_unset: bool = False,
  102. exclude_defaults: bool = False,
  103. warnings: bool = True,
  104. mode: Literal["json", "python"] = "python",
  105. ) -> dict[str, Any]:
  106. if (not PYDANTIC_V1) or hasattr(model, "model_dump"):
  107. return model.model_dump(
  108. mode=mode,
  109. exclude=exclude,
  110. exclude_unset=exclude_unset,
  111. exclude_defaults=exclude_defaults,
  112. # warnings are not supported in Pydantic v1
  113. warnings=True if PYDANTIC_V1 else warnings,
  114. )
  115. return cast(
  116. "dict[str, Any]",
  117. model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
  118. exclude=exclude,
  119. exclude_unset=exclude_unset,
  120. exclude_defaults=exclude_defaults,
  121. ),
  122. )
  123. def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
  124. if PYDANTIC_V1:
  125. return model.parse_obj(data) # pyright: ignore[reportDeprecated]
  126. return model.model_validate(data)
  127. def model_parse_json(model: type[_ModelT], data: str | bytes) -> _ModelT:
  128. if PYDANTIC_V1:
  129. return model.parse_raw(data) # pyright: ignore[reportDeprecated]
  130. return model.model_validate_json(data)
  131. def model_json_schema(model: type[_ModelT]) -> dict[str, Any]:
  132. if PYDANTIC_V1:
  133. return model.schema() # pyright: ignore[reportDeprecated]
  134. return model.model_json_schema()
  135. # generic models
  136. if TYPE_CHECKING:
  137. class GenericModel(pydantic.BaseModel): ...
  138. else:
  139. if PYDANTIC_V1:
  140. import pydantic.generics
  141. class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
  142. else:
  143. # there no longer needs to be a distinction in v2 but
  144. # we still have to create our own subclass to avoid
  145. # inconsistent MRO ordering errors
  146. class GenericModel(pydantic.BaseModel): ...
  147. # cached properties
  148. if TYPE_CHECKING:
  149. cached_property = property
  150. # we define a separate type (copied from typeshed)
  151. # that represents that `cached_property` is `set`able
  152. # at runtime, which differs from `@property`.
  153. #
  154. # this is a separate type as editors likely special case
  155. # `@property` and we don't want to cause issues just to have
  156. # more helpful internal types.
  157. class typed_cached_property(Generic[_T]):
  158. func: Callable[[Any], _T]
  159. attrname: str | None
  160. def __init__(self, func: Callable[[Any], _T]) -> None: ...
  161. @overload
  162. def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
  163. @overload
  164. def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
  165. def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
  166. raise NotImplementedError()
  167. def __set_name__(self, owner: type[Any], name: str) -> None: ...
  168. # __set__ is not defined at runtime, but @cached_property is designed to be settable
  169. def __set__(self, instance: object, value: _T) -> None: ...
  170. else:
  171. from functools import cached_property as cached_property
  172. typed_cached_property = cached_property