app.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import contextlib
  16. import json
  17. from typing import (
  18. Any,
  19. AsyncGenerator,
  20. Callable,
  21. Dict,
  22. Generic,
  23. List,
  24. Optional,
  25. Tuple,
  26. TypeVar,
  27. )
  28. import aiohttp
  29. import fastapi
  30. from fastapi.encoders import jsonable_encoder
  31. from fastapi.exceptions import RequestValidationError
  32. from fastapi.responses import JSONResponse
  33. from pydantic import BaseModel
  34. from starlette.exceptions import HTTPException
  35. from typing_extensions import Final, ParamSpec
  36. from ..base import BasePipeline
  37. from .models import Response
  38. from .utils import call_async, generate_log_id
  39. SERVING_CONFIG_KEY: Final[str] = "Serving"
  40. _PipelineT = TypeVar("_PipelineT", bound=BasePipeline)
  41. _P = ParamSpec("_P")
  42. _R = TypeVar("_R")
  43. class PipelineWrapper(Generic[_PipelineT]):
  44. def __init__(self, pipeline: _PipelineT) -> None:
  45. super().__init__()
  46. self._pipeline = pipeline
  47. self._lock = asyncio.Lock()
  48. @property
  49. def pipeline(self) -> _PipelineT:
  50. return self._pipeline
  51. async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
  52. def _infer() -> List[Any]:
  53. output = list(self._pipeline(*args, **kwargs))
  54. return output
  55. return await self.call(_infer)
  56. async def call(
  57. self, func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
  58. ) -> _R:
  59. async with self._lock:
  60. return await call_async(func, *args, **kwargs)
  61. class AppConfig(BaseModel):
  62. extra: Optional[Dict[str, Any]] = None
  63. class AppContext(Generic[_PipelineT]):
  64. def __init__(self, *, config: AppConfig) -> None:
  65. super().__init__()
  66. self._config = config
  67. self.extra: Dict[str, Any] = {}
  68. self._pipeline: Optional[PipelineWrapper[_PipelineT]] = None
  69. self._aiohttp_session: Optional[aiohttp.ClientSession] = None
  70. @property
  71. def config(self) -> AppConfig:
  72. return self._config
  73. @property
  74. def pipeline(self) -> PipelineWrapper[_PipelineT]:
  75. if not self._pipeline:
  76. raise AttributeError("`pipeline` has not been set.")
  77. return self._pipeline
  78. @pipeline.setter
  79. def pipeline(self, val: PipelineWrapper[_PipelineT]) -> None:
  80. self._pipeline = val
  81. @property
  82. def aiohttp_session(self) -> aiohttp.ClientSession:
  83. if not self._aiohttp_session:
  84. raise AttributeError("`aiohttp_session` has not been set.")
  85. return self._aiohttp_session
  86. @aiohttp_session.setter
  87. def aiohttp_session(self, val: aiohttp.ClientSession) -> None:
  88. self._aiohttp_session = val
  89. def create_app_config(pipeline_config: Dict[str, Any], **kwargs: Any) -> AppConfig:
  90. app_config = pipeline_config.get(SERVING_CONFIG_KEY, {})
  91. app_config.update(kwargs)
  92. return AppConfig.model_validate(app_config)
  93. def create_app(
  94. *, pipeline: _PipelineT, app_config: AppConfig, app_aiohttp_session: bool = True
  95. ) -> Tuple[fastapi.FastAPI, AppContext[_PipelineT]]:
  96. @contextlib.asynccontextmanager
  97. async def _app_lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]:
  98. ctx.pipeline = PipelineWrapper[_PipelineT](pipeline)
  99. if app_aiohttp_session:
  100. ctx.aiohttp_session = aiohttp.ClientSession(
  101. cookie_jar=aiohttp.DummyCookieJar()
  102. )
  103. yield
  104. if app_aiohttp_session:
  105. await ctx.aiohttp_session.close()
  106. app = fastapi.FastAPI(lifespan=_app_lifespan)
  107. ctx = AppContext[_PipelineT](config=app_config)
  108. @app.get("/health", operation_id="checkHealth")
  109. async def _check_health() -> Response:
  110. return Response(logId=generate_log_id(), errorCode=0, errorMsg="Healthy")
  111. @app.exception_handler(RequestValidationError)
  112. async def _validation_exception_handler(
  113. request: fastapi.Request, exc: RequestValidationError
  114. ) -> JSONResponse:
  115. json_compatible_data = jsonable_encoder(
  116. Response(
  117. logId=generate_log_id(),
  118. errorCode=422,
  119. errorMsg=json.dumps(exc.errors()),
  120. )
  121. )
  122. return JSONResponse(content=json_compatible_data, status_code=422)
  123. @app.exception_handler(HTTPException)
  124. async def _http_exception_handler(
  125. request: fastapi.Request, exc: HTTPException
  126. ) -> JSONResponse:
  127. json_compatible_data = jsonable_encoder(
  128. Response(
  129. logId=generate_log_id(), errorCode=exc.status_code, errorMsg=exc.detail
  130. )
  131. )
  132. return JSONResponse(content=json_compatible_data, status_code=exc.status_code)
  133. return app, ctx