doc_understanding.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import time
  16. from typing import Any, List
  17. from .....utils import logging
  18. from .....utils.deps import function_requires_deps, is_dep_available
  19. from ...infra import utils as serving_utils
  20. from ...infra.config import AppConfig
  21. from ...schemas.doc_understanding import (
  22. INFER_ENDPOINT,
  23. ImageContent,
  24. ImageUrl,
  25. InferRequest,
  26. Message,
  27. RoleType,
  28. TextContent,
  29. )
  30. from .._app import create_app, primary_operation
  31. if is_dep_available("fastapi"):
  32. from fastapi import FastAPI
  33. if is_dep_available("openai"):
  34. from openai.types.chat import ChatCompletion
  35. from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
  36. from openai.types.chat.chat_completion_message import ChatCompletionMessage
  37. if is_dep_available("pillow"):
  38. from PIL import Image
  39. @function_requires_deps("fastapi", "openai", "pillow")
  40. def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
  41. app, ctx = create_app(
  42. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  43. )
  44. @primary_operation(
  45. app,
  46. "/chat/completions",
  47. "inferA",
  48. )
  49. @primary_operation(
  50. app,
  51. INFER_ENDPOINT,
  52. "infer",
  53. )
  54. async def _infer(request: InferRequest) -> "ChatCompletion":
  55. pipeline = ctx.pipeline
  56. aiohttp_session = ctx.aiohttp_session
  57. def _resize_image_with_token_limit(image, max_token_num=2200, tile_size=28):
  58. image = Image.fromarray(image)
  59. w0, h0 = image.width, image.height
  60. tokens = math.ceil(w0 / tile_size) * math.ceil(h0 / tile_size)
  61. if tokens <= max_token_num:
  62. return image
  63. k = math.sqrt(
  64. max_token_num / (math.ceil(w0 / tile_size) * math.ceil(h0 / tile_size))
  65. )
  66. k = min(1.0, k)
  67. w_new = max(int(w0 * k), tile_size)
  68. h_new = max(int(h0 * k), tile_size)
  69. new_size = (w_new, h_new)
  70. resized_image = image.resize(new_size)
  71. tokens_new = math.ceil(w_new / tile_size) * math.ceil(h_new / tile_size)
  72. logging.info(
  73. f"Resizing image from {w0}x{h0} to {w_new}x{h_new}, "
  74. f"which will reduce the image tokens from {tokens} to {tokens_new}."
  75. )
  76. return resized_image
  77. def _process_messages(messages: List[Message]):
  78. system_message = ""
  79. user_message = ""
  80. image_url = ""
  81. for msg in messages:
  82. if msg.role == RoleType.SYSTEM:
  83. if isinstance(msg.content, list):
  84. for content in msg.content:
  85. if isinstance(content, TextContent):
  86. system_message = content.text
  87. break
  88. else:
  89. system_message = msg.content
  90. elif msg.role == RoleType.USER:
  91. if isinstance(msg.content, list):
  92. for content in msg.content:
  93. if isinstance(content, str):
  94. user_message = content
  95. else:
  96. if isinstance(content, TextContent):
  97. user_message = content.text
  98. elif isinstance(content, ImageContent):
  99. image_url = content.image_url
  100. if isinstance(image_url, ImageUrl):
  101. image_url = image_url.url
  102. else:
  103. user_message = msg.content
  104. return system_message, user_message, image_url
  105. system_message, user_message, image_url = _process_messages(request.messages)
  106. if request.max_image_tokens is not None:
  107. if image_url.startswith("data:image"):
  108. _, image_url = image_url.split(",", 1)
  109. img_bytes = await serving_utils.get_raw_bytes_async(
  110. image_url, aiohttp_session
  111. )
  112. image = serving_utils.image_bytes_to_array(img_bytes)
  113. image = _resize_image_with_token_limit(image, request.max_image_tokens)
  114. else:
  115. image = image_url
  116. result = (
  117. await pipeline.infer(
  118. {"image": image, "query": user_message},
  119. )
  120. )[0]
  121. return ChatCompletion(
  122. id=serving_utils.generate_log_id(),
  123. model=request.model,
  124. choices=[
  125. ChatCompletionChoice(
  126. index=0,
  127. finish_reason="stop",
  128. message=ChatCompletionMessage(
  129. role="assistant",
  130. content=result["result"],
  131. ),
  132. )
  133. ],
  134. created=int(time.time()),
  135. object="chat.completion",
  136. )
  137. return app