dataclasses.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. """
  2. The main purpose is to enhance stdlib dataclasses by adding validation
  3. A pydantic dataclass can be generated from scratch or from a stdlib one.
  4. Behind the scene, a pydantic dataclass is just like a regular one on which we attach
  5. a `BaseModel` and magic methods to trigger the validation of the data.
  6. `__init__` and `__post_init__` are hence overridden and have extra logic to be
  7. able to validate input data.
  8. When a pydantic dataclass is generated from scratch, it's just a plain dataclass
  9. with validation triggered at initialization
  10. The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
  11. ```py
  12. @dataclasses.dataclass
  13. class M:
  14. x: int
  15. ValidatedM = pydantic.dataclasses.dataclass(M)
  16. ```
  17. We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
  18. ```py
  19. assert isinstance(ValidatedM(x=1), M)
  20. assert ValidatedM(x=1) == M(x=1)
  21. ```
  22. This means we **don't want to create a new dataclass that inherits from it**
  23. The trick is to create a wrapper around `M` that will act as a proxy to trigger
  24. validation without altering default `M` behaviour.
  25. """
  26. import copy
  27. import dataclasses
  28. import sys
  29. from contextlib import contextmanager
  30. from functools import wraps
  31. from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
  32. from typing_extensions import dataclass_transform
  33. from .class_validators import gather_all_validators
  34. from .config import BaseConfig, ConfigDict, Extra, get_config
  35. from .error_wrappers import ValidationError
  36. from .errors import DataclassTypeError
  37. from .fields import Field, FieldInfo, Required, Undefined
  38. from .main import create_model, validate_model
  39. from .utils import ClassAttribute
  40. if TYPE_CHECKING:
  41. from .main import BaseModel
  42. from .typing import CallableGenerator, NoArgAnyCallable
  43. DataclassT = TypeVar('DataclassT', bound='Dataclass')
  44. DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
  45. class Dataclass:
  46. # stdlib attributes
  47. __dataclass_fields__: ClassVar[Dict[str, Any]]
  48. __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
  49. __post_init__: ClassVar[Callable[..., None]]
  50. # Added by pydantic
  51. __pydantic_run_validation__: ClassVar[bool]
  52. __post_init_post_parse__: ClassVar[Callable[..., None]]
  53. __pydantic_initialised__: ClassVar[bool]
  54. __pydantic_model__: ClassVar[Type[BaseModel]]
  55. __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
  56. __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
  57. def __init__(self, *args: object, **kwargs: object) -> None:
  58. pass
  59. @classmethod
  60. def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
  61. pass
  62. @classmethod
  63. def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
  64. pass
  65. __all__ = [
  66. 'dataclass',
  67. 'set_validation',
  68. 'create_pydantic_model_from_dataclass',
  69. 'is_builtin_dataclass',
  70. 'make_dataclass_validator',
  71. ]
  72. _T = TypeVar('_T')
  73. if sys.version_info >= (3, 10):
  74. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  75. @overload
  76. def dataclass(
  77. *,
  78. init: bool = True,
  79. repr: bool = True,
  80. eq: bool = True,
  81. order: bool = False,
  82. unsafe_hash: bool = False,
  83. frozen: bool = False,
  84. config: Union[ConfigDict, Type[object], None] = None,
  85. validate_on_init: Optional[bool] = None,
  86. use_proxy: Optional[bool] = None,
  87. kw_only: bool = ...,
  88. ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
  89. ...
  90. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  91. @overload
  92. def dataclass(
  93. _cls: Type[_T],
  94. *,
  95. init: bool = True,
  96. repr: bool = True,
  97. eq: bool = True,
  98. order: bool = False,
  99. unsafe_hash: bool = False,
  100. frozen: bool = False,
  101. config: Union[ConfigDict, Type[object], None] = None,
  102. validate_on_init: Optional[bool] = None,
  103. use_proxy: Optional[bool] = None,
  104. kw_only: bool = ...,
  105. ) -> 'DataclassClassOrWrapper':
  106. ...
  107. else:
  108. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  109. @overload
  110. def dataclass(
  111. *,
  112. init: bool = True,
  113. repr: bool = True,
  114. eq: bool = True,
  115. order: bool = False,
  116. unsafe_hash: bool = False,
  117. frozen: bool = False,
  118. config: Union[ConfigDict, Type[object], None] = None,
  119. validate_on_init: Optional[bool] = None,
  120. use_proxy: Optional[bool] = None,
  121. ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
  122. ...
  123. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  124. @overload
  125. def dataclass(
  126. _cls: Type[_T],
  127. *,
  128. init: bool = True,
  129. repr: bool = True,
  130. eq: bool = True,
  131. order: bool = False,
  132. unsafe_hash: bool = False,
  133. frozen: bool = False,
  134. config: Union[ConfigDict, Type[object], None] = None,
  135. validate_on_init: Optional[bool] = None,
  136. use_proxy: Optional[bool] = None,
  137. ) -> 'DataclassClassOrWrapper':
  138. ...
  139. @dataclass_transform(field_specifiers=(dataclasses.field, Field))
  140. def dataclass(
  141. _cls: Optional[Type[_T]] = None,
  142. *,
  143. init: bool = True,
  144. repr: bool = True,
  145. eq: bool = True,
  146. order: bool = False,
  147. unsafe_hash: bool = False,
  148. frozen: bool = False,
  149. config: Union[ConfigDict, Type[object], None] = None,
  150. validate_on_init: Optional[bool] = None,
  151. use_proxy: Optional[bool] = None,
  152. kw_only: bool = False,
  153. ) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
  154. """
  155. Like the python standard lib dataclasses but with type validation.
  156. The result is either a pydantic dataclass that will validate input data
  157. or a wrapper that will trigger validation around a stdlib dataclass
  158. to avoid modifying it directly
  159. """
  160. the_config = get_config(config)
  161. def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
  162. should_use_proxy = (
  163. use_proxy
  164. if use_proxy is not None
  165. else (
  166. is_builtin_dataclass(cls)
  167. and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0])))
  168. )
  169. )
  170. if should_use_proxy:
  171. dc_cls_doc = ''
  172. dc_cls = DataclassProxy(cls)
  173. default_validate_on_init = False
  174. else:
  175. dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
  176. if sys.version_info >= (3, 10):
  177. dc_cls = dataclasses.dataclass(
  178. cls,
  179. init=init,
  180. repr=repr,
  181. eq=eq,
  182. order=order,
  183. unsafe_hash=unsafe_hash,
  184. frozen=frozen,
  185. kw_only=kw_only,
  186. )
  187. else:
  188. dc_cls = dataclasses.dataclass( # type: ignore
  189. cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
  190. )
  191. default_validate_on_init = True
  192. should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
  193. _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc)
  194. dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})
  195. return dc_cls
  196. if _cls is None:
  197. return wrap
  198. return wrap(_cls)
  199. @contextmanager
  200. def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
  201. original_run_validation = cls.__pydantic_run_validation__
  202. try:
  203. cls.__pydantic_run_validation__ = value
  204. yield cls
  205. finally:
  206. cls.__pydantic_run_validation__ = original_run_validation
  207. class DataclassProxy:
  208. __slots__ = '__dataclass__'
  209. def __init__(self, dc_cls: Type['Dataclass']) -> None:
  210. object.__setattr__(self, '__dataclass__', dc_cls)
  211. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  212. with set_validation(self.__dataclass__, True):
  213. return self.__dataclass__(*args, **kwargs)
  214. def __getattr__(self, name: str) -> Any:
  215. return getattr(self.__dataclass__, name)
  216. def __setattr__(self, __name: str, __value: Any) -> None:
  217. return setattr(self.__dataclass__, __name, __value)
  218. def __instancecheck__(self, instance: Any) -> bool:
  219. return isinstance(instance, self.__dataclass__)
  220. def __copy__(self) -> 'DataclassProxy':
  221. return DataclassProxy(copy.copy(self.__dataclass__))
  222. def __deepcopy__(self, memo: Any) -> 'DataclassProxy':
  223. return DataclassProxy(copy.deepcopy(self.__dataclass__, memo))
  224. def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
  225. dc_cls: Type['Dataclass'],
  226. config: Type[BaseConfig],
  227. validate_on_init: bool,
  228. dc_cls_doc: str,
  229. ) -> None:
  230. """
  231. We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
  232. it won't even exist (code is generated on the fly by `dataclasses`)
  233. By default, we run validation after `__init__` or `__post_init__` if defined
  234. """
  235. init = dc_cls.__init__
  236. @wraps(init)
  237. def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
  238. if config.extra == Extra.ignore:
  239. init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
  240. elif config.extra == Extra.allow:
  241. for k, v in kwargs.items():
  242. self.__dict__.setdefault(k, v)
  243. init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
  244. else:
  245. init(self, *args, **kwargs)
  246. if hasattr(dc_cls, '__post_init__'):
  247. try:
  248. post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined]
  249. except AttributeError:
  250. post_init = dc_cls.__post_init__
  251. @wraps(post_init)
  252. def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
  253. if config.post_init_call == 'before_validation':
  254. post_init(self, *args, **kwargs)
  255. if self.__class__.__pydantic_run_validation__:
  256. self.__pydantic_validate_values__()
  257. if hasattr(self, '__post_init_post_parse__'):
  258. self.__post_init_post_parse__(*args, **kwargs)
  259. if config.post_init_call == 'after_validation':
  260. post_init(self, *args, **kwargs)
  261. setattr(dc_cls, '__init__', handle_extra_init)
  262. setattr(dc_cls, '__post_init__', new_post_init)
  263. else:
  264. @wraps(init)
  265. def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
  266. handle_extra_init(self, *args, **kwargs)
  267. if self.__class__.__pydantic_run_validation__:
  268. self.__pydantic_validate_values__()
  269. if hasattr(self, '__post_init_post_parse__'):
  270. # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
  271. # public method `dataclasses.fields`
  272. # get all initvars and their default values
  273. initvars_and_values: Dict[str, Any] = {}
  274. for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
  275. if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
  276. try:
  277. # set arg value by default
  278. initvars_and_values[f.name] = args[i]
  279. except IndexError:
  280. initvars_and_values[f.name] = kwargs.get(f.name, f.default)
  281. self.__post_init_post_parse__(**initvars_and_values)
  282. setattr(dc_cls, '__init__', new_init)
  283. setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
  284. setattr(dc_cls, '__pydantic_initialised__', False)
  285. setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
  286. setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
  287. setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
  288. setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
  289. if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
  290. setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
  291. def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
  292. yield cls.__validate__
  293. def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
  294. with set_validation(cls, True):
  295. if isinstance(v, cls):
  296. v.__pydantic_validate_values__()
  297. return v
  298. elif isinstance(v, (list, tuple)):
  299. return cls(*v)
  300. elif isinstance(v, dict):
  301. return cls(**v)
  302. else:
  303. raise DataclassTypeError(class_name=cls.__name__)
  304. def create_pydantic_model_from_dataclass(
  305. dc_cls: Type['Dataclass'],
  306. config: Type[Any] = BaseConfig,
  307. dc_cls_doc: Optional[str] = None,
  308. ) -> Type['BaseModel']:
  309. field_definitions: Dict[str, Any] = {}
  310. for field in dataclasses.fields(dc_cls):
  311. default: Any = Undefined
  312. default_factory: Optional['NoArgAnyCallable'] = None
  313. field_info: FieldInfo
  314. if field.default is not dataclasses.MISSING:
  315. default = field.default
  316. elif field.default_factory is not dataclasses.MISSING:
  317. default_factory = field.default_factory
  318. else:
  319. default = Required
  320. if isinstance(default, FieldInfo):
  321. field_info = default
  322. dc_cls.__pydantic_has_field_info_default__ = True
  323. else:
  324. field_info = Field(default=default, default_factory=default_factory, **field.metadata)
  325. field_definitions[field.name] = (field.type, field_info)
  326. validators = gather_all_validators(dc_cls)
  327. model: Type['BaseModel'] = create_model(
  328. dc_cls.__name__,
  329. __config__=config,
  330. __module__=dc_cls.__module__,
  331. __validators__=validators,
  332. __cls_kwargs__={'__resolve_forward_refs__': False},
  333. **field_definitions,
  334. )
  335. model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
  336. return model
  337. def _dataclass_validate_values(self: 'Dataclass') -> None:
  338. # validation errors can occur if this function is called twice on an already initialised dataclass.
  339. # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
  340. if getattr(self, '__pydantic_initialised__'):
  341. return
  342. if getattr(self, '__pydantic_has_field_info_default__', False):
  343. # We need to remove `FieldInfo` values since they are not valid as input
  344. # It's ok to do that because they are obviously the default values!
  345. input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
  346. else:
  347. input_data = self.__dict__
  348. d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
  349. if validation_error:
  350. raise validation_error
  351. self.__dict__.update(d)
  352. object.__setattr__(self, '__pydantic_initialised__', True)
  353. def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
  354. if self.__pydantic_initialised__:
  355. d = dict(self.__dict__)
  356. d.pop(name, None)
  357. known_field = self.__pydantic_model__.__fields__.get(name, None)
  358. if known_field:
  359. value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
  360. if error_:
  361. raise ValidationError([error_], self.__class__)
  362. object.__setattr__(self, name, value)
  363. def is_builtin_dataclass(_cls: Type[Any]) -> bool:
  364. """
  365. Whether a class is a stdlib dataclass
  366. (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
  367. we check that
  368. - `_cls` is a dataclass
  369. - `_cls` is not a processed pydantic dataclass (with a basemodel attached)
  370. - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
  371. e.g.
  372. ```
  373. @dataclasses.dataclass
  374. class A:
  375. x: int
  376. @pydantic.dataclasses.dataclass
  377. class B(A):
  378. y: int
  379. ```
  380. In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
  381. which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
  382. """
  383. return (
  384. dataclasses.is_dataclass(_cls)
  385. and not hasattr(_cls, '__pydantic_model__')
  386. and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
  387. )
  388. def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator':
  389. """
  390. Create a pydantic.dataclass from a builtin dataclass to add type validation
  391. and yield the validators
  392. It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
  393. """
  394. yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))