image_processing_mineru2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import ast
  2. import math
  3. import re
  4. from functools import partial, reduce
  5. from typing import Dict, Optional, Union
  6. import numpy as np
  7. import torch
  8. from PIL import Image
  9. from transformers.image_processing_utils import (
  10. BaseImageProcessor,
  11. BatchFeature,
  12. get_size_dict,
  13. )
  14. from transformers.image_transforms import (
  15. convert_to_rgb,
  16. normalize,
  17. rescale,
  18. resize,
  19. to_channel_dimension_format,
  20. )
  21. from transformers.image_utils import (
  22. ChannelDimension,
  23. PILImageResampling,
  24. to_numpy_array,
  25. )
  26. from transformers.utils import TensorType
  27. def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
  28. original_width, original_height = original_size
  29. best_fit = (0, 0)
  30. max_effective_resolution = 0
  31. min_wasted_resolution = float("inf")
  32. for width, height in possible_resolutions:
  33. scale = min(width / original_width, height / original_height)
  34. downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
  35. effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
  36. wasted_resolution = (width * height) - effective_resolution
  37. if effective_resolution > max_effective_resolution or (
  38. effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
  39. ):
  40. max_effective_resolution = effective_resolution
  41. min_wasted_resolution = wasted_resolution
  42. best_fit = (width, height)
  43. return best_fit
  44. def divide_to_patches(image, patch_size):
  45. patches = []
  46. width, height = image.size
  47. for i in range(0, height, patch_size):
  48. for j in range(0, width, patch_size):
  49. box = (j, i, j + patch_size, i + patch_size)
  50. patch = image.crop(box)
  51. patches.append(patch)
  52. return patches
  53. def expand2square(pil_img, background_color):
  54. width, height = pil_img.size
  55. if width == height:
  56. return pil_img
  57. if pil_img.mode == "L":
  58. pil_img = pil_img.convert("RGB")
  59. if width > height:
  60. result = Image.new(pil_img.mode, (width, width), background_color)
  61. result.paste(pil_img, (0, (width - height) // 2))
  62. return result
  63. else:
  64. result = Image.new(pil_img.mode, (height, height), background_color)
  65. result.paste(pil_img, ((height - width) // 2, 0))
  66. return result
  67. def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
  68. if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
  69. assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
  70. matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
  71. range_start = tuple(map(int, matches[0]))
  72. range_end = tuple(map(int, matches[-1]))
  73. grid_pinpoints = [
  74. (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
  75. ]
  76. grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
  77. if type(grid_pinpoints) is list:
  78. possible_resolutions = grid_pinpoints
  79. else:
  80. possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
  81. width, height = select_best_resolution(image_size, possible_resolutions)
  82. return width // patch_size, height // patch_size
  83. # This functions is not used.
  84. def resize_and_pad_image(image, target_resolution):
  85. original_width, original_height = image.size
  86. target_width, target_height = target_resolution
  87. scale_w = target_width / original_width
  88. scale_h = target_height / original_height
  89. if scale_w < scale_h:
  90. new_width = target_width
  91. new_height = min(math.ceil(original_height * scale_w), target_height)
  92. else:
  93. new_height = target_height
  94. new_width = min(math.ceil(original_width * scale_h), target_width)
  95. # Resize the image
  96. resized_image = image.resize((new_width, new_height))
  97. new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
  98. paste_x = (target_width - new_width) // 2
  99. paste_y = (target_height - new_height) // 2
  100. new_image.paste(resized_image, (paste_x, paste_y))
  101. return new_image
  102. # DIFFERENT from sglang.srt.mm_utils.process_anyres_image
  103. def process_anyres_image(image, processor, grid_pinpoints):
  104. if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
  105. patch_size = processor.crop_size["height"]
  106. assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
  107. matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
  108. range_start = tuple(map(int, matches[0]))
  109. range_end = tuple(map(int, matches[-1]))
  110. grid_pinpoints = [
  111. (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
  112. ]
  113. grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
  114. if type(grid_pinpoints) is list:
  115. possible_resolutions = grid_pinpoints
  116. else:
  117. possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
  118. best_resolution = select_best_resolution(image.size, possible_resolutions)
  119. # image_padded = resize_and_pad_image(image, best_resolution)
  120. image_padded = image.resize(best_resolution)
  121. patches = divide_to_patches(image_padded, processor.crop_size["height"])
  122. image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
  123. image_patches = [image_original_resize] + patches
  124. image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
  125. return torch.stack(image_patches, dim=0)
  126. def process_images(images, image_processor, model_cfg):
  127. image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", "")
  128. new_images = []
  129. if image_aspect_ratio == "pad":
  130. for image in images:
  131. image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
  132. image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
  133. new_images.append(image)
  134. elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
  135. for image in images:
  136. image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
  137. new_images.append(image)
  138. else:
  139. return image_processor(images, return_tensors="pt")["pixel_values"]
  140. if all(x.shape == new_images[0].shape for x in new_images):
  141. new_images = torch.stack(new_images, dim=0)
  142. return new_images
  143. class Mineru2ImageProcessor(BaseImageProcessor):
  144. model_input_names = ["pixel_values"]
  145. def __init__(
  146. self,
  147. image_mean=(0.5, 0.5, 0.5),
  148. image_std=(0.5, 0.5, 0.5),
  149. size=(384, 384),
  150. crop_size: Optional[Dict[str, int]] = None,
  151. resample=PILImageResampling.BICUBIC,
  152. rescale_factor=1 / 255,
  153. data_format=ChannelDimension.FIRST,
  154. image_aspect_ratio: Optional[str] = None,
  155. image_grid_pinpoints: Optional[list] = None,
  156. **kwargs,
  157. ) -> None:
  158. super().__init__(**kwargs)
  159. crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
  160. crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
  161. self.image_mean = image_mean
  162. self.image_std = image_std
  163. self.size = size
  164. self.resample = resample
  165. self.rescale_factor = rescale_factor
  166. self.data_format = data_format
  167. self.crop_size = crop_size
  168. self.image_aspect_ratio = image_aspect_ratio
  169. self.image_grid_pinpoints = image_grid_pinpoints
  170. self.in_e2e_processing = False
  171. def _preprocess(self, images):
  172. if isinstance(images, Image.Image):
  173. images = [images]
  174. else:
  175. # to adapt video data
  176. images = [to_numpy_array(image) for image in images]
  177. assert isinstance(images, list)
  178. transforms = [
  179. convert_to_rgb,
  180. to_numpy_array,
  181. partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
  182. partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
  183. partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
  184. partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
  185. ]
  186. images = reduce(lambda x, f: [*map(f, x)], transforms, images)
  187. return {"pixel_values": images}
  188. def _preprocess_end_to_end(self, images):
  189. image_aspect_ratio = self.image_aspect_ratio
  190. image_grid_pinpoints = self.image_grid_pinpoints
  191. assert image_aspect_ratio is not None
  192. assert image_grid_pinpoints is not None
  193. pixel_values = []
  194. if image_aspect_ratio == "pad":
  195. for image in images:
  196. image = expand2square(image, tuple(int(x * 255) for x in self.image_mean))
  197. image = self._preprocess(image)["pixel_values"][0]
  198. pixel_values.append(image)
  199. elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
  200. for image in images:
  201. image = process_anyres_image(image, self, self.image_grid_pinpoints)
  202. pixel_values.append(image.numpy())
  203. else:
  204. pixel_values = self._preprocess(images)["pixel_values"]
  205. if isinstance(pixel_values, list) and all(x.shape == pixel_values[0].shape for x in pixel_values):
  206. pixel_values = np.stack(pixel_values, axis=0)
  207. # CAUTION: here used (height, width).
  208. image_sizes = [(image.height, image.width) for image in images]
  209. assert len(pixel_values) == len(image_sizes)
  210. return {"pixel_values": pixel_values, "image_sizes": image_sizes}
  211. def preprocess(
  212. self,
  213. images,
  214. return_tensors: Optional[Union[str, TensorType]] = None,
  215. **kwargs,
  216. ):
  217. if self.image_aspect_ratio is None or self.in_e2e_processing:
  218. data = self._preprocess(images)
  219. else:
  220. assert self.image_grid_pinpoints is not None
  221. self.in_e2e_processing = True
  222. try:
  223. data = self._preprocess_end_to_end(images)
  224. finally:
  225. self.in_e2e_processing = False
  226. return BatchFeature(data=data, tensor_type=return_tensors)