utils.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. import base64
  16. import io
  17. import mimetypes
  18. import re
  19. import tempfile
  20. import threading
  21. import uuid
  22. from functools import partial
  23. from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar, Union, overload
  24. from urllib.parse import parse_qs, urlparse
  25. import numpy as np
  26. import pandas as pd
  27. import requests
  28. from PIL import Image
  29. from typing_extensions import Literal, ParamSpec, TypeAlias, assert_never
  30. from ....utils.deps import function_requires_deps, is_dep_available
  31. from .models import ImageInfo, PDFInfo, PDFPageInfo
  32. if is_dep_available("aiohttp"):
  33. import aiohttp
  34. if is_dep_available("opencv-contrib-python"):
  35. import cv2
  36. if is_dep_available("filetype"):
  37. import filetype
  38. if is_dep_available("pypdfium2"):
  39. import pypdfium2 as pdfium
  40. if is_dep_available("yarl"):
  41. import yarl
  42. __all__ = [
  43. "FileType",
  44. "generate_log_id",
  45. "is_url",
  46. "infer_file_type",
  47. "infer_file_ext",
  48. "image_bytes_to_array",
  49. "image_bytes_to_image",
  50. "image_to_bytes",
  51. "image_array_to_bytes",
  52. "csv_bytes_to_data_frame",
  53. "data_frame_to_bytes",
  54. "base64_encode",
  55. "read_pdf",
  56. "file_to_images",
  57. "get_image_info",
  58. "write_to_temp_file",
  59. "get_raw_bytes",
  60. "get_raw_bytes_async",
  61. "call_async",
  62. ]
  63. FileType: TypeAlias = Literal["IMAGE", "PDF", "VIDEO", "AUDIO"]
  64. P = ParamSpec("P")
  65. R = TypeVar("R")
  66. def generate_log_id() -> str:
  67. return str(uuid.uuid4())
  68. # TODO:
  69. # 1. Use Pydantic to validate the URL and Base64-encoded string types for both
  70. # input and output data instead of handling this manually.
  71. # 2. Define a `File` type for global use; this will be part of the contract.
  72. # 3. Consider using two separate fields instead of a union of URL and Base64,
  73. # even though they are both strings. Backward compatibility should be
  74. # maintained.
  75. def is_url(s: str) -> bool:
  76. if not (s.startswith("http://") or s.startswith("https://")):
  77. # Quick rejection
  78. return False
  79. result = urlparse(s)
  80. return all([result.scheme, result.netloc]) and result.scheme in ("http", "https")
  81. def infer_file_type(url: str) -> Optional[FileType]:
  82. url_parts = urlparse(url)
  83. filename = url_parts.path.split("/")[-1]
  84. file_type = mimetypes.guess_type(filename)[0]
  85. if file_type is None:
  86. # HACK: The support for BOS URLs with query params is implementation-based,
  87. # not interface-based.
  88. is_bos_url = re.fullmatch(r"\w+\.bcebos\.com", url_parts.netloc) is not None
  89. if is_bos_url and url_parts.query:
  90. params = parse_qs(url_parts.query)
  91. if (
  92. "responseContentDisposition" in params
  93. and len(params["responseContentDisposition"]) == 1
  94. ):
  95. match_ = re.match(
  96. r"attachment;filename=(.*)", params["responseContentDisposition"][0]
  97. )
  98. if match_:
  99. file_type = mimetypes.guess_type(match_.group(1))[0]
  100. if file_type is None:
  101. return None
  102. if file_type.startswith("image/"):
  103. return "IMAGE"
  104. elif file_type == "application/pdf":
  105. return "PDF"
  106. elif file_type.startswith("video/"):
  107. return "VIDEO"
  108. elif file_type.startswith("audio/"):
  109. return "AUDIO"
  110. else:
  111. return None
  112. @function_requires_deps("filetype")
  113. def infer_file_ext(file: str) -> Optional[str]:
  114. if is_url(file):
  115. url_parts = urlparse(file)
  116. filename = url_parts.path.split("/")[-1]
  117. mime_type = mimetypes.guess_type(filename)[0]
  118. if mime_type is None:
  119. return None
  120. return mimetypes.guess_extension(mime_type)
  121. else:
  122. bytes_ = base64.b64decode(file)
  123. return "." + filetype.guess_extension(bytes_)
  124. @function_requires_deps("opencv-contrib-python")
  125. def image_bytes_to_array(data: bytes) -> np.ndarray:
  126. return cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR)
  127. def image_bytes_to_image(data: bytes) -> Image.Image:
  128. return Image.open(io.BytesIO(data))
  129. def image_to_bytes(image: Image.Image, format: str = "JPEG") -> bytes:
  130. with io.BytesIO() as f:
  131. image.save(f, format=format)
  132. img_bytes = f.getvalue()
  133. return img_bytes
  134. @function_requires_deps("opencv-contrib-python")
  135. def image_array_to_bytes(image: np.ndarray, ext: str = ".jpg") -> bytes:
  136. image = cv2.imencode(ext, image)[1]
  137. return image.tobytes()
  138. def csv_bytes_to_data_frame(data: bytes) -> pd.DataFrame:
  139. with io.StringIO(data.decode("utf-8")) as f:
  140. df = pd.read_csv(f)
  141. return df
  142. def data_frame_to_bytes(df: pd.DataFrame) -> bytes:
  143. return df.to_csv().encode("utf-8")
  144. def base64_encode(data: bytes) -> str:
  145. return base64.b64encode(data).decode("ascii")
  146. _lock = threading.Lock()
  147. @function_requires_deps("pypdfium2", "opencv-contrib-python")
  148. def read_pdf(
  149. bytes_: bytes, max_num_imgs: Optional[int] = None
  150. ) -> Tuple[List[np.ndarray], PDFInfo]:
  151. images: List[np.ndarray] = []
  152. page_info_list: List[PDFPageInfo] = []
  153. with _lock:
  154. doc = pdfium.PdfDocument(bytes_)
  155. try:
  156. for page in doc:
  157. if max_num_imgs is not None and len(images) >= max_num_imgs:
  158. break
  159. # TODO: Do not always use zoom=2.0
  160. zoom = 2.0
  161. deg = 0
  162. image = page.render(scale=zoom, rotation=deg).to_pil()
  163. image = image.convert("RGB")
  164. image = np.array(image)
  165. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  166. images.append(image)
  167. page_info = PDFPageInfo(
  168. width=image.shape[1],
  169. height=image.shape[0],
  170. )
  171. page_info_list.append(page_info)
  172. finally:
  173. doc.close()
  174. pdf_info = PDFInfo(
  175. numPages=len(page_info_list),
  176. pages=page_info_list,
  177. )
  178. return images, pdf_info
  179. @overload
  180. def file_to_images(
  181. file_bytes: bytes,
  182. file_type: Literal["IMAGE"],
  183. *,
  184. max_num_imgs: Optional[int] = ...,
  185. ) -> Tuple[List[np.ndarray], ImageInfo]: ...
  186. @overload
  187. def file_to_images(
  188. file_bytes: bytes,
  189. file_type: Literal["PDF"],
  190. *,
  191. max_num_imgs: Optional[int] = ...,
  192. ) -> Tuple[List[np.ndarray], PDFInfo]: ...
  193. @overload
  194. def file_to_images(
  195. file_bytes: bytes,
  196. file_type: Literal["IMAGE", "PDF"],
  197. *,
  198. max_num_imgs: Optional[int] = ...,
  199. ) -> Union[Tuple[List[np.ndarray], ImageInfo], Tuple[List[np.ndarray], PDFInfo]]: ...
  200. def file_to_images(
  201. file_bytes: bytes,
  202. file_type: Literal["IMAGE", "PDF"],
  203. *,
  204. max_num_imgs: Optional[int] = None,
  205. ) -> Union[Tuple[List[np.ndarray], ImageInfo], Tuple[List[np.ndarray], PDFInfo]]:
  206. if file_type == "IMAGE":
  207. images = [image_bytes_to_array(file_bytes)]
  208. data_info = get_image_info(images[0])
  209. elif file_type == "PDF":
  210. images, data_info = read_pdf(file_bytes, max_num_imgs=max_num_imgs)
  211. else:
  212. assert_never(file_type)
  213. return images, data_info
  214. def get_image_info(image: np.ndarray) -> ImageInfo:
  215. return ImageInfo(width=image.shape[1], height=image.shape[0])
  216. def write_to_temp_file(file_bytes: bytes, suffix: str) -> str:
  217. with tempfile.NamedTemporaryFile("wb", suffix=suffix, delete=False) as f:
  218. f.write(file_bytes)
  219. return f.name
  220. def get_raw_bytes(file: str) -> bytes:
  221. if is_url(file):
  222. resp = requests.get(file, timeout=5)
  223. resp.raise_for_status()
  224. return resp.content
  225. else:
  226. return base64.b64decode(file)
  227. @function_requires_deps("aiohttp", "yarl")
  228. async def get_raw_bytes_async(file: str, session: "aiohttp.ClientSession") -> bytes:
  229. if is_url(file):
  230. async with session.get(yarl.URL(file, encoded=True)) as resp:
  231. return await resp.read()
  232. else:
  233. return base64.b64decode(file)
  234. def call_async(
  235. func: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs
  236. ) -> Awaitable[R]:
  237. return asyncio.get_running_loop().run_in_executor(
  238. None, partial(func, *args, **kwargs)
  239. )