background.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import sys
  2. import typing
  3. if sys.version_info >= (3, 10): # pragma: no cover
  4. from typing import ParamSpec
  5. else: # pragma: no cover
  6. from typing_extensions import ParamSpec
  7. from starlette._utils import is_async_callable
  8. from starlette.concurrency import run_in_threadpool
  9. P = ParamSpec("P")
  10. class BackgroundTask:
  11. def __init__(
  12. self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
  13. ) -> None:
  14. self.func = func
  15. self.args = args
  16. self.kwargs = kwargs
  17. self.is_async = is_async_callable(func)
  18. async def __call__(self) -> None:
  19. if self.is_async:
  20. await self.func(*self.args, **self.kwargs)
  21. else:
  22. await run_in_threadpool(self.func, *self.args, **self.kwargs)
  23. class BackgroundTasks(BackgroundTask):
  24. def __init__(self, tasks: typing.Optional[typing.Sequence[BackgroundTask]] = None):
  25. self.tasks = list(tasks) if tasks else []
  26. def add_task(
  27. self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
  28. ) -> None:
  29. task = BackgroundTask(func, *args, **kwargs)
  30. self.tasks.append(task)
  31. async def __call__(self) -> None:
  32. for task in self.tasks:
  33. await task()