_app.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import contextlib
  16. import json
  17. from queue import Queue
  18. from threading import Thread
  19. from typing import (
  20. Any,
  21. AsyncGenerator,
  22. Callable,
  23. Dict,
  24. Generic,
  25. List,
  26. Optional,
  27. Tuple,
  28. TypedDict,
  29. TypeVar,
  30. )
  31. from typing_extensions import ParamSpec, TypeGuard
  32. from ....utils import logging
  33. from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
  34. from ...pipelines import BasePipeline
  35. from ..infra.config import AppConfig
  36. from ..infra.models import AIStudioNoResultResponse
  37. from ..infra.utils import call_async, generate_log_id
  38. if is_dep_available("aiohttp"):
  39. import aiohttp
  40. if is_dep_available("fastapi"):
  41. import fastapi
  42. from fastapi.encoders import jsonable_encoder
  43. from fastapi.exceptions import RequestValidationError
  44. from fastapi.responses import JSONResponse
  45. if is_dep_available("starlette"):
  46. from starlette.exceptions import HTTPException
  47. PipelineT = TypeVar("PipelineT", bound=BasePipeline)
  48. P = ParamSpec("P")
  49. R = TypeVar("R")
  50. class _Error(TypedDict):
  51. error: str
  52. def _is_error(obj: object) -> TypeGuard[_Error]:
  53. return (
  54. isinstance(obj, dict)
  55. and obj.keys() == {"error"}
  56. and isinstance(obj["error"], str)
  57. )
  58. # XXX: Since typing info (e.g., the pipeline class) cannot be easily obtained
  59. # without abstraction leaks, generic classes do not offer additional benefits
  60. # for type hinting. However, I would stick with the current design, as it does
  61. # not introduce runtime overhead at the moment and may prove useful in the
  62. # future.
  63. @class_requires_deps("fastapi")
  64. class PipelineWrapper(Generic[PipelineT]):
  65. def __init__(self, pipeline: PipelineT) -> None:
  66. super().__init__()
  67. self._pipeline = pipeline
  68. # HACK: We work around a bug in Paddle Inference by performing all
  69. # inference in the same thread.
  70. self._queue = Queue()
  71. self._closed = False
  72. self._loop = asyncio.get_running_loop()
  73. self._thread = Thread(target=self._worker, daemon=False)
  74. self._thread.start()
  75. @property
  76. def pipeline(self) -> PipelineT:
  77. return self._pipeline
  78. async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
  79. def _infer(*args, **kwargs) -> List[Any]:
  80. output: list = []
  81. with contextlib.closing(self._pipeline.predict(*args, **kwargs)) as it:
  82. for item in it:
  83. if _is_error(item):
  84. raise fastapi.HTTPException(
  85. status_code=500, detail=item["error"]
  86. )
  87. output.append(item)
  88. return output
  89. return await self.call(_infer, *args, **kwargs)
  90. async def call(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
  91. if self._closed:
  92. raise RuntimeError("`PipelineWrapper` has already been closed")
  93. fut = self._loop.create_future()
  94. self._queue.put((func, args, kwargs, fut))
  95. return await fut
  96. async def close(self):
  97. if not self._closed:
  98. self._queue.put(None)
  99. await call_async(self._thread.join)
  100. self._closed = True
  101. def _worker(self):
  102. while not self._closed:
  103. item = self._queue.get()
  104. if item is None:
  105. break
  106. func, args, kwargs, fut = item
  107. try:
  108. result = func(*args, **kwargs)
  109. self._loop.call_soon_threadsafe(fut.set_result, result)
  110. except Exception as e:
  111. self._loop.call_soon_threadsafe(fut.set_exception, e)
  112. finally:
  113. self._queue.task_done()
  114. @class_requires_deps("aiohttp")
  115. class AppContext(Generic[PipelineT]):
  116. def __init__(self, *, config: AppConfig) -> None:
  117. super().__init__()
  118. self._config = config
  119. self.extra: Dict[str, Any] = {}
  120. self._pipeline: Optional[PipelineWrapper[PipelineT]] = None
  121. self._aiohttp_session: Optional[aiohttp.ClientSession] = None
  122. @property
  123. def config(self) -> AppConfig:
  124. return self._config
  125. @property
  126. def pipeline(self) -> PipelineWrapper[PipelineT]:
  127. if not self._pipeline:
  128. raise AttributeError("`pipeline` has not been set.")
  129. return self._pipeline
  130. @pipeline.setter
  131. def pipeline(self, val: PipelineWrapper[PipelineT]) -> None:
  132. self._pipeline = val
  133. @property
  134. def aiohttp_session(self) -> "aiohttp.ClientSession":
  135. if not self._aiohttp_session:
  136. raise AttributeError("`aiohttp_session` has not been set.")
  137. return self._aiohttp_session
  138. @aiohttp_session.setter
  139. def aiohttp_session(self, val: "aiohttp.ClientSession") -> None:
  140. self._aiohttp_session = val
  141. @function_requires_deps("fastapi", "aiohttp", "starlette")
  142. def create_app(
  143. *, pipeline: PipelineT, app_config: AppConfig, app_aiohttp_session: bool = True
  144. ) -> Tuple["fastapi.FastAPI", AppContext[PipelineT]]:
  145. @contextlib.asynccontextmanager
  146. async def _app_lifespan(app: "fastapi.FastAPI") -> AsyncGenerator[None, None]:
  147. ctx.pipeline = PipelineWrapper[PipelineT](pipeline)
  148. try:
  149. if app_aiohttp_session:
  150. async with aiohttp.ClientSession(
  151. cookie_jar=aiohttp.DummyCookieJar()
  152. ) as aiohttp_session:
  153. ctx.aiohttp_session = aiohttp_session
  154. yield
  155. else:
  156. yield
  157. finally:
  158. await ctx.pipeline.close()
  159. # Should we control API versions?
  160. app = fastapi.FastAPI(lifespan=_app_lifespan)
  161. ctx = AppContext[PipelineT](config=app_config)
  162. app.state.context = ctx
  163. @app.get("/health", operation_id="checkHealth")
  164. async def _check_health() -> AIStudioNoResultResponse:
  165. return AIStudioNoResultResponse(
  166. logId=generate_log_id(), errorCode=0, errorMsg="Healthy"
  167. )
  168. @app.exception_handler(RequestValidationError)
  169. async def _validation_exception_handler(
  170. request: fastapi.Request, exc: RequestValidationError
  171. ) -> JSONResponse:
  172. json_compatible_data = jsonable_encoder(
  173. AIStudioNoResultResponse(
  174. logId=generate_log_id(),
  175. errorCode=422,
  176. errorMsg=json.dumps(exc.errors()),
  177. )
  178. )
  179. return JSONResponse(content=json_compatible_data, status_code=422)
  180. @app.exception_handler(HTTPException)
  181. async def _http_exception_handler(
  182. request: fastapi.Request, exc: HTTPException
  183. ) -> JSONResponse:
  184. json_compatible_data = jsonable_encoder(
  185. AIStudioNoResultResponse(
  186. logId=generate_log_id(), errorCode=exc.status_code, errorMsg=exc.detail
  187. )
  188. )
  189. return JSONResponse(content=json_compatible_data, status_code=exc.status_code)
  190. @app.exception_handler(Exception)
  191. async def _unexpected_exception_handler(
  192. request: fastapi.Request, exc: Exception
  193. ) -> JSONResponse:
  194. # XXX: The default server will duplicate the error message. Is it
  195. # necessary to log the exception info here?
  196. logging.exception("Unhandled exception")
  197. json_compatible_data = jsonable_encoder(
  198. AIStudioNoResultResponse(
  199. logId=generate_log_id(),
  200. errorCode=500,
  201. errorMsg="Internal server error",
  202. )
  203. )
  204. return JSONResponse(content=json_compatible_data, status_code=500)
  205. return app, ctx
  206. # TODO: Precise type hints
  207. @function_requires_deps("fastapi")
  208. def primary_operation(
  209. app: "fastapi.FastAPI", path: str, operation_id: str, **kwargs: Any
  210. ) -> Callable:
  211. return app.post(
  212. path,
  213. operation_id=operation_id,
  214. responses={
  215. 422: {"model": AIStudioNoResultResponse},
  216. 500: {"model": AIStudioNoResultResponse},
  217. },
  218. response_model_exclude_none=True,
  219. **kwargs,
  220. )