image_processor.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import ast
  2. import asyncio
  3. import re
  4. from typing import List, Optional, Union
  5. import numpy as np
  6. try:
  7. # sglang==0.4.5.post3
  8. from sglang.srt.managers.multimodal_processors.base_processor import (
  9. BaseMultimodalProcessor as BaseProcessor,
  10. )
  11. get_global_processor = None
  12. except ImportError:
  13. # sglang==0.4.4.post1
  14. from sglang.srt.managers.image_processors.base_image_processor import (
  15. BaseImageProcessor as BaseProcessor,
  16. get_global_processor,
  17. )
  18. from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
  19. from sglang.srt.utils import load_image, logger
  20. from sglang.utils import get_exception_traceback
  21. from .model import Mineru2QwenForCausalLM
  22. # image_best_res is only resized (not padded).
  23. def process_anyres_image(image, processor, grid_pinpoints):
  24. if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
  25. patch_size = processor.crop_size["height"]
  26. assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
  27. matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
  28. range_start = tuple(map(int, matches[0]))
  29. range_end = tuple(map(int, matches[-1]))
  30. grid_pinpoints = [
  31. (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
  32. ]
  33. grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
  34. if type(grid_pinpoints) is list:
  35. possible_resolutions = grid_pinpoints
  36. else:
  37. possible_resolutions = ast.literal_eval(grid_pinpoints)
  38. best_resolution = select_best_resolution(image.size, possible_resolutions)
  39. image_best_res = image.resize(best_resolution) # <<<<<<< Here changed
  40. patches = divide_to_patches(image_best_res, processor.crop_size["height"])
  41. image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
  42. image_patches = [image_original_resize] + patches
  43. image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
  44. return np.stack(image_patches, axis=0)
  45. class Mineru2ImageProcessor(BaseProcessor):
  46. def __init__(self, hf_config, server_args, _processor):
  47. super().__init__(hf_config, server_args, _processor)
  48. @staticmethod
  49. def _process_single_image_task(
  50. image_data: Union[str, bytes],
  51. image_aspect_ratio: Optional[str] = None,
  52. image_grid_pinpoints: Optional[str] = None,
  53. image_processor=None,
  54. ):
  55. if image_processor is None:
  56. assert get_global_processor is not None
  57. image_processor = get_global_processor().image_processor
  58. try:
  59. image, image_size = load_image(image_data)
  60. if image_size is not None:
  61. # It is a video with multiple images
  62. image_hash = hash(image_data)
  63. pixel_values = image_processor(image)["pixel_values"]
  64. pixel_values = np.stack(pixel_values, axis=0)
  65. return pixel_values, image_hash, image_size
  66. else:
  67. # It is an image
  68. image_hash = hash(image_data)
  69. if image_aspect_ratio == "pad":
  70. image = expand2square(
  71. image,
  72. tuple(int(x * 255) for x in image_processor.image_mean),
  73. )
  74. pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
  75. elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
  76. pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
  77. else:
  78. pixel_values = image_processor(image)["pixel_values"][0]
  79. return pixel_values, image_hash, image.size
  80. except Exception:
  81. logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
  82. async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
  83. if hasattr(self, "cpu_executor"):
  84. executor = self.cpu_executor
  85. else:
  86. executor = self.executor
  87. if get_global_processor is not None:
  88. image_processor = None # save ipc cost
  89. else:
  90. image_processor = self._processor.image_processor
  91. if executor is not None:
  92. loop = asyncio.get_running_loop()
  93. return await loop.run_in_executor(
  94. executor,
  95. Mineru2ImageProcessor._process_single_image_task,
  96. image_data,
  97. aspect_ratio,
  98. grid_pinpoints,
  99. image_processor,
  100. )
  101. else:
  102. return self._process_single_image_task(
  103. image_data,
  104. aspect_ratio,
  105. grid_pinpoints,
  106. image_processor,
  107. )
  108. # sglang==0.4.4.post1
  109. async def process_images_async(
  110. self,
  111. image_data: List[Union[str, bytes]],
  112. input_text,
  113. request_obj,
  114. *args,
  115. **kwargs,
  116. ):
  117. if not image_data:
  118. return None
  119. modalities = request_obj.modalities or ["image"]
  120. aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", "")
  121. grid_pinpoints = (
  122. self.hf_config.image_grid_pinpoints
  123. if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio
  124. else None
  125. )
  126. if isinstance(image_data, str):
  127. image_data = [image_data]
  128. if isinstance(image_data, list) and len(image_data) > 0:
  129. if "multi-images" in modalities or "video" in modalities:
  130. # Multiple images
  131. aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
  132. pixel_values, image_hashes, image_sizes = [], [], []
  133. res = []
  134. for img_data in image_data:
  135. res.append(self._process_single_image(img_data, aspect_ratio, grid_pinpoints))
  136. res = await asyncio.gather(*res)
  137. for pixel_v, image_h, image_s in res:
  138. pixel_values.append(pixel_v)
  139. image_hashes.append(image_h)
  140. image_sizes.append(image_s)
  141. if isinstance(pixel_values[0], np.ndarray):
  142. pixel_values = np.stack(pixel_values, axis=0)
  143. else:
  144. # A single image
  145. pixel_values, image_hash, image_size = await self._process_single_image(
  146. image_data[0], aspect_ratio, grid_pinpoints
  147. )
  148. image_hashes = [image_hash]
  149. image_sizes = [image_size]
  150. else:
  151. raise ValueError(f"Invalid image data: {image_data}")
  152. return {
  153. "pixel_values": pixel_values,
  154. "image_hashes": image_hashes,
  155. "image_sizes": image_sizes,
  156. "modalities": request_obj.modalities or ["image"],
  157. }
  158. # sglang==0.4.5.post3
  159. async def process_mm_data_async(
  160. self,
  161. image_data: List[Union[str, bytes]],
  162. input_text,
  163. request_obj,
  164. *args,
  165. **kwargs,
  166. ):
  167. from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
  168. result = await self.process_images_async(image_data, input_text, request_obj, *args, **kwargs)
  169. if result is None:
  170. return None
  171. modality = Modality.IMAGE
  172. if isinstance(request_obj.modalities, list):
  173. if request_obj.modalities[0] == "multi-images":
  174. modality = Modality.MULTI_IMAGES
  175. elif request_obj.modalities[0] == "video":
  176. modality = Modality.VIDEO
  177. return {
  178. "mm_items": [
  179. MultimodalDataItem(
  180. pixel_values=result["pixel_values"],
  181. image_sizes=result["image_sizes"],
  182. modality=modality,
  183. )
  184. ],
  185. }
  186. ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}