templating.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import typing
  2. from os import PathLike
  3. from starlette.background import BackgroundTask
  4. from starlette.datastructures import URL
  5. from starlette.requests import Request
  6. from starlette.responses import Response
  7. from starlette.types import Receive, Scope, Send
  8. try:
  9. import jinja2
  10. # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1
  11. # hence we try to get pass_context (most installs will be >=3.1)
  12. # and fall back to contextfunction,
  13. # adding a type ignore for mypy to let us access an attribute that may not exist
  14. if hasattr(jinja2, "pass_context"):
  15. pass_context = jinja2.pass_context
  16. else: # pragma: nocover
  17. pass_context = jinja2.contextfunction # type: ignore[attr-defined]
  18. except ModuleNotFoundError: # pragma: nocover
  19. jinja2 = None # type: ignore[assignment]
  20. class _TemplateResponse(Response):
  21. media_type = "text/html"
  22. def __init__(
  23. self,
  24. template: typing.Any,
  25. context: dict,
  26. status_code: int = 200,
  27. headers: typing.Optional[typing.Mapping[str, str]] = None,
  28. media_type: typing.Optional[str] = None,
  29. background: typing.Optional[BackgroundTask] = None,
  30. ):
  31. self.template = template
  32. self.context = context
  33. content = template.render(context)
  34. super().__init__(content, status_code, headers, media_type, background)
  35. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  36. request = self.context.get("request", {})
  37. extensions = request.get("extensions", {})
  38. if "http.response.debug" in extensions:
  39. await send(
  40. {
  41. "type": "http.response.debug",
  42. "info": {
  43. "template": self.template,
  44. "context": self.context,
  45. },
  46. }
  47. )
  48. await super().__call__(scope, receive, send)
  49. class Jinja2Templates:
  50. """
  51. templates = Jinja2Templates("templates")
  52. return templates.TemplateResponse("index.html", {"request": request})
  53. """
  54. def __init__(
  55. self,
  56. directory: typing.Union[str, PathLike],
  57. context_processors: typing.Optional[
  58. typing.List[typing.Callable[[Request], typing.Dict[str, typing.Any]]]
  59. ] = None,
  60. **env_options: typing.Any,
  61. ) -> None:
  62. assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
  63. self.env = self._create_env(directory, **env_options)
  64. self.context_processors = context_processors or []
  65. def _create_env(
  66. self, directory: typing.Union[str, PathLike], **env_options: typing.Any
  67. ) -> "jinja2.Environment":
  68. @pass_context
  69. def url_for(context: dict, name: str, **path_params: typing.Any) -> URL:
  70. request = context["request"]
  71. return request.url_for(name, **path_params)
  72. loader = jinja2.FileSystemLoader(directory)
  73. env_options.setdefault("loader", loader)
  74. env_options.setdefault("autoescape", True)
  75. env = jinja2.Environment(**env_options)
  76. env.globals["url_for"] = url_for
  77. return env
  78. def get_template(self, name: str) -> "jinja2.Template":
  79. return self.env.get_template(name)
  80. def TemplateResponse(
  81. self,
  82. name: str,
  83. context: dict,
  84. status_code: int = 200,
  85. headers: typing.Optional[typing.Mapping[str, str]] = None,
  86. media_type: typing.Optional[str] = None,
  87. background: typing.Optional[BackgroundTask] = None,
  88. ) -> _TemplateResponse:
  89. if "request" not in context:
  90. raise ValueError('context must include a "request" key')
  91. request = typing.cast(Request, context["request"])
  92. for context_processor in self.context_processors:
  93. context.update(context_processor(request))
  94. template = self.get_template(name)
  95. return _TemplateResponse(
  96. template,
  97. context,
  98. status_code=status_code,
  99. headers=headers,
  100. media_type=media_type,
  101. background=background,
  102. )