Pārlūkot izejas kodu

[Fix] Use one thread to perform all inference to circumvent Paddle Inference multi-thread bug (#4156)

* One thread for all inference

* Reset others

* Add closed flag
Lin Manhui 5 mēneši atpakaļ
vecāks
revīzija
81d34e4a46
1 mainītis faili ar 46 papildinājumiem un 13 dzēšanām
  1. 46 13
      paddlex/inference/serving/basic_serving/_app.py

+ 46 - 13
paddlex/inference/serving/basic_serving/_app.py

@@ -15,6 +15,8 @@
 import asyncio
 import contextlib
 import json
+from queue import Queue
+from threading import Thread
 from typing import (
     Any,
     AsyncGenerator,
@@ -74,16 +76,22 @@ class PipelineWrapper(Generic[PipelineT]):
     def __init__(self, pipeline: PipelineT) -> None:
         super().__init__()
         self._pipeline = pipeline
-        self._lock = asyncio.Lock()
+        # HACK: We work around a bug in Paddle Inference by performing all
+        # inference in the same thread.
+        self._queue = Queue()
+        self._closed = False
+        self._loop = asyncio.get_running_loop()
+        self._thread = Thread(target=self._worker, daemon=False)
+        self._thread.start()
 
     @property
     def pipeline(self) -> PipelineT:
         return self._pipeline
 
     async def infer(self, *args: Any, **kwargs: Any) -> List[Any]:
-        def _infer() -> List[Any]:
+        def _infer(*args, **kwargs) -> List[Any]:
             output: list = []
-            with contextlib.closing(self._pipeline(*args, **kwargs)) as it:
+            with contextlib.closing(self._pipeline.predict(*args, **kwargs)) as it:
                 for item in it:
                     if _is_error(item):
                         raise fastapi.HTTPException(
@@ -93,11 +101,33 @@ class PipelineWrapper(Generic[PipelineT]):
 
             return output
 
-        return await self.call(_infer)
+        return await self.call(_infer, *args, **kwargs)
 
     async def call(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
-        async with self._lock:
-            return await call_async(func, *args, **kwargs)
+        if self._closed:
+            raise RuntimeError("`PipelineWrapper` has already been closed")
+        fut = self._loop.create_future()
+        self._queue.put((func, args, kwargs, fut))
+        return await fut
+
+    async def close(self):
+        if not self._closed:
+            self._queue.put(None)
+            await call_async(self._thread.join)
+
+    def _worker(self):
+        while not self._closed:
+            item = self._queue.get()
+            if item is None:
+                break
+            func, args, kwargs, fut = item
+            try:
+                result = func(*args, **kwargs)
+                self._loop.call_soon_threadsafe(fut.set_result, result)
+            except Exception as e:
+                self._loop.call_soon_threadsafe(fut.set_exception, e)
+            finally:
+                self._queue.task_done()
 
 
 @class_requires_deps("aiohttp")
@@ -141,14 +171,17 @@ def create_app(
     @contextlib.asynccontextmanager
     async def _app_lifespan(app: "fastapi.FastAPI") -> AsyncGenerator[None, None]:
         ctx.pipeline = PipelineWrapper[PipelineT](pipeline)
-        if app_aiohttp_session:
-            async with aiohttp.ClientSession(
-                cookie_jar=aiohttp.DummyCookieJar()
-            ) as aiohttp_session:
-                ctx.aiohttp_session = aiohttp_session
+        try:
+            if app_aiohttp_session:
+                async with aiohttp.ClientSession(
+                    cookie_jar=aiohttp.DummyCookieJar()
+                ) as aiohttp_session:
+                    ctx.aiohttp_session = aiohttp_session
+                    yield
+            else:
                 yield
-        else:
-            yield
+        finally:
+            await ctx.pipeline.close()
 
     # Should we control API versions?
     app = fastapi.FastAPI(lifespan=_app_lifespan)