wired_table_rec_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # -*- encoding: utf-8 -*-
  2. import os
  3. import traceback
  4. from enum import Enum
  5. from io import BytesIO
  6. from pathlib import Path
  7. from typing import List, Union, Dict, Any, Tuple
  8. import cv2
  9. import numpy as np
  10. from onnxruntime import (
  11. GraphOptimizationLevel,
  12. InferenceSession,
  13. SessionOptions,
  14. get_available_providers,
  15. )
  16. from PIL import Image, UnidentifiedImageError
  17. root_dir = Path(__file__).resolve().parent
  18. InputType = Union[str, np.ndarray, bytes, Path]
  19. class EP(Enum):
  20. CPU_EP = "CPUExecutionProvider"
  21. class OrtInferSession:
  22. def __init__(self, config: Dict[str, Any]):
  23. model_path = config.get("model_path", None)
  24. self._verify_model(model_path)
  25. self.had_providers: List[str] = get_available_providers()
  26. EP_list = self._get_ep_list()
  27. sess_opt = self._init_sess_opts(config)
  28. self.session = InferenceSession(
  29. model_path,
  30. sess_options=sess_opt,
  31. providers=EP_list,
  32. )
  33. @staticmethod
  34. def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
  35. sess_opt = SessionOptions()
  36. sess_opt.log_severity_level = 4
  37. sess_opt.enable_cpu_mem_arena = False
  38. sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
  39. cpu_nums = os.cpu_count()
  40. intra_op_num_threads = config.get("intra_op_num_threads", -1)
  41. if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
  42. sess_opt.intra_op_num_threads = intra_op_num_threads
  43. inter_op_num_threads = config.get("inter_op_num_threads", -1)
  44. if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
  45. sess_opt.inter_op_num_threads = inter_op_num_threads
  46. return sess_opt
  47. def get_metadata(self, key: str = "character") -> list:
  48. meta_dict = self.session.get_modelmeta().custom_metadata_map
  49. content_list = meta_dict[key].splitlines()
  50. return content_list
  51. def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
  52. cpu_provider_opts = {
  53. "arena_extend_strategy": "kSameAsRequested",
  54. }
  55. EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
  56. return EP_list
  57. def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
  58. input_dict = dict(zip(self.get_input_names(), input_content))
  59. try:
  60. return self.session.run(None, input_dict)
  61. except Exception as e:
  62. error_info = traceback.format_exc()
  63. raise ONNXRuntimeError(error_info) from e
  64. def get_input_names(self) -> List[str]:
  65. return [v.name for v in self.session.get_inputs()]
  66. def get_output_names(self) -> List[str]:
  67. return [v.name for v in self.session.get_outputs()]
  68. def get_character_list(self, key: str = "character") -> List[str]:
  69. meta_dict = self.session.get_modelmeta().custom_metadata_map
  70. return meta_dict[key].splitlines()
  71. def have_key(self, key: str = "character") -> bool:
  72. meta_dict = self.session.get_modelmeta().custom_metadata_map
  73. if key in meta_dict.keys():
  74. return True
  75. return False
  76. @staticmethod
  77. def _verify_model(model_path: Union[str, Path, None]):
  78. if model_path is None:
  79. raise ValueError("model_path is None!")
  80. model_path = Path(model_path)
  81. if not model_path.exists():
  82. raise FileNotFoundError(f"{model_path} does not exists.")
  83. if not model_path.is_file():
  84. raise FileExistsError(f"{model_path} is not a file.")
  85. class ONNXRuntimeError(Exception):
  86. pass
  87. class LoadImage:
  88. def __init__(
  89. self,
  90. ):
  91. pass
  92. def __call__(self, img: InputType) -> np.ndarray:
  93. if not isinstance(img, InputType.__args__):
  94. raise LoadImageError(
  95. f"The img type {type(img)} does not in {InputType.__args__}"
  96. )
  97. img = self.load_img(img)
  98. img = self.convert_img(img)
  99. return img
  100. def load_img(self, img: InputType) -> np.ndarray:
  101. if isinstance(img, (str, Path)):
  102. self.verify_exist(img)
  103. try:
  104. img = np.array(Image.open(img))
  105. except UnidentifiedImageError as e:
  106. raise LoadImageError(f"cannot identify image file {img}") from e
  107. return img
  108. if isinstance(img, bytes):
  109. img = np.array(Image.open(BytesIO(img)))
  110. return img
  111. if isinstance(img, np.ndarray):
  112. return img
  113. raise LoadImageError(f"{type(img)} is not supported!")
  114. def convert_img(self, img: np.ndarray):
  115. if img.ndim == 2:
  116. return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  117. if img.ndim == 3:
  118. channel = img.shape[2]
  119. if channel == 1:
  120. return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  121. if channel == 2:
  122. return self.cvt_two_to_three(img)
  123. if channel == 4:
  124. return self.cvt_four_to_three(img)
  125. if channel == 3:
  126. return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  127. raise LoadImageError(
  128. f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
  129. )
  130. raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
  131. @staticmethod
  132. def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
  133. """RGBA → BGR"""
  134. r, g, b, a = cv2.split(img)
  135. new_img = cv2.merge((b, g, r))
  136. not_a = cv2.bitwise_not(a)
  137. not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  138. new_img = cv2.bitwise_and(new_img, new_img, mask=a)
  139. new_img = cv2.add(new_img, not_a)
  140. return new_img
  141. @staticmethod
  142. def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
  143. """gray + alpha → BGR"""
  144. img_gray = img[..., 0]
  145. img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
  146. img_alpha = img[..., 1]
  147. not_a = cv2.bitwise_not(img_alpha)
  148. not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  149. new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
  150. new_img = cv2.add(new_img, not_a)
  151. return new_img
  152. @staticmethod
  153. def verify_exist(file_path: Union[str, Path]):
  154. if not Path(file_path).exists():
  155. raise LoadImageError(f"{file_path} does not exist.")
  156. class LoadImageError(Exception):
  157. pass
  158. # Pillow >=v9.1.0 use a slightly different naming scheme for filters.
  159. # Set pillow_interp_codes according to the naming scheme used.
  160. if Image is not None:
  161. if hasattr(Image, "Resampling"):
  162. pillow_interp_codes = {
  163. "nearest": Image.Resampling.NEAREST,
  164. "bilinear": Image.Resampling.BILINEAR,
  165. "bicubic": Image.Resampling.BICUBIC,
  166. "box": Image.Resampling.BOX,
  167. "lanczos": Image.Resampling.LANCZOS,
  168. "hamming": Image.Resampling.HAMMING,
  169. }
  170. else:
  171. pillow_interp_codes = {
  172. "nearest": Image.NEAREST,
  173. "bilinear": Image.BILINEAR,
  174. "bicubic": Image.BICUBIC,
  175. "box": Image.BOX,
  176. "lanczos": Image.LANCZOS,
  177. "hamming": Image.HAMMING,
  178. }
  179. cv2_interp_codes = {
  180. "nearest": cv2.INTER_NEAREST,
  181. "bilinear": cv2.INTER_LINEAR,
  182. "bicubic": cv2.INTER_CUBIC,
  183. "area": cv2.INTER_AREA,
  184. "lanczos": cv2.INTER_LANCZOS4,
  185. }
  186. def resize_img(img, scale, keep_ratio=True):
  187. if keep_ratio:
  188. # 缩小使用area更保真
  189. if min(img.shape[:2]) > min(scale):
  190. interpolation = "area"
  191. else:
  192. interpolation = "bicubic" # bilinear
  193. img_new, scale_factor = imrescale(
  194. img, scale, return_scale=True, interpolation=interpolation
  195. )
  196. # the w_scale and h_scale has minor difference
  197. # a real fix should be done in the mmcv.imrescale in the future
  198. new_h, new_w = img_new.shape[:2]
  199. h, w = img.shape[:2]
  200. w_scale = new_w / w
  201. h_scale = new_h / h
  202. else:
  203. img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
  204. return img_new, w_scale, h_scale
  205. def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
  206. """Resize image while keeping the aspect ratio.
  207. Args:
  208. img (ndarray): The input image.
  209. scale (float | tuple[int]): The scaling factor or maximum size.
  210. If it is a float number, then the image will be rescaled by this
  211. factor, else if it is a tuple of 2 integers, then the image will
  212. be rescaled as large as possible within the scale.
  213. return_scale (bool): Whether to return the scaling factor besides the
  214. rescaled image.
  215. interpolation (str): Same as :func:`resize`.
  216. backend (str | None): Same as :func:`resize`.
  217. Returns:
  218. ndarray: The rescaled image.
  219. """
  220. h, w = img.shape[:2]
  221. new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
  222. rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
  223. if return_scale:
  224. return rescaled_img, scale_factor
  225. else:
  226. return rescaled_img
  227. def imresize(
  228. img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
  229. ):
  230. """Resize image to a given size.
  231. Args:
  232. img (ndarray): The input image.
  233. size (tuple[int]): Target size (w, h).
  234. return_scale (bool): Whether to return `w_scale` and `h_scale`.
  235. interpolation (str): Interpolation method, accepted values are
  236. "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
  237. backend, "nearest", "bilinear" for 'pillow' backend.
  238. out (ndarray): The output destination.
  239. backend (str | None): The image resize backend type. Options are `cv2`,
  240. `pillow`, `None`. If backend is None, the global imread_backend
  241. specified by ``mmcv.use_backend()`` will be used. Default: None.
  242. Returns:
  243. tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
  244. `resized_img`.
  245. """
  246. h, w = img.shape[:2]
  247. if backend is None:
  248. backend = "cv2"
  249. if backend not in ["cv2", "pillow"]:
  250. raise ValueError(
  251. f"backend: {backend} is not supported for resize."
  252. f"Supported backends are 'cv2', 'pillow'"
  253. )
  254. if backend == "pillow":
  255. assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
  256. pil_image = Image.fromarray(img)
  257. pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
  258. resized_img = np.array(pil_image)
  259. else:
  260. resized_img = cv2.resize(
  261. img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
  262. )
  263. if not return_scale:
  264. return resized_img
  265. else:
  266. w_scale = size[0] / w
  267. h_scale = size[1] / h
  268. return resized_img, w_scale, h_scale
  269. def rescale_size(old_size, scale, return_scale=False):
  270. """Calculate the new size to be rescaled to.
  271. Args:
  272. old_size (tuple[int]): The old size (w, h) of image.
  273. scale (float | tuple[int]): The scaling factor or maximum size.
  274. If it is a float number, then the image will be rescaled by this
  275. factor, else if it is a tuple of 2 integers, then the image will
  276. be rescaled as large as possible within the scale.
  277. return_scale (bool): Whether to return the scaling factor besides the
  278. rescaled image size.
  279. Returns:
  280. tuple[int]: The new rescaled image size.
  281. """
  282. w, h = old_size
  283. if isinstance(scale, (float, int)):
  284. if scale <= 0:
  285. raise ValueError(f"Invalid scale {scale}, must be positive.")
  286. scale_factor = scale
  287. elif isinstance(scale, tuple):
  288. max_long_edge = max(scale)
  289. max_short_edge = min(scale)
  290. scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
  291. else:
  292. raise TypeError(
  293. f"Scale must be a number or tuple of int, but got {type(scale)}"
  294. )
  295. new_size = _scale_size((w, h), scale_factor)
  296. if return_scale:
  297. return new_size, scale_factor
  298. else:
  299. return new_size
  300. def _scale_size(size, scale):
  301. """Rescale a size by a ratio.
  302. Args:
  303. size (tuple[int]): (w, h).
  304. scale (float | tuple(float)): Scaling factor.
  305. Returns:
  306. tuple[int]: scaled size.
  307. """
  308. if isinstance(scale, (float, int)):
  309. scale = (scale, scale)
  310. w, h = size
  311. return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)