Browse Source

add 'max_image_tokens' param for docbee serving (#3886)

学卿 6 tháng trước cách đây
mục cha
commit
baddf1a269

+ 7 - 0
docs/pipeline_usage/tutorials/vlm_pipelines/doc_understanding.en.md

@@ -399,6 +399,13 @@ Below is a basic service deployment API reference and multilingual service call
 <td>Optional</td>
 <td>false</td>
 </tr>
+<tr>
+<td><code>max_image_tokens</code></td>
+<td><code>int</code></td>
+<td>Maximum number of tokens of input image</td>
+<td>Optional</td>
+<td>None</td>
+</tr>
 </tbody>
 </table>
 

+ 7 - 0
docs/pipeline_usage/tutorials/vlm_pipelines/doc_understanding.md

@@ -399,6 +399,13 @@ for res in output:
 <td>否</td>
 <td>false</td>
 </tr>
+<tr>
+<td><code>max_image_tokens</code></td>
+<td><code>int</code></td>
+<td>图像的最大输入token数</td>
+<td>否</td>
+<td>None</td>
+</tr>
 </tbody>
 </table>
 

+ 41 - 2
paddlex/inference/serving/basic_serving/_pipeline_apps/doc_understanding.py

@@ -12,9 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import math
 import time
 from typing import Any, List
 
+from .....utils import logging
 from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
@@ -35,9 +37,11 @@ if is_dep_available("openai"):
     from openai.types.chat import ChatCompletion
     from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
     from openai.types.chat.chat_completion_message import ChatCompletionMessage
+if is_dep_available("pillow"):
+    from PIL import Image
 
 
-@function_requires_deps("fastapi", "openai")
+@function_requires_deps("fastapi", "openai", "pillow")
 def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
@@ -55,6 +59,30 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     )
     async def _infer(request: InferRequest) -> "ChatCompletion":
         pipeline = ctx.pipeline
+        aiohttp_session = ctx.aiohttp_session
+
+        def _resize_image_with_token_limit(image, max_token_num=2200, tile_size=28):
+            image = Image.fromarray(image)
+            w0, h0 = image.width, image.height
+            tokens = math.ceil(w0 / tile_size) * math.ceil(h0 / tile_size)
+            if tokens <= max_token_num:
+                return image
+
+            k = math.sqrt(
+                max_token_num / (math.ceil(w0 / tile_size) * math.ceil(h0 / tile_size))
+            )
+            k = min(1.0, k)
+            w_new = max(int(w0 * k), tile_size)
+            h_new = max(int(h0 * k), tile_size)
+            new_size = (w_new, h_new)
+            resized_image = image.resize(new_size)
+            tokens_new = math.ceil(w_new / tile_size) * math.ceil(h_new / tile_size)
+            logging.info(
+                f"Resizing image from {w0}x{h0} to {w_new}x{h_new}, "
+                f"which will reduce the image tokens from {tokens} to {tokens_new}."
+            )
+
+            return resized_image
 
         def _process_messages(messages: List[Message]):
             system_message = ""
@@ -88,9 +116,20 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
             return system_message, user_message, image_url
 
         system_message, user_message, image_url = _process_messages(request.messages)
+        if request.max_image_tokens is not None:
+            if image_url.startswith("data:image"):
+                _, image_url = image_url.split(",", 1)
+            img_bytes = await serving_utils.get_raw_bytes_async(
+                image_url, aiohttp_session
+            )
+            image = serving_utils.image_bytes_to_array(img_bytes)
+            image = _resize_image_with_token_limit(image, request.max_image_tokens)
+        else:
+            image = image_url
+
         result = (
             await pipeline.infer(
-                {"image": image_url, "query": user_message},
+                {"image": image, "query": user_message},
             )
         )[0]
 

+ 1 - 0
paddlex/inference/serving/schemas/doc_understanding.py

@@ -70,6 +70,7 @@ class InferRequest(BaseModel):
     temperature: Optional[float] = 0.1
     top_p: Optional[float] = 0.95
     stream: Optional[bool] = False
+    max_image_tokens: Optional[int] = None
 
 
 PRIMARY_OPERATIONS: Final[PrimaryOperations] = {