root_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """RootModel class and type definitions."""
  2. from __future__ import annotations as _annotations
  3. import typing
  4. from copy import copy, deepcopy
  5. from pydantic_core import PydanticUndefined
  6. from . import PydanticUserError
  7. from ._internal import _repr
  8. from .main import BaseModel, _object_setattr
  9. if typing.TYPE_CHECKING:
  10. from typing import Any
  11. from typing_extensions import Literal
  12. Model = typing.TypeVar('Model', bound='BaseModel')
  13. __all__ = ('RootModel',)
  14. RootModelRootType = typing.TypeVar('RootModelRootType')
  15. class RootModel(BaseModel, typing.Generic[RootModelRootType]):
  16. """Usage docs: https://docs.pydantic.dev/2.4/concepts/models/#rootmodel-and-custom-root-types
  17. A Pydantic `BaseModel` for the root object of the model.
  18. Attributes:
  19. root: The root object of the model.
  20. __pydantic_root_model__: Whether the model is a RootModel.
  21. __pydantic_private__: Private fields in the model.
  22. __pydantic_extra__: Extra fields in the model.
  23. """
  24. __pydantic_root_model__ = True
  25. __pydantic_private__ = None
  26. __pydantic_extra__ = None
  27. root: RootModelRootType
  28. def __init_subclass__(cls, **kwargs):
  29. extra = cls.model_config.get('extra')
  30. if extra is not None:
  31. raise PydanticUserError(
  32. "`RootModel` does not support setting `model_config['extra']`", code='root-model-extra'
  33. )
  34. super().__init_subclass__(**kwargs)
  35. def __init__(__pydantic_self__, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
  36. __tracebackhide__ = True
  37. if data:
  38. if root is not PydanticUndefined:
  39. raise ValueError(
  40. '"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
  41. )
  42. root = data # type: ignore
  43. __pydantic_self__.__pydantic_validator__.validate_python(root, self_instance=__pydantic_self__)
  44. __init__.__pydantic_base_init__ = True
  45. @classmethod
  46. def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model:
  47. """Create a new model using the provided root object and update fields set.
  48. Args:
  49. root: The root object of the model.
  50. _fields_set: The set of fields to be updated.
  51. Returns:
  52. The new model.
  53. Raises:
  54. NotImplemented: If the model is not a subclass of `RootModel`.
  55. """
  56. return super().model_construct(root=root, _fields_set=_fields_set)
  57. def __getstate__(self) -> dict[Any, Any]:
  58. return {
  59. '__dict__': self.__dict__,
  60. '__pydantic_fields_set__': self.__pydantic_fields_set__,
  61. }
  62. def __setstate__(self, state: dict[Any, Any]) -> None:
  63. _object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
  64. _object_setattr(self, '__dict__', state['__dict__'])
  65. def __copy__(self: Model) -> Model:
  66. """Returns a shallow copy of the model."""
  67. cls = type(self)
  68. m = cls.__new__(cls)
  69. _object_setattr(m, '__dict__', copy(self.__dict__))
  70. _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
  71. return m
  72. def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model:
  73. """Returns a deep copy of the model."""
  74. cls = type(self)
  75. m = cls.__new__(cls)
  76. _object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
  77. # This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
  78. # and attempting a deepcopy would be marginally slower.
  79. _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
  80. return m
  81. if typing.TYPE_CHECKING:
  82. def model_dump(
  83. self,
  84. *,
  85. mode: Literal['json', 'python'] | str = 'python',
  86. include: Any = None,
  87. exclude: Any = None,
  88. by_alias: bool = False,
  89. exclude_unset: bool = False,
  90. exclude_defaults: bool = False,
  91. exclude_none: bool = False,
  92. round_trip: bool = False,
  93. warnings: bool = True,
  94. ) -> RootModelRootType:
  95. """This method is included just to get a more accurate return type for type checkers.
  96. It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
  97. See the documentation of `BaseModel.model_dump` for more details about the arguments.
  98. """
  99. ...
  100. def __eq__(self, other: Any) -> bool:
  101. if not isinstance(other, RootModel):
  102. return NotImplemented
  103. return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other)
  104. def __repr_args__(self) -> _repr.ReprArgs:
  105. yield 'root', self.root