|
|
@@ -15,6 +15,8 @@
|
|
|
import asyncio
|
|
|
import contextlib
|
|
|
import json
|
|
|
+from queue import Queue
|
|
|
+from threading import Thread
|
|
|
from typing import (
|
|
|
Any,
|
|
|
AsyncGenerator,
|
|
|
@@ -74,16 +76,22 @@ class PipelineWrapper(Generic[PipelineT]):
|
|
|
def __init__(self, pipeline: PipelineT) -> None:
|
|
|
super().__init__()
|
|
|
self._pipeline = pipeline
|
|
|
- self._lock = asyncio.Lock()
|
|
|
+ # HACK: We work around a bug in Paddle Inference by performing all
|
|
|
+ # inference in the same thread.
|
|
|
+ self._queue = Queue()
|
|
|
+ self._closed = False
|
|
|
+ self._loop = asyncio.get_running_loop()
|
|
|
+ self._thread = Thread(target=self._worker, daemon=False)
|
|
|
+ self._thread.start()
|
|
|
|
|
|
@property
|
|
|
def pipeline(self) -> PipelineT:
|
|
|
return self._pipeline
|
|
|
|
|
|
async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
|
|
|
- def _infer() -> List[Any]:
|
|
|
+ def _infer(*args, **kwargs) -> List[Any]:
|
|
|
output: list = []
|
|
|
- with contextlib.closing(self._pipeline(*args, **kwargs)) as it:
|
|
|
+ with contextlib.closing(self._pipeline.predict(*args, **kwargs)) as it:
|
|
|
for item in it:
|
|
|
if _is_error(item):
|
|
|
raise fastapi.HTTPException(
|
|
|
@@ -93,11 +101,33 @@ class PipelineWrapper(Generic[PipelineT]):
|
|
|
|
|
|
return output
|
|
|
|
|
|
- return await self.call(_infer)
|
|
|
+ return await self.call(_infer, *args, **kwargs)
|
|
|
|
|
|
async def call(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
|
|
- async with self._lock:
|
|
|
- return await call_async(func, *args, **kwargs)
|
|
|
+ if self._closed:
|
|
|
+ raise RuntimeError("`PipelineWrapper` has already been closed")
|
|
|
+ fut = self._loop.create_future()
|
|
|
+ self._queue.put((func, args, kwargs, fut))
|
|
|
+ return await fut
|
|
|
+
|
|
|
+ async def close(self):
|
|
|
+ if not self._closed:
|
|
|
+ self._queue.put(None)
|
|
|
+ await call_async(self._thread.join)
|
|
|
+
|
|
|
+ def _worker(self):
|
|
|
+ while not self._closed:
|
|
|
+ item = self._queue.get()
|
|
|
+ if item is None:
|
|
|
+ break
|
|
|
+ func, args, kwargs, fut = item
|
|
|
+ try:
|
|
|
+ result = func(*args, **kwargs)
|
|
|
+ self._loop.call_soon_threadsafe(fut.set_result, result)
|
|
|
+ except Exception as e:
|
|
|
+ self._loop.call_soon_threadsafe(fut.set_exception, e)
|
|
|
+ finally:
|
|
|
+ self._queue.task_done()
|
|
|
|
|
|
|
|
|
@class_requires_deps("aiohttp")
|
|
|
@@ -141,14 +171,17 @@ def create_app(
|
|
|
@contextlib.asynccontextmanager
|
|
|
async def _app_lifespan(app: "fastapi.FastAPI") -> AsyncGenerator[None, None]:
|
|
|
ctx.pipeline = PipelineWrapper[PipelineT](pipeline)
|
|
|
- if app_aiohttp_session:
|
|
|
- async with aiohttp.ClientSession(
|
|
|
- cookie_jar=aiohttp.DummyCookieJar()
|
|
|
- ) as aiohttp_session:
|
|
|
- ctx.aiohttp_session = aiohttp_session
|
|
|
+ try:
|
|
|
+ if app_aiohttp_session:
|
|
|
+ async with aiohttp.ClientSession(
|
|
|
+ cookie_jar=aiohttp.DummyCookieJar()
|
|
|
+ ) as aiohttp_session:
|
|
|
+ ctx.aiohttp_session = aiohttp_session
|
|
|
+ yield
|
|
|
+ else:
|
|
|
yield
|
|
|
- else:
|
|
|
- yield
|
|
|
+ finally:
|
|
|
+ await ctx.pipeline.close()
|
|
|
|
|
|
# Should we control API versions?
|
|
|
app = fastapi.FastAPI(lifespan=_app_lifespan)
|