| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import asyncio
- import contextlib
- import json
- from queue import Queue
- from threading import Thread
- from typing import (
- Any,
- AsyncGenerator,
- Callable,
- Dict,
- Generic,
- List,
- Optional,
- Tuple,
- TypedDict,
- TypeVar,
- )
- from typing_extensions import ParamSpec, TypeGuard
- from ....utils import logging
- from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
- from ...pipelines import BasePipeline
- from ..infra.config import AppConfig
- from ..infra.models import AIStudioNoResultResponse
- from ..infra.utils import call_async, generate_log_id
- if is_dep_available("aiohttp"):
- import aiohttp
- if is_dep_available("fastapi"):
- import fastapi
- from fastapi.encoders import jsonable_encoder
- from fastapi.exceptions import RequestValidationError
- from fastapi.responses import JSONResponse
- if is_dep_available("starlette"):
- from starlette.exceptions import HTTPException
- PipelineT = TypeVar("PipelineT", bound=BasePipeline)
- P = ParamSpec("P")
- R = TypeVar("R")
- class _Error(TypedDict):
- error: str
- def _is_error(obj: object) -> TypeGuard[_Error]:
- return (
- isinstance(obj, dict)
- and obj.keys() == {"error"}
- and isinstance(obj["error"], str)
- )
- # XXX: Since typing info (e.g., the pipeline class) cannot be easily obtained
- # without abstraction leaks, generic classes do not offer additional benefits
- # for type hinting. However, I would stick with the current design, as it does
- # not introduce runtime overhead at the moment and may prove useful in the
- # future.
- @class_requires_deps("fastapi")
- class PipelineWrapper(Generic[PipelineT]):
- def __init__(self, pipeline: PipelineT) -> None:
- super().__init__()
- self._pipeline = pipeline
- # HACK: We work around a bug in Paddle Inference by performing all
- # inference in the same thread.
- self._queue = Queue()
- self._closed = False
- self._loop = asyncio.get_running_loop()
- self._thread = Thread(target=self._worker, daemon=False)
- self._thread.start()
- @property
- def pipeline(self) -> PipelineT:
- return self._pipeline
- async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
- def _infer(*args, **kwargs) -> List[Any]:
- output: list = []
- with contextlib.closing(self._pipeline.predict(*args, **kwargs)) as it:
- for item in it:
- if _is_error(item):
- raise fastapi.HTTPException(
- status_code=500, detail=item["error"]
- )
- output.append(item)
- return output
- return await self.call(_infer, *args, **kwargs)
- async def call(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
- if self._closed:
- raise RuntimeError("`PipelineWrapper` has already been closed")
- fut = self._loop.create_future()
- self._queue.put((func, args, kwargs, fut))
- return await fut
- async def close(self):
- if not self._closed:
- self._queue.put(None)
- await call_async(self._thread.join)
- self._closed = True
- def _worker(self):
- while not self._closed:
- item = self._queue.get()
- if item is None:
- break
- func, args, kwargs, fut = item
- try:
- result = func(*args, **kwargs)
- self._loop.call_soon_threadsafe(fut.set_result, result)
- except Exception as e:
- self._loop.call_soon_threadsafe(fut.set_exception, e)
- finally:
- self._queue.task_done()
- @class_requires_deps("aiohttp")
- class AppContext(Generic[PipelineT]):
- def __init__(self, *, config: AppConfig) -> None:
- super().__init__()
- self._config = config
- self.extra: Dict[str, Any] = {}
- self._pipeline: Optional[PipelineWrapper[PipelineT]] = None
- self._aiohttp_session: Optional[aiohttp.ClientSession] = None
- @property
- def config(self) -> AppConfig:
- return self._config
- @property
- def pipeline(self) -> PipelineWrapper[PipelineT]:
- if not self._pipeline:
- raise AttributeError("`pipeline` has not been set.")
- return self._pipeline
- @pipeline.setter
- def pipeline(self, val: PipelineWrapper[PipelineT]) -> None:
- self._pipeline = val
- @property
- def aiohttp_session(self) -> "aiohttp.ClientSession":
- if not self._aiohttp_session:
- raise AttributeError("`aiohttp_session` has not been set.")
- return self._aiohttp_session
- @aiohttp_session.setter
- def aiohttp_session(self, val: "aiohttp.ClientSession") -> None:
- self._aiohttp_session = val
- @function_requires_deps("fastapi", "aiohttp", "starlette")
- def create_app(
- *, pipeline: PipelineT, app_config: AppConfig, app_aiohttp_session: bool = True
- ) -> Tuple["fastapi.FastAPI", AppContext[PipelineT]]:
- @contextlib.asynccontextmanager
- async def _app_lifespan(app: "fastapi.FastAPI") -> AsyncGenerator[None, None]:
- ctx.pipeline = PipelineWrapper[PipelineT](pipeline)
- try:
- if app_aiohttp_session:
- async with aiohttp.ClientSession(
- cookie_jar=aiohttp.DummyCookieJar()
- ) as aiohttp_session:
- ctx.aiohttp_session = aiohttp_session
- yield
- else:
- yield
- finally:
- await ctx.pipeline.close()
- # Should we control API versions?
- app = fastapi.FastAPI(lifespan=_app_lifespan)
- ctx = AppContext[PipelineT](config=app_config)
- app.state.context = ctx
- @app.get("/health", operation_id="checkHealth")
- async def _check_health() -> AIStudioNoResultResponse:
- return AIStudioNoResultResponse(
- logId=generate_log_id(), errorCode=0, errorMsg="Healthy"
- )
- @app.exception_handler(RequestValidationError)
- async def _validation_exception_handler(
- request: fastapi.Request, exc: RequestValidationError
- ) -> JSONResponse:
- json_compatible_data = jsonable_encoder(
- AIStudioNoResultResponse(
- logId=generate_log_id(),
- errorCode=422,
- errorMsg=json.dumps(exc.errors()),
- )
- )
- return JSONResponse(content=json_compatible_data, status_code=422)
- @app.exception_handler(HTTPException)
- async def _http_exception_handler(
- request: fastapi.Request, exc: HTTPException
- ) -> JSONResponse:
- json_compatible_data = jsonable_encoder(
- AIStudioNoResultResponse(
- logId=generate_log_id(), errorCode=exc.status_code, errorMsg=exc.detail
- )
- )
- return JSONResponse(content=json_compatible_data, status_code=exc.status_code)
- @app.exception_handler(Exception)
- async def _unexpected_exception_handler(
- request: fastapi.Request, exc: Exception
- ) -> JSONResponse:
- # XXX: The default server will duplicate the error message. Is it
- # necessary to log the exception info here?
- logging.exception("Unhandled exception")
- json_compatible_data = jsonable_encoder(
- AIStudioNoResultResponse(
- logId=generate_log_id(),
- errorCode=500,
- errorMsg="Internal server error",
- )
- )
- return JSONResponse(content=json_compatible_data, status_code=500)
- return app, ctx
- # TODO: Precise type hints
- @function_requires_deps("fastapi")
- def primary_operation(
- app: "fastapi.FastAPI", path: str, operation_id: str, **kwargs: Any
- ) -> Callable:
- return app.post(
- path,
- operation_id=operation_id,
- responses={
- 422: {"model": AIStudioNoResultResponse},
- 500: {"model": AIStudioNoResultResponse},
- },
- response_model_exclude_none=True,
- **kwargs,
- )
|