_app.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import contextlib
  16. import json
  17. from typing import (
  18. Any,
  19. AsyncGenerator,
  20. Callable,
  21. Dict,
  22. Generic,
  23. List,
  24. Optional,
  25. Tuple,
  26. TypedDict,
  27. TypeVar,
  28. )
  29. from typing_extensions import ParamSpec, TypeGuard
  30. from ....utils import logging
  31. from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
  32. from ...pipelines import BasePipeline
  33. from ..infra.config import AppConfig
  34. from ..infra.models import NoResultResponse
  35. from ..infra.utils import call_async, generate_log_id
  36. if is_dep_available("aiohttp"):
  37. import aiohttp
  38. if is_dep_available("fastapi"):
  39. import fastapi
  40. from fastapi.encoders import jsonable_encoder
  41. from fastapi.exceptions import RequestValidationError
  42. from fastapi.responses import JSONResponse
  43. if is_dep_available("starlette"):
  44. from starlette.exceptions import HTTPException
  45. PipelineT = TypeVar("PipelineT", bound=BasePipeline)
  46. P = ParamSpec("P")
  47. R = TypeVar("R")
  48. class _Error(TypedDict):
  49. error: str
  50. def _is_error(obj: object) -> TypeGuard[_Error]:
  51. return (
  52. isinstance(obj, dict)
  53. and obj.keys() == {"error"}
  54. and isinstance(obj["error"], str)
  55. )
  56. # XXX: Since typing info (e.g., the pipeline class) cannot be easily obtained
  57. # without abstraction leaks, generic classes do not offer additional benefits
  58. # for type hinting. However, I would stick with the current design, as it does
  59. # not introduce runtime overhead at the moment and may prove useful in the
  60. # future.
  61. @class_requires_deps("fastapi")
  62. class PipelineWrapper(Generic[PipelineT]):
  63. def __init__(self, pipeline: PipelineT) -> None:
  64. super().__init__()
  65. self._pipeline = pipeline
  66. self._lock = asyncio.Lock()
  67. @property
  68. def pipeline(self) -> PipelineT:
  69. return self._pipeline
  70. async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
  71. def _infer() -> List[Any]:
  72. output: list = []
  73. with contextlib.closing(self._pipeline(*args, **kwargs)) as it:
  74. for item in it:
  75. if _is_error(item):
  76. raise fastapi.HTTPException(
  77. status_code=500, detail=item["error"]
  78. )
  79. output.append(item)
  80. return output
  81. return await self.call(_infer)
  82. async def call(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
  83. async with self._lock:
  84. return await call_async(func, *args, **kwargs)
  85. @class_requires_deps("aiohttp")
  86. class AppContext(Generic[PipelineT]):
  87. def __init__(self, *, config: AppConfig) -> None:
  88. super().__init__()
  89. self._config = config
  90. self.extra: Dict[str, Any] = {}
  91. self._pipeline: Optional[PipelineWrapper[PipelineT]] = None
  92. self._aiohttp_session: Optional[aiohttp.ClientSession] = None
  93. @property
  94. def config(self) -> AppConfig:
  95. return self._config
  96. @property
  97. def pipeline(self) -> PipelineWrapper[PipelineT]:
  98. if not self._pipeline:
  99. raise AttributeError("`pipeline` has not been set.")
  100. return self._pipeline
  101. @pipeline.setter
  102. def pipeline(self, val: PipelineWrapper[PipelineT]) -> None:
  103. self._pipeline = val
  104. @property
  105. def aiohttp_session(self) -> "aiohttp.ClientSession":
  106. if not self._aiohttp_session:
  107. raise AttributeError("`aiohttp_session` has not been set.")
  108. return self._aiohttp_session
  109. @aiohttp_session.setter
  110. def aiohttp_session(self, val: "aiohttp.ClientSession") -> None:
  111. self._aiohttp_session = val
  112. @function_requires_deps("fastapi", "aiohttp", "starlette")
  113. def create_app(
  114. *, pipeline: PipelineT, app_config: AppConfig, app_aiohttp_session: bool = True
  115. ) -> Tuple["fastapi.FastAPI", AppContext[PipelineT]]:
  116. @contextlib.asynccontextmanager
  117. async def _app_lifespan(app: "fastapi.FastAPI") -> AsyncGenerator[None, None]:
  118. ctx.pipeline = PipelineWrapper[PipelineT](pipeline)
  119. if app_aiohttp_session:
  120. async with aiohttp.ClientSession(
  121. cookie_jar=aiohttp.DummyCookieJar()
  122. ) as aiohttp_session:
  123. ctx.aiohttp_session = aiohttp_session
  124. yield
  125. else:
  126. yield
  127. # Should we control API versions?
  128. app = fastapi.FastAPI(lifespan=_app_lifespan)
  129. ctx = AppContext[PipelineT](config=app_config)
  130. app.state.context = ctx
  131. @app.get("/health", operation_id="checkHealth")
  132. async def _check_health() -> NoResultResponse:
  133. return NoResultResponse(
  134. logId=generate_log_id(), errorCode=0, errorMsg="Healthy"
  135. )
  136. @app.exception_handler(RequestValidationError)
  137. async def _validation_exception_handler(
  138. request: fastapi.Request, exc: RequestValidationError
  139. ) -> JSONResponse:
  140. json_compatible_data = jsonable_encoder(
  141. NoResultResponse(
  142. logId=generate_log_id(),
  143. errorCode=422,
  144. errorMsg=json.dumps(exc.errors()),
  145. )
  146. )
  147. return JSONResponse(content=json_compatible_data, status_code=422)
  148. @app.exception_handler(HTTPException)
  149. async def _http_exception_handler(
  150. request: fastapi.Request, exc: HTTPException
  151. ) -> JSONResponse:
  152. json_compatible_data = jsonable_encoder(
  153. NoResultResponse(
  154. logId=generate_log_id(), errorCode=exc.status_code, errorMsg=exc.detail
  155. )
  156. )
  157. return JSONResponse(content=json_compatible_data, status_code=exc.status_code)
  158. @app.exception_handler(Exception)
  159. async def _unexpected_exception_handler(
  160. request: fastapi.Request, exc: Exception
  161. ) -> JSONResponse:
  162. # XXX: The default server will duplicate the error message. Is it
  163. # necessary to log the exception info here?
  164. logging.exception("Unhandled exception")
  165. json_compatible_data = jsonable_encoder(
  166. NoResultResponse(
  167. logId=generate_log_id(),
  168. errorCode=500,
  169. errorMsg="Internal server error",
  170. )
  171. )
  172. return JSONResponse(content=json_compatible_data, status_code=500)
  173. return app, ctx
  174. # TODO: Precise type hints
  175. @function_requires_deps("fastapi")
  176. def primary_operation(
  177. app: "fastapi.FastAPI", path: str, operation_id: str, **kwargs: Any
  178. ) -> Callable:
  179. return app.post(
  180. path,
  181. operation_id=operation_id,
  182. responses={422: {"model": NoResultResponse}, 500: {"model": NoResultResponse}},
  183. response_model_exclude_none=True,
  184. **kwargs,
  185. )