_pydantic.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from __future__ import annotations
  2. import inspect
  3. from typing import Any, TypeVar
  4. from typing_extensions import TypeGuard
  5. import pydantic
  6. from .._types import NOT_GIVEN
  7. from .._utils import is_dict as _is_dict, is_list
  8. from .._compat import PYDANTIC_V1, model_json_schema
  9. _T = TypeVar("_T")
  10. def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
  11. if inspect.isclass(model) and is_basemodel_type(model):
  12. schema = model_json_schema(model)
  13. elif (not PYDANTIC_V1) and isinstance(model, pydantic.TypeAdapter):
  14. schema = model.json_schema()
  15. else:
  16. raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
  17. return _ensure_strict_json_schema(schema, path=(), root=schema)
  18. def _ensure_strict_json_schema(
  19. json_schema: object,
  20. *,
  21. path: tuple[str, ...],
  22. root: dict[str, object],
  23. ) -> dict[str, Any]:
  24. """Mutates the given JSON schema to ensure it conforms to the `strict` standard
  25. that the API expects.
  26. """
  27. if not is_dict(json_schema):
  28. raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
  29. defs = json_schema.get("$defs")
  30. if is_dict(defs):
  31. for def_name, def_schema in defs.items():
  32. _ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
  33. definitions = json_schema.get("definitions")
  34. if is_dict(definitions):
  35. for definition_name, definition_schema in definitions.items():
  36. _ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
  37. typ = json_schema.get("type")
  38. if typ == "object" and "additionalProperties" not in json_schema:
  39. json_schema["additionalProperties"] = False
  40. # object types
  41. # { 'type': 'object', 'properties': { 'a': {...} } }
  42. properties = json_schema.get("properties")
  43. if is_dict(properties):
  44. json_schema["required"] = [prop for prop in properties.keys()]
  45. json_schema["properties"] = {
  46. key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
  47. for key, prop_schema in properties.items()
  48. }
  49. # arrays
  50. # { 'type': 'array', 'items': {...} }
  51. items = json_schema.get("items")
  52. if is_dict(items):
  53. json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
  54. # unions
  55. any_of = json_schema.get("anyOf")
  56. if is_list(any_of):
  57. json_schema["anyOf"] = [
  58. _ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
  59. for i, variant in enumerate(any_of)
  60. ]
  61. # intersections
  62. all_of = json_schema.get("allOf")
  63. if is_list(all_of):
  64. if len(all_of) == 1:
  65. json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
  66. json_schema.pop("allOf")
  67. else:
  68. json_schema["allOf"] = [
  69. _ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
  70. for i, entry in enumerate(all_of)
  71. ]
  72. # strip `None` defaults as there's no meaningful distinction here
  73. # the schema will still be `nullable` and the model will default
  74. # to using `None` anyway
  75. if json_schema.get("default", NOT_GIVEN) is None:
  76. json_schema.pop("default")
  77. # we can't use `$ref`s if there are also other properties defined, e.g.
  78. # `{"$ref": "...", "description": "my description"}`
  79. #
  80. # so we unravel the ref
  81. # `{"type": "string", "description": "my description"}`
  82. ref = json_schema.get("$ref")
  83. if ref and has_more_than_n_keys(json_schema, 1):
  84. assert isinstance(ref, str), f"Received non-string $ref - {ref}"
  85. resolved = resolve_ref(root=root, ref=ref)
  86. if not is_dict(resolved):
  87. raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
  88. # properties from the json schema take priority over the ones on the `$ref`
  89. json_schema.update({**resolved, **json_schema})
  90. json_schema.pop("$ref")
  91. # Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
  92. # we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
  93. return _ensure_strict_json_schema(json_schema, path=path, root=root)
  94. return json_schema
  95. def resolve_ref(*, root: dict[str, object], ref: str) -> object:
  96. if not ref.startswith("#/"):
  97. raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
  98. path = ref[2:].split("/")
  99. resolved = root
  100. for key in path:
  101. value = resolved[key]
  102. assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
  103. resolved = value
  104. return resolved
  105. def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
  106. if not inspect.isclass(typ):
  107. return False
  108. return issubclass(typ, pydantic.BaseModel)
  109. def is_dataclass_like_type(typ: type) -> bool:
  110. """Returns True if the given type likely used `@pydantic.dataclass`"""
  111. return hasattr(typ, "__pydantic_config__")
  112. def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
  113. # just pretend that we know there are only `str` keys
  114. # as that check is not worth the performance cost
  115. return _is_dict(obj)
  116. def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
  117. i = 0
  118. for _ in obj.keys():
  119. i += 1
  120. if i > n:
  121. return True
  122. return False