app.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import contextlib
  16. import json
  17. from typing import (
  18. Any,
  19. AsyncGenerator,
  20. Callable,
  21. Dict,
  22. Final,
  23. Generic,
  24. List,
  25. Mapping,
  26. Optional,
  27. Tuple,
  28. TypeVar,
  29. )
  30. import aiohttp
  31. import fastapi
  32. from fastapi.encoders import jsonable_encoder
  33. from fastapi.exceptions import RequestValidationError
  34. from fastapi.responses import JSONResponse
  35. from pydantic import BaseModel
  36. from starlette.exceptions import HTTPException
  37. from typing_extensions import ParamSpec
  38. from ..base import BasePipeline
  39. from .models import NoResultResponse
  40. from .utils import call_async, generate_log_id
  41. SERVING_CONFIG_KEY: Final[str] = "Serving"
  42. _PipelineT = TypeVar("_PipelineT", bound=BasePipeline)
  43. _P = ParamSpec("_P")
  44. _R = TypeVar("_R")
  45. class PipelineWrapper(Generic[_PipelineT]):
  46. def __init__(self, pipeline: _PipelineT) -> None:
  47. super().__init__()
  48. self._pipeline = pipeline
  49. self._lock = asyncio.Lock()
  50. @property
  51. def pipeline(self) -> _PipelineT:
  52. return self._pipeline
  53. async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
  54. def _infer() -> List[Any]:
  55. output = list(self._pipeline(*args, **kwargs))
  56. return output
  57. return await self.call(_infer)
  58. async def call(
  59. self, func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
  60. ) -> _R:
  61. async with self._lock:
  62. return await call_async(func, *args, **kwargs)
  63. class AppConfig(BaseModel):
  64. extra: Optional[Dict[str, Any]] = None
  65. class AppContext(Generic[_PipelineT]):
  66. def __init__(self, *, config: AppConfig) -> None:
  67. super().__init__()
  68. self._config = config
  69. self.extra: Dict[str, Any] = {}
  70. self._pipeline: Optional[PipelineWrapper[_PipelineT]] = None
  71. self._aiohttp_session: Optional[aiohttp.ClientSession] = None
  72. @property
  73. def config(self) -> AppConfig:
  74. return self._config
  75. @property
  76. def pipeline(self) -> PipelineWrapper[_PipelineT]:
  77. if not self._pipeline:
  78. raise AttributeError("`pipeline` has not been set.")
  79. return self._pipeline
  80. @pipeline.setter
  81. def pipeline(self, val: PipelineWrapper[_PipelineT]) -> None:
  82. self._pipeline = val
  83. @property
  84. def aiohttp_session(self) -> aiohttp.ClientSession:
  85. if not self._aiohttp_session:
  86. raise AttributeError("`aiohttp_session` has not been set.")
  87. return self._aiohttp_session
  88. @aiohttp_session.setter
  89. def aiohttp_session(self, val: aiohttp.ClientSession) -> None:
  90. self._aiohttp_session = val
  91. def create_app_config(pipeline_config: Mapping[str, Any], **kwargs: Any) -> AppConfig:
  92. app_config = pipeline_config.get(SERVING_CONFIG_KEY, {})
  93. app_config.update(kwargs)
  94. return AppConfig.model_validate(app_config)
  95. def create_app(
  96. *, pipeline: _PipelineT, app_config: AppConfig, app_aiohttp_session: bool = True
  97. ) -> Tuple[fastapi.FastAPI, AppContext[_PipelineT]]:
  98. @contextlib.asynccontextmanager
  99. async def _app_lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]:
  100. ctx.pipeline = PipelineWrapper[_PipelineT](pipeline)
  101. if app_aiohttp_session:
  102. async with aiohttp.ClientSession(
  103. cookie_jar=aiohttp.DummyCookieJar()
  104. ) as aiohttp_session:
  105. ctx.aiohttp_session = aiohttp_session
  106. yield
  107. else:
  108. yield
  109. app = fastapi.FastAPI(lifespan=_app_lifespan)
  110. ctx = AppContext[_PipelineT](config=app_config)
  111. app.state.context = ctx
  112. @app.get("/health", operation_id="checkHealth")
  113. async def _check_health() -> NoResultResponse:
  114. return NoResultResponse(
  115. logId=generate_log_id(), errorCode=0, errorMsg="Healthy"
  116. )
  117. @app.exception_handler(RequestValidationError)
  118. async def _validation_exception_handler(
  119. request: fastapi.Request, exc: RequestValidationError
  120. ) -> JSONResponse:
  121. json_compatible_data = jsonable_encoder(
  122. NoResultResponse(
  123. logId=generate_log_id(),
  124. errorCode=422,
  125. errorMsg=json.dumps(exc.errors()),
  126. )
  127. )
  128. return JSONResponse(content=json_compatible_data, status_code=422)
  129. @app.exception_handler(HTTPException)
  130. async def _http_exception_handler(
  131. request: fastapi.Request, exc: HTTPException
  132. ) -> JSONResponse:
  133. json_compatible_data = jsonable_encoder(
  134. NoResultResponse(
  135. logId=generate_log_id(), errorCode=exc.status_code, errorMsg=exc.detail
  136. )
  137. )
  138. return JSONResponse(content=json_compatible_data, status_code=exc.status_code)
  139. return app, ctx