| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- import inspect
- import re
- import typing
- from starlette.requests import Request
- from starlette.responses import Response
- from starlette.routing import BaseRoute, Mount, Route
- try:
- import yaml
- except ModuleNotFoundError: # pragma: nocover
- yaml = None # type: ignore[assignment]
- class OpenAPIResponse(Response):
- media_type = "application/vnd.oai.openapi"
- def render(self, content: typing.Any) -> bytes:
- assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
- assert isinstance(
- content, dict
- ), "The schema passed to OpenAPIResponse should be a dictionary."
- return yaml.dump(content, default_flow_style=False).encode("utf-8")
- class EndpointInfo(typing.NamedTuple):
- path: str
- http_method: str
- func: typing.Callable
- class BaseSchemaGenerator:
- def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
- raise NotImplementedError() # pragma: no cover
- def get_endpoints(
- self, routes: typing.List[BaseRoute]
- ) -> typing.List[EndpointInfo]:
- """
- Given the routes, yields the following information:
- - path
- eg: /users/
- - http_method
- one of 'get', 'post', 'put', 'patch', 'delete', 'options'
- - func
- method ready to extract the docstring
- """
- endpoints_info: list = []
- for route in routes:
- if isinstance(route, Mount):
- path = self._remove_converter(route.path)
- routes = route.routes or []
- sub_endpoints = [
- EndpointInfo(
- path="".join((path, sub_endpoint.path)),
- http_method=sub_endpoint.http_method,
- func=sub_endpoint.func,
- )
- for sub_endpoint in self.get_endpoints(routes)
- ]
- endpoints_info.extend(sub_endpoints)
- elif not isinstance(route, Route) or not route.include_in_schema:
- continue
- elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
- path = self._remove_converter(route.path)
- for method in route.methods or ["GET"]:
- if method == "HEAD":
- continue
- endpoints_info.append(
- EndpointInfo(path, method.lower(), route.endpoint)
- )
- else:
- path = self._remove_converter(route.path)
- for method in ["get", "post", "put", "patch", "delete", "options"]:
- if not hasattr(route.endpoint, method):
- continue
- func = getattr(route.endpoint, method)
- endpoints_info.append(EndpointInfo(path, method.lower(), func))
- return endpoints_info
- def _remove_converter(self, path: str) -> str:
- """
- Remove the converter from the path.
- For example, a route like this:
- Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
- Should be represented as `/users/{id}` in the OpenAPI schema.
- """
- return re.sub(r":\w+}", "}", path)
- def parse_docstring(self, func_or_method: typing.Callable) -> dict:
- """
- Given a function, parse the docstring as YAML and return a dictionary of info.
- """
- docstring = func_or_method.__doc__
- if not docstring:
- return {}
- assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
- # We support having regular docstrings before the schema
- # definition. Here we return just the schema part from
- # the docstring.
- docstring = docstring.split("---")[-1]
- parsed = yaml.safe_load(docstring)
- if not isinstance(parsed, dict):
- # A regular docstring (not yaml formatted) can return
- # a simple string here, which wouldn't follow the schema.
- return {}
- return parsed
- def OpenAPIResponse(self, request: Request) -> Response:
- routes = request.app.routes
- schema = self.get_schema(routes=routes)
- return OpenAPIResponse(schema)
- class SchemaGenerator(BaseSchemaGenerator):
- def __init__(self, base_schema: dict) -> None:
- self.base_schema = base_schema
- def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
- schema = dict(self.base_schema)
- schema.setdefault("paths", {})
- endpoints_info = self.get_endpoints(routes)
- for endpoint in endpoints_info:
- parsed = self.parse_docstring(endpoint.func)
- if not parsed:
- continue
- if endpoint.path not in schema["paths"]:
- schema["paths"][endpoint.path] = {}
- schema["paths"][endpoint.path][endpoint.http_method] = parsed
- return schema
|