schemas.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import inspect
  2. import re
  3. import typing
  4. from starlette.requests import Request
  5. from starlette.responses import Response
  6. from starlette.routing import BaseRoute, Mount, Route
  7. try:
  8. import yaml
  9. except ModuleNotFoundError: # pragma: nocover
  10. yaml = None # type: ignore[assignment]
  11. class OpenAPIResponse(Response):
  12. media_type = "application/vnd.oai.openapi"
  13. def render(self, content: typing.Any) -> bytes:
  14. assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
  15. assert isinstance(
  16. content, dict
  17. ), "The schema passed to OpenAPIResponse should be a dictionary."
  18. return yaml.dump(content, default_flow_style=False).encode("utf-8")
  19. class EndpointInfo(typing.NamedTuple):
  20. path: str
  21. http_method: str
  22. func: typing.Callable
  23. class BaseSchemaGenerator:
  24. def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
  25. raise NotImplementedError() # pragma: no cover
  26. def get_endpoints(
  27. self, routes: typing.List[BaseRoute]
  28. ) -> typing.List[EndpointInfo]:
  29. """
  30. Given the routes, yields the following information:
  31. - path
  32. eg: /users/
  33. - http_method
  34. one of 'get', 'post', 'put', 'patch', 'delete', 'options'
  35. - func
  36. method ready to extract the docstring
  37. """
  38. endpoints_info: list = []
  39. for route in routes:
  40. if isinstance(route, Mount):
  41. path = self._remove_converter(route.path)
  42. routes = route.routes or []
  43. sub_endpoints = [
  44. EndpointInfo(
  45. path="".join((path, sub_endpoint.path)),
  46. http_method=sub_endpoint.http_method,
  47. func=sub_endpoint.func,
  48. )
  49. for sub_endpoint in self.get_endpoints(routes)
  50. ]
  51. endpoints_info.extend(sub_endpoints)
  52. elif not isinstance(route, Route) or not route.include_in_schema:
  53. continue
  54. elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
  55. path = self._remove_converter(route.path)
  56. for method in route.methods or ["GET"]:
  57. if method == "HEAD":
  58. continue
  59. endpoints_info.append(
  60. EndpointInfo(path, method.lower(), route.endpoint)
  61. )
  62. else:
  63. path = self._remove_converter(route.path)
  64. for method in ["get", "post", "put", "patch", "delete", "options"]:
  65. if not hasattr(route.endpoint, method):
  66. continue
  67. func = getattr(route.endpoint, method)
  68. endpoints_info.append(EndpointInfo(path, method.lower(), func))
  69. return endpoints_info
  70. def _remove_converter(self, path: str) -> str:
  71. """
  72. Remove the converter from the path.
  73. For example, a route like this:
  74. Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
  75. Should be represented as `/users/{id}` in the OpenAPI schema.
  76. """
  77. return re.sub(r":\w+}", "}", path)
  78. def parse_docstring(self, func_or_method: typing.Callable) -> dict:
  79. """
  80. Given a function, parse the docstring as YAML and return a dictionary of info.
  81. """
  82. docstring = func_or_method.__doc__
  83. if not docstring:
  84. return {}
  85. assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
  86. # We support having regular docstrings before the schema
  87. # definition. Here we return just the schema part from
  88. # the docstring.
  89. docstring = docstring.split("---")[-1]
  90. parsed = yaml.safe_load(docstring)
  91. if not isinstance(parsed, dict):
  92. # A regular docstring (not yaml formatted) can return
  93. # a simple string here, which wouldn't follow the schema.
  94. return {}
  95. return parsed
  96. def OpenAPIResponse(self, request: Request) -> Response:
  97. routes = request.app.routes
  98. schema = self.get_schema(routes=routes)
  99. return OpenAPIResponse(schema)
  100. class SchemaGenerator(BaseSchemaGenerator):
  101. def __init__(self, base_schema: dict) -> None:
  102. self.base_schema = base_schema
  103. def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
  104. schema = dict(self.base_schema)
  105. schema.setdefault("paths", {})
  106. endpoints_info = self.get_endpoints(routes)
  107. for endpoint in endpoints_info:
  108. parsed = self.parse_docstring(endpoint.func)
  109. if not parsed:
  110. continue
  111. if endpoint.path not in schema["paths"]:
  112. schema["paths"][endpoint.path] = {}
  113. schema["paths"][endpoint.path][endpoint.http_method] = parsed
  114. return schema