|
|
@@ -12,9 +12,11 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import math
|
|
|
import time
|
|
|
from typing import Any, List
|
|
|
|
|
|
+from .....utils import logging
|
|
|
from .....utils.deps import function_requires_deps, is_dep_available
|
|
|
from ...infra import utils as serving_utils
|
|
|
from ...infra.config import AppConfig
|
|
|
@@ -35,9 +37,11 @@ if is_dep_available("openai"):
|
|
|
from openai.types.chat import ChatCompletion
|
|
|
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
|
|
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|
|
+if is_dep_available("pillow"):
|
|
|
+ from PIL import Image
|
|
|
|
|
|
|
|
|
-@function_requires_deps("fastapi", "openai")
|
|
|
+@function_requires_deps("fastapi", "openai", "pillow")
|
|
|
def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
|
|
|
app, ctx = create_app(
|
|
|
pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
|
|
|
@@ -55,6 +59,30 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
|
|
|
)
|
|
|
async def _infer(request: InferRequest) -> "ChatCompletion":
|
|
|
pipeline = ctx.pipeline
|
|
|
+ aiohttp_session = ctx.aiohttp_session
|
|
|
+
|
|
|
+ def _resize_image_with_token_limit(image, max_token_num=2200, tile_size=28):
|
|
|
+ image = Image.fromarray(image)
|
|
|
+ w0, h0 = image.width, image.height
|
|
|
+ tokens = math.ceil(w0 / tile_size) * math.ceil(h0 / tile_size)
|
|
|
+ if tokens <= max_token_num:
|
|
|
+ return image
|
|
|
+
|
|
|
+ k = math.sqrt(
|
|
|
+ max_token_num / (math.ceil(w0 / tile_size) * math.ceil(h0 / tile_size))
|
|
|
+ )
|
|
|
+ k = min(1.0, k)
|
|
|
+ w_new = max(int(w0 * k), tile_size)
|
|
|
+ h_new = max(int(h0 * k), tile_size)
|
|
|
+ new_size = (w_new, h_new)
|
|
|
+ resized_image = image.resize(new_size)
|
|
|
+ tokens_new = math.ceil(w_new / tile_size) * math.ceil(h_new / tile_size)
|
|
|
+ logging.info(
|
|
|
+ f"Resizing image from {w0}x{h0} to {w_new}x{h_new}, "
|
|
|
+ f"which will reduce the image tokens from {tokens} to {tokens_new}."
|
|
|
+ )
|
|
|
+
|
|
|
+ return resized_image
|
|
|
|
|
|
def _process_messages(messages: List[Message]):
|
|
|
system_message = ""
|
|
|
@@ -88,9 +116,20 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
|
|
|
return system_message, user_message, image_url
|
|
|
|
|
|
system_message, user_message, image_url = _process_messages(request.messages)
|
|
|
+ if request.max_image_tokens is not None:
|
|
|
+ if image_url.startswith("data:image"):
|
|
|
+ _, image_url = image_url.split(",", 1)
|
|
|
+ img_bytes = await serving_utils.get_raw_bytes_async(
|
|
|
+ image_url, aiohttp_session
|
|
|
+ )
|
|
|
+ image = serving_utils.image_bytes_to_array(img_bytes)
|
|
|
+ image = _resize_image_with_token_limit(image, request.max_image_tokens)
|
|
|
+ else:
|
|
|
+ image = image_url
|
|
|
+
|
|
|
result = (
|
|
|
await pipeline.infer(
|
|
|
- {"image": image_url, "query": user_message},
|
|
|
+ {"image": image, "query": user_message},
|
|
|
)
|
|
|
)[0]
|
|
|
|