image_utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import math
  2. import base64
  3. from PIL import Image
  4. from typing import Tuple
  5. import os
  6. from dots_ocr.utils.consts import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS
  7. from dots_ocr.utils.doc_utils import fitz_doc_to_image
  8. from io import BytesIO
  9. import fitz
  10. import requests
  11. import copy
  12. def round_by_factor(number: int, factor: int) -> int:
  13. """Returns the closest integer to 'number' that is divisible by 'factor'."""
  14. return round(number / factor) * factor
  15. def ceil_by_factor(number: int, factor: int) -> int:
  16. """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
  17. return math.ceil(number / factor) * factor
  18. def floor_by_factor(number: int, factor: int) -> int:
  19. """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
  20. return math.floor(number / factor) * factor
  21. def smart_resize(
  22. height: int,
  23. width: int,
  24. factor: int = 28,
  25. min_pixels: int = 3136,
  26. max_pixels: int = 11289600,
  27. ):
  28. """Rescales the image so that the following conditions are met:
  29. 1. Both dimensions (height and width) are divisible by 'factor'.
  30. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
  31. 3. The aspect ratio of the image is maintained as closely as possible.
  32. """
  33. if max(height, width) / min(height, width) > 200:
  34. raise ValueError(
  35. f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
  36. )
  37. h_bar = max(factor, round_by_factor(height, factor))
  38. w_bar = max(factor, round_by_factor(width, factor))
  39. if h_bar * w_bar > max_pixels:
  40. beta = math.sqrt((height * width) / max_pixels)
  41. h_bar = max(factor, floor_by_factor(height / beta, factor))
  42. w_bar = max(factor, floor_by_factor(width / beta, factor))
  43. elif h_bar * w_bar < min_pixels:
  44. beta = math.sqrt(min_pixels / (height * width))
  45. h_bar = ceil_by_factor(height * beta, factor)
  46. w_bar = ceil_by_factor(width * beta, factor)
  47. if h_bar * w_bar > max_pixels: # max_pixels first to control the token length
  48. beta = math.sqrt((h_bar * w_bar) / max_pixels)
  49. h_bar = max(factor, floor_by_factor(h_bar / beta, factor))
  50. w_bar = max(factor, floor_by_factor(w_bar / beta, factor))
  51. return h_bar, w_bar
  52. def PILimage_to_base64(image, format='PNG'):
  53. buffered = BytesIO()
  54. image.save(buffered, format=format)
  55. base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
  56. return f"data:image/{format.lower()};base64,{base64_str}"
  57. def to_rgb(pil_image: Image.Image) -> Image.Image:
  58. if pil_image.mode == 'RGBA':
  59. white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
  60. white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
  61. return white_background
  62. else:
  63. return pil_image.convert("RGB")
  64. # copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
  65. def fetch_image(
  66. image,
  67. min_pixels=None,
  68. max_pixels=None,
  69. resized_height=None,
  70. resized_width=None,
  71. ) -> Image.Image:
  72. assert image is not None, f"image not found, maybe input format error: {image}"
  73. image_obj = None
  74. if isinstance(image, Image.Image):
  75. image_obj = image
  76. elif image.startswith("http://") or image.startswith("https://"):
  77. # fix memory leak issue while using BytesIO
  78. with requests.get(image, stream=True) as response:
  79. response.raise_for_status()
  80. with BytesIO(response.content) as bio:
  81. image_obj = copy.deepcopy(Image.open(bio))
  82. elif image.startswith("file://"):
  83. image_obj = Image.open(image[7:])
  84. elif image.startswith("data:image"):
  85. if "base64," in image:
  86. _, base64_data = image.split("base64,", 1)
  87. data = base64.b64decode(base64_data)
  88. # fix memory leak issue while using BytesIO
  89. with BytesIO(data) as bio:
  90. image_obj = copy.deepcopy(Image.open(bio))
  91. else:
  92. image_obj = Image.open(image)
  93. if image_obj is None:
  94. raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
  95. image = to_rgb(image_obj)
  96. ## resize
  97. if resized_height and resized_width:
  98. resized_height, resized_width = smart_resize(
  99. resized_height,
  100. resized_width,
  101. factor=IMAGE_FACTOR,
  102. )
  103. assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
  104. image = image.resize((resized_width, resized_height))
  105. elif min_pixels or max_pixels:
  106. width, height = image.size
  107. if not min_pixels:
  108. min_pixels = MIN_PIXELS
  109. if not max_pixels:
  110. max_pixels = MAX_PIXELS
  111. resized_height, resized_width = smart_resize(
  112. height,
  113. width,
  114. factor=IMAGE_FACTOR,
  115. min_pixels=min_pixels,
  116. max_pixels=max_pixels,
  117. )
  118. assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
  119. image = image.resize((resized_width, resized_height))
  120. return image
  121. def get_input_dimensions(
  122. image: Image.Image,
  123. min_pixels: int,
  124. max_pixels: int,
  125. factor: int = 28
  126. ) -> Tuple[int, int]:
  127. """
  128. Gets the resized dimensions of the input image.
  129. Args:
  130. image: The original image.
  131. min_pixels: The minimum number of pixels.
  132. max_pixels: The maximum number of pixels.
  133. factor: The resizing factor.
  134. Returns:
  135. The resized (width, height).
  136. """
  137. input_height, input_width = smart_resize(
  138. image.height,
  139. image.width,
  140. factor=factor,
  141. min_pixels=min_pixels,
  142. max_pixels=max_pixels
  143. )
  144. return input_width, input_height
  145. def get_image_by_fitz_doc(image, target_dpi=200):
  146. # get image through fitz, to get target dpi image, mainly for higher image
  147. if not isinstance(image, Image.Image):
  148. assert isinstance(image, str)
  149. _, file_ext = os.path.splitext(image)
  150. assert file_ext in {'.jpg', '.jpeg', '.png'}
  151. if image.startswith("http://") or image.startswith("https://"):
  152. with requests.get(image, stream=True) as response:
  153. response.raise_for_status()
  154. data_bytes = response.content
  155. else:
  156. with open(image, 'rb') as f:
  157. data_bytes = f.read()
  158. image = Image.open(BytesIO(data_bytes))
  159. else:
  160. data_bytes = BytesIO()
  161. image.save(data_bytes, format='PNG')
  162. origin_dpi = image.info.get('dpi', None)
  163. pdf_bytes = fitz.open(stream=data_bytes).convert_to_pdf()
  164. doc = fitz.open('pdf', pdf_bytes)
  165. page = doc[0]
  166. image_fitz = fitz_doc_to_image(page, target_dpi=target_dpi, origin_dpi=origin_dpi)
  167. return image_fitz