| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import typing
- from os import PathLike
- from starlette.background import BackgroundTask
- from starlette.datastructures import URL
- from starlette.requests import Request
- from starlette.responses import Response
- from starlette.types import Receive, Scope, Send
- try:
- import jinja2
- # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
- # hence we try to get pass_context (most installs will be >=3.1)
- # and fall back to contextfunction,
- # adding a type ignore for mypy to let us access an attribute that may not exist
- if hasattr(jinja2, "pass_context"):
- pass_context = jinja2.pass_context
- else: # pragma: nocover
- pass_context = jinja2.contextfunction # type: ignore[attr-defined]
- except ModuleNotFoundError: # pragma: nocover
- jinja2 = None # type: ignore[assignment]
- class _TemplateResponse(Response):
- media_type = "text/html"
- def __init__(
- self,
- template: typing.Any,
- context: dict,
- status_code: int = 200,
- headers: typing.Optional[typing.Mapping[str, str]] = None,
- media_type: typing.Optional[str] = None,
- background: typing.Optional[BackgroundTask] = None,
- ):
- self.template = template
- self.context = context
- content = template.render(context)
- super().__init__(content, status_code, headers, media_type, background)
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- request = self.context.get("request", {})
- extensions = request.get("extensions", {})
- if "http.response.debug" in extensions:
- await send(
- {
- "type": "http.response.debug",
- "info": {
- "template": self.template,
- "context": self.context,
- },
- }
- )
- await super().__call__(scope, receive, send)
- class Jinja2Templates:
- """
- templates = Jinja2Templates("templates")
- return templates.TemplateResponse("index.html", {"request": request})
- """
- def __init__(
- self,
- directory: typing.Union[str, PathLike],
- context_processors: typing.Optional[
- typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
- ] = None,
- **env_options: typing.Any,
- ) -> None:
- assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
- self.env = self._create_env(directory, **env_options)
- self.context_processors = context_processors or []
- def _create_env(
- self, directory: typing.Union[str, PathLike], **env_options: typing.Any
- ) -> "jinja2.Environment":
- @pass_context
- def url_for(context: dict, name: str, **path_params: typing.Any) -> URL:
- request = context["request"]
- return request.url_for(name, **path_params)
- loader = jinja2.FileSystemLoader(directory)
- env_options.setdefault("loader", loader)
- env_options.setdefault("autoescape", True)
- env = jinja2.Environment(**env_options)
- env.globals["url_for"] = url_for
- return env
- def get_template(self, name: str) -> "jinja2.Template":
- return self.env.get_template(name)
- def TemplateResponse(
- self,
- name: str,
- context: dict,
- status_code: int = 200,
- headers: typing.Optional[typing.Mapping[str, str]] = None,
- media_type: typing.Optional[str] = None,
- background: typing.Optional[BackgroundTask] = None,
- ) -> _TemplateResponse:
- if "request" not in context:
- raise ValueError('context must include a "request" key')
- request = typing.cast(Request, context["request"])
- for context_processor in self.context_processors:
- context.update(context_processor(request))
- template = self.get_template(name)
- return _TemplateResponse(
- template,
- context,
- status_code=status_code,
- headers=headers,
- media_type=media_type,
- background=background,
- )
|