image_processor.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import ast
  2. import asyncio
  3. import re
  4. from typing import List, Optional, Union
  5. import numpy as np
  6. from sglang.version import __version__ as sglang_version
  7. from packaging import version
  8. if version.parse(sglang_version) >= version.parse("0.4.9"):
  9. # sglang >= 0.4.9
  10. from sglang.srt.multimodal.processors.base_processor import (
  11. BaseMultimodalProcessor as BaseProcessor,
  12. )
  13. from sglang.srt.multimodal.mm_utils import divide_to_patches, expand2square, select_best_resolution
  14. else:
  15. # 0.4.7 <= sglang < 0.4.9
  16. from sglang.srt.managers.multimodal_processors.base_processor import (
  17. BaseMultimodalProcessor as BaseProcessor,
  18. )
  19. from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
  20. get_global_processor = None
  21. from sglang.srt.utils import load_image, logger
  22. from sglang.utils import get_exception_traceback
  23. from .model import Mineru2QwenForCausalLM
  24. # image_best_res is only resized (not padded).
  25. def process_anyres_image(image, processor, grid_pinpoints):
  26. if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
  27. patch_size = processor.crop_size["height"]
  28. assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
  29. matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
  30. range_start = tuple(map(int, matches[0]))
  31. range_end = tuple(map(int, matches[-1]))
  32. grid_pinpoints = [
  33. (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
  34. ]
  35. grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
  36. if type(grid_pinpoints) is list:
  37. possible_resolutions = grid_pinpoints
  38. else:
  39. possible_resolutions = ast.literal_eval(grid_pinpoints)
  40. best_resolution = select_best_resolution(image.size, possible_resolutions)
  41. image_best_res = image.resize(best_resolution) # <<<<<<< Here changed
  42. patches = divide_to_patches(image_best_res, processor.crop_size["height"])
  43. image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
  44. image_patches = [image_original_resize] + patches
  45. image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
  46. return np.stack(image_patches, axis=0)
  47. class Mineru2ImageProcessor(BaseProcessor):
  48. def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
  49. super().__init__(hf_config, server_args, _processor, *args, **kwargs)
  50. @staticmethod
  51. def _process_single_image_task(
  52. image_data: Union[str, bytes],
  53. image_aspect_ratio: Optional[str] = None,
  54. image_grid_pinpoints: Optional[str] = None,
  55. image_processor=None,
  56. ):
  57. if image_processor is None:
  58. assert get_global_processor is not None
  59. image_processor = get_global_processor().image_processor
  60. try:
  61. image, image_size = load_image(image_data)
  62. if image_size is not None:
  63. # It is a video with multiple images
  64. image_hash = hash(image_data)
  65. pixel_values = image_processor(image)["pixel_values"]
  66. pixel_values = np.stack(pixel_values, axis=0)
  67. return pixel_values, image_hash, image_size
  68. else:
  69. # It is an image
  70. image_hash = hash(image_data)
  71. if image_aspect_ratio == "pad":
  72. image = expand2square(
  73. image,
  74. tuple(int(x * 255) for x in image_processor.image_mean),
  75. )
  76. pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
  77. elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
  78. pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
  79. else:
  80. pixel_values = image_processor(image)["pixel_values"][0]
  81. return pixel_values, image_hash, image.size
  82. except Exception:
  83. logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
  84. async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
  85. if hasattr(self, "cpu_executor"):
  86. executor = self.cpu_executor
  87. else:
  88. executor = self.executor
  89. if get_global_processor is not None:
  90. image_processor = None # save ipc cost
  91. else:
  92. image_processor = self._processor.image_processor
  93. if executor is not None:
  94. loop = asyncio.get_running_loop()
  95. return await loop.run_in_executor(
  96. executor,
  97. Mineru2ImageProcessor._process_single_image_task,
  98. image_data,
  99. aspect_ratio,
  100. grid_pinpoints,
  101. image_processor,
  102. )
  103. else:
  104. return self._process_single_image_task(
  105. image_data,
  106. aspect_ratio,
  107. grid_pinpoints,
  108. image_processor,
  109. )
  110. async def process_mm_data_async(
  111. self,
  112. image_data: List[Union[str, bytes]],
  113. input_text,
  114. request_obj,
  115. *args,
  116. **kwargs,
  117. ):
  118. from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
  119. if not image_data:
  120. return None
  121. modalities = request_obj.modalities or ["image"]
  122. aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
  123. grid_pinpoints = (
  124. self.hf_config.image_grid_pinpoints
  125. if hasattr(self.hf_config, "image_grid_pinpoints")
  126. and "anyres" in aspect_ratio
  127. else None
  128. )
  129. if isinstance(image_data, str):
  130. image_data = [image_data]
  131. if isinstance(image_data, list) and len(image_data) > 0:
  132. if "multi-images" in modalities or "video" in modalities:
  133. # Multiple images
  134. aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
  135. pixel_values, data_hashes, image_sizes = [], [], []
  136. res = []
  137. for img_data in image_data:
  138. res.append(
  139. self._process_single_image(
  140. img_data, aspect_ratio, grid_pinpoints
  141. )
  142. )
  143. res = await asyncio.gather(*res)
  144. for pixel_v, image_h, image_s in res:
  145. pixel_values.append(pixel_v)
  146. data_hashes.append(image_h)
  147. image_sizes.append(image_s)
  148. if isinstance(pixel_values[0], np.ndarray):
  149. pixel_values = np.stack(pixel_values, axis=0)
  150. else:
  151. # A single image
  152. pixel_values, image_hash, image_size = await self._process_single_image(
  153. image_data[0], aspect_ratio, grid_pinpoints
  154. )
  155. image_sizes = [image_size]
  156. else:
  157. raise ValueError(f"Invalid image data: {image_data}")
  158. modality = Modality.IMAGE
  159. if isinstance(request_obj.modalities, list):
  160. if request_obj.modalities[0] == "multi-images":
  161. modality = Modality.MULTI_IMAGES
  162. elif request_obj.modalities[0] == "video":
  163. modality = Modality.VIDEO
  164. if version.parse(sglang_version) >= version.parse("0.4.9.post3"):
  165. # sglang >= 0.4.9.post3
  166. return {
  167. "mm_items": [
  168. MultimodalDataItem(
  169. feature=pixel_values,
  170. model_specific_data={
  171. "image_sizes": image_sizes,
  172. },
  173. modality=modality,
  174. )
  175. ],
  176. }
  177. else:
  178. # 0.4.7 <= sglang <= 0.4.9.post2
  179. return {
  180. "mm_items": [
  181. MultimodalDataItem(
  182. pixel_values=pixel_values,
  183. image_sizes=image_sizes,
  184. modality=modality,
  185. )
  186. ],
  187. }
  188. ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}