app.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import contextlib
  16. import json
  17. from typing import (
  18. Any,
  19. AsyncGenerator,
  20. Callable,
  21. Dict,
  22. Generic,
  23. List,
  24. Mapping,
  25. Optional,
  26. Tuple,
  27. TypeVar,
  28. )
  29. import aiohttp
  30. import fastapi
  31. from fastapi.encoders import jsonable_encoder
  32. from fastapi.exceptions import RequestValidationError
  33. from fastapi.responses import JSONResponse
  34. from pydantic import BaseModel
  35. from starlette.exceptions import HTTPException
  36. from typing_extensions import Final, ParamSpec
  37. from ..base import BasePipeline
  38. from .models import NoResultResponse
  39. from .utils import call_async, generate_log_id
  40. SERVING_CONFIG_KEY: Final[str] = "Serving"
  41. _PipelineT = TypeVar("_PipelineT", bound=BasePipeline)
  42. _P = ParamSpec("_P")
  43. _R = TypeVar("_R")
  44. class PipelineWrapper(Generic[_PipelineT]):
  45. def __init__(self, pipeline: _PipelineT) -> None:
  46. super().__init__()
  47. self._pipeline = pipeline
  48. self._lock = asyncio.Lock()
  49. @property
  50. def pipeline(self) -> _PipelineT:
  51. return self._pipeline
  52. async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
  53. def _infer() -> List[Any]:
  54. output = list(self._pipeline(*args, **kwargs))
  55. return output
  56. return await self.call(_infer)
  57. async def call(
  58. self, func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
  59. ) -> _R:
  60. async with self._lock:
  61. return await call_async(func, *args, **kwargs)
  62. class AppConfig(BaseModel):
  63. extra: Optional[Dict[str, Any]] = None
  64. class AppContext(Generic[_PipelineT]):
  65. def __init__(self, *, config: AppConfig) -> None:
  66. super().__init__()
  67. self._config = config
  68. self.extra: Dict[str, Any] = {}
  69. self._pipeline: Optional[PipelineWrapper[_PipelineT]] = None
  70. self._aiohttp_session: Optional[aiohttp.ClientSession] = None
  71. @property
  72. def config(self) -> AppConfig:
  73. return self._config
  74. @property
  75. def pipeline(self) -> PipelineWrapper[_PipelineT]:
  76. if not self._pipeline:
  77. raise AttributeError("`pipeline` has not been set.")
  78. return self._pipeline
  79. @pipeline.setter
  80. def pipeline(self, val: PipelineWrapper[_PipelineT]) -> None:
  81. self._pipeline = val
  82. @property
  83. def aiohttp_session(self) -> aiohttp.ClientSession:
  84. if not self._aiohttp_session:
  85. raise AttributeError("`aiohttp_session` has not been set.")
  86. return self._aiohttp_session
  87. @aiohttp_session.setter
  88. def aiohttp_session(self, val: aiohttp.ClientSession) -> None:
  89. self._aiohttp_session = val
  90. def create_app_config(pipeline_config: Mapping[str, Any], **kwargs: Any) -> AppConfig:
  91. app_config = pipeline_config.get(SERVING_CONFIG_KEY, {})
  92. app_config.update(kwargs)
  93. return AppConfig.model_validate(app_config)
  94. def create_app(
  95. *, pipeline: _PipelineT, app_config: AppConfig, app_aiohttp_session: bool = True
  96. ) -> Tuple[fastapi.FastAPI, AppContext[_PipelineT]]:
  97. @contextlib.asynccontextmanager
  98. async def _app_lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]:
  99. ctx.pipeline = PipelineWrapper[_PipelineT](pipeline)
  100. if app_aiohttp_session:
  101. async with aiohttp.ClientSession(
  102. cookie_jar=aiohttp.DummyCookieJar()
  103. ) as aiohttp_session:
  104. ctx.aiohttp_session = aiohttp_session
  105. yield
  106. else:
  107. yield
  108. app = fastapi.FastAPI(lifespan=_app_lifespan)
  109. ctx = AppContext[_PipelineT](config=app_config)
  110. app.state.context = ctx
  111. @app.get("/health", operation_id="checkHealth")
  112. async def _check_health() -> NoResultResponse:
  113. return NoResultResponse(
  114. logId=generate_log_id(), errorCode=0, errorMsg="Healthy"
  115. )
  116. @app.exception_handler(RequestValidationError)
  117. async def _validation_exception_handler(
  118. request: fastapi.Request, exc: RequestValidationError
  119. ) -> JSONResponse:
  120. json_compatible_data = jsonable_encoder(
  121. NoResultResponse(
  122. logId=generate_log_id(),
  123. errorCode=422,
  124. errorMsg=json.dumps(exc.errors()),
  125. )
  126. )
  127. return JSONResponse(content=json_compatible_data, status_code=422)
  128. @app.exception_handler(HTTPException)
  129. async def _http_exception_handler(
  130. request: fastapi.Request, exc: HTTPException
  131. ) -> JSONResponse:
  132. json_compatible_data = jsonable_encoder(
  133. NoResultResponse(
  134. logId=generate_log_id(), errorCode=exc.status_code, errorMsg=exc.detail
  135. )
  136. )
  137. return JSONResponse(content=json_compatible_data, status_code=exc.status_code)
  138. return app, ctx