utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. import os
  2. import traceback
  3. from enum import Enum
  4. from io import BytesIO
  5. from pathlib import Path
  6. from typing import List, Union, Dict, Any, Tuple, Optional
  7. import cv2
  8. import loguru
  9. import numpy as np
  10. from onnxruntime import (
  11. GraphOptimizationLevel,
  12. InferenceSession,
  13. SessionOptions,
  14. get_available_providers,
  15. )
  16. from PIL import Image, UnidentifiedImageError
  17. root_dir = Path(__file__).resolve().parent
  18. InputType = Union[str, np.ndarray, bytes, Path]
  19. class EP(Enum):
  20. CPU_EP = "CPUExecutionProvider"
  21. class OrtInferSession:
  22. def __init__(self, config: Dict[str, Any]):
  23. self.logger = loguru.logger
  24. model_path = config.get("model_path", None)
  25. self.had_providers: List[str] = get_available_providers()
  26. EP_list = self._get_ep_list()
  27. sess_opt = self._init_sess_opts(config)
  28. self.session = InferenceSession(
  29. model_path,
  30. sess_options=sess_opt,
  31. providers=EP_list,
  32. )
  33. @staticmethod
  34. def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
  35. sess_opt = SessionOptions()
  36. sess_opt.log_severity_level = 4
  37. sess_opt.enable_cpu_mem_arena = False
  38. sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
  39. cpu_nums = os.cpu_count()
  40. intra_op_num_threads = config.get("intra_op_num_threads", -1)
  41. if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
  42. sess_opt.intra_op_num_threads = intra_op_num_threads
  43. inter_op_num_threads = config.get("inter_op_num_threads", -1)
  44. if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
  45. sess_opt.inter_op_num_threads = inter_op_num_threads
  46. return sess_opt
  47. def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
  48. cpu_provider_opts = {
  49. "arena_extend_strategy": "kSameAsRequested",
  50. }
  51. EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
  52. return EP_list
  53. def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
  54. input_dict = dict(zip(self.get_input_names(), input_content))
  55. try:
  56. return self.session.run(None, input_dict)
  57. except Exception as e:
  58. error_info = traceback.format_exc()
  59. raise ONNXRuntimeError(error_info) from e
  60. def get_input_names(self) -> List[str]:
  61. return [v.name for v in self.session.get_inputs()]
  62. class ONNXRuntimeError(Exception):
  63. pass
  64. class LoadImage:
  65. def __init__(
  66. self,
  67. ):
  68. pass
  69. def __call__(self, img: InputType) -> np.ndarray:
  70. if not isinstance(img, InputType.__args__):
  71. raise LoadImageError(
  72. f"The img type {type(img)} does not in {InputType.__args__}"
  73. )
  74. img = self.load_img(img)
  75. img = self.convert_img(img)
  76. return img
  77. def load_img(self, img: InputType) -> np.ndarray:
  78. if isinstance(img, (str, Path)):
  79. self.verify_exist(img)
  80. try:
  81. img = np.array(Image.open(img))
  82. except UnidentifiedImageError as e:
  83. raise LoadImageError(f"cannot identify image file {img}") from e
  84. return img
  85. if isinstance(img, bytes):
  86. img = np.array(Image.open(BytesIO(img)))
  87. return img
  88. if isinstance(img, np.ndarray):
  89. return img
  90. raise LoadImageError(f"{type(img)} is not supported!")
  91. def convert_img(self, img: np.ndarray):
  92. if img.ndim == 2:
  93. return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  94. if img.ndim == 3:
  95. channel = img.shape[2]
  96. if channel == 1:
  97. return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  98. if channel == 2:
  99. return self.cvt_two_to_three(img)
  100. if channel == 4:
  101. return self.cvt_four_to_three(img)
  102. if channel == 3:
  103. return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  104. raise LoadImageError(
  105. f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
  106. )
  107. raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
  108. @staticmethod
  109. def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
  110. """RGBA → BGR"""
  111. r, g, b, a = cv2.split(img)
  112. new_img = cv2.merge((b, g, r))
  113. not_a = cv2.bitwise_not(a)
  114. not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  115. new_img = cv2.bitwise_and(new_img, new_img, mask=a)
  116. new_img = cv2.add(new_img, not_a)
  117. return new_img
  118. @staticmethod
  119. def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
  120. """gray + alpha → BGR"""
  121. img_gray = img[..., 0]
  122. img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
  123. img_alpha = img[..., 1]
  124. not_a = cv2.bitwise_not(img_alpha)
  125. not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  126. new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
  127. new_img = cv2.add(new_img, not_a)
  128. return new_img
  129. @staticmethod
  130. def verify_exist(file_path: Union[str, Path]):
  131. if not Path(file_path).exists():
  132. raise LoadImageError(f"{file_path} does not exist.")
  133. class LoadImageError(Exception):
  134. pass
  135. # Pillow >=v9.1.0 use a slightly different naming scheme for filters.
  136. # Set pillow_interp_codes according to the naming scheme used.
  137. if Image is not None:
  138. if hasattr(Image, "Resampling"):
  139. pillow_interp_codes = {
  140. "nearest": Image.Resampling.NEAREST,
  141. "bilinear": Image.Resampling.BILINEAR,
  142. "bicubic": Image.Resampling.BICUBIC,
  143. "box": Image.Resampling.BOX,
  144. "lanczos": Image.Resampling.LANCZOS,
  145. "hamming": Image.Resampling.HAMMING,
  146. }
  147. else:
  148. pillow_interp_codes = {
  149. "nearest": Image.NEAREST,
  150. "bilinear": Image.BILINEAR,
  151. "bicubic": Image.BICUBIC,
  152. "box": Image.BOX,
  153. "lanczos": Image.LANCZOS,
  154. "hamming": Image.HAMMING,
  155. }
  156. cv2_interp_codes = {
  157. "nearest": cv2.INTER_NEAREST,
  158. "bilinear": cv2.INTER_LINEAR,
  159. "bicubic": cv2.INTER_CUBIC,
  160. "area": cv2.INTER_AREA,
  161. "lanczos": cv2.INTER_LANCZOS4,
  162. }
  163. def resize_img(img, scale, keep_ratio=True):
  164. if keep_ratio:
  165. # 缩小使用area更保真
  166. if min(img.shape[:2]) > min(scale):
  167. interpolation = "area"
  168. else:
  169. interpolation = "bicubic" # bilinear
  170. img_new, scale_factor = imrescale(
  171. img, scale, return_scale=True, interpolation=interpolation
  172. )
  173. # the w_scale and h_scale has minor difference
  174. # a real fix should be done in the mmcv.imrescale in the future
  175. new_h, new_w = img_new.shape[:2]
  176. h, w = img.shape[:2]
  177. w_scale = new_w / w
  178. h_scale = new_h / h
  179. else:
  180. img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
  181. return img_new, w_scale, h_scale
  182. def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
  183. """Resize image while keeping the aspect ratio.
  184. Args:
  185. img (ndarray): The input image.
  186. scale (float | tuple[int]): The scaling factor or maximum size.
  187. If it is a float number, then the image will be rescaled by this
  188. factor, else if it is a tuple of 2 integers, then the image will
  189. be rescaled as large as possible within the scale.
  190. return_scale (bool): Whether to return the scaling factor besides the
  191. rescaled image.
  192. interpolation (str): Same as :func:`resize`.
  193. backend (str | None): Same as :func:`resize`.
  194. Returns:
  195. ndarray: The rescaled image.
  196. """
  197. h, w = img.shape[:2]
  198. new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
  199. rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
  200. if return_scale:
  201. return rescaled_img, scale_factor
  202. else:
  203. return rescaled_img
  204. def imresize(
  205. img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
  206. ):
  207. """Resize image to a given size.
  208. Args:
  209. img (ndarray): The input image.
  210. size (tuple[int]): Target size (w, h).
  211. return_scale (bool): Whether to return `w_scale` and `h_scale`.
  212. interpolation (str): Interpolation method, accepted values are
  213. "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
  214. backend, "nearest", "bilinear" for 'pillow' backend.
  215. out (ndarray): The output destination.
  216. backend (str | None): The image resize backend type. Options are `cv2`,
  217. `pillow`, `None`. If backend is None, the global imread_backend
  218. specified by ``mmcv.use_backend()`` will be used. Default: None.
  219. Returns:
  220. tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
  221. `resized_img`.
  222. """
  223. h, w = img.shape[:2]
  224. if backend is None:
  225. backend = "cv2"
  226. if backend not in ["cv2", "pillow"]:
  227. raise ValueError(
  228. f"backend: {backend} is not supported for resize."
  229. f"Supported backends are 'cv2', 'pillow'"
  230. )
  231. if backend == "pillow":
  232. assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
  233. pil_image = Image.fromarray(img)
  234. pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
  235. resized_img = np.array(pil_image)
  236. else:
  237. resized_img = cv2.resize(
  238. img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
  239. )
  240. if not return_scale:
  241. return resized_img
  242. else:
  243. w_scale = size[0] / w
  244. h_scale = size[1] / h
  245. return resized_img, w_scale, h_scale
  246. def rescale_size(old_size, scale, return_scale=False):
  247. """Calculate the new size to be rescaled to.
  248. Args:
  249. old_size (tuple[int]): The old size (w, h) of image.
  250. scale (float | tuple[int]): The scaling factor or maximum size.
  251. If it is a float number, then the image will be rescaled by this
  252. factor, else if it is a tuple of 2 integers, then the image will
  253. be rescaled as large as possible within the scale.
  254. return_scale (bool): Whether to return the scaling factor besides the
  255. rescaled image size.
  256. Returns:
  257. tuple[int]: The new rescaled image size.
  258. """
  259. w, h = old_size
  260. if isinstance(scale, (float, int)):
  261. if scale <= 0:
  262. raise ValueError(f"Invalid scale {scale}, must be positive.")
  263. scale_factor = scale
  264. elif isinstance(scale, tuple):
  265. max_long_edge = max(scale)
  266. max_short_edge = min(scale)
  267. scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
  268. else:
  269. raise TypeError(
  270. f"Scale must be a number or tuple of int, but got {type(scale)}"
  271. )
  272. new_size = _scale_size((w, h), scale_factor)
  273. if return_scale:
  274. return new_size, scale_factor
  275. else:
  276. return new_size
  277. def _scale_size(size, scale):
  278. """Rescale a size by a ratio.
  279. Args:
  280. size (tuple[int]): (w, h).
  281. scale (float | tuple(float)): Scaling factor.
  282. Returns:
  283. tuple[int]: scaled size.
  284. """
  285. if isinstance(scale, (float, int)):
  286. scale = (scale, scale)
  287. w, h = size
  288. return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
  289. class VisTable:
  290. def __init__(self):
  291. self.load_img = LoadImage()
  292. def __call__(
  293. self,
  294. img_path: Union[str, Path],
  295. table_results,
  296. save_html_path: Optional[Union[str, Path]] = None,
  297. save_drawed_path: Optional[Union[str, Path]] = None,
  298. save_logic_path: Optional[Union[str, Path]] = None,
  299. ):
  300. if save_html_path:
  301. html_with_border = self.insert_border_style(table_results.pred_html)
  302. self.save_html(save_html_path, html_with_border)
  303. table_cell_bboxes = table_results.cell_bboxes
  304. table_logic_points = table_results.logic_points
  305. if table_cell_bboxes is None:
  306. return None
  307. img = self.load_img(img_path)
  308. dims_bboxes = table_cell_bboxes.shape[1]
  309. if dims_bboxes == 4:
  310. drawed_img = self.draw_rectangle(img, table_cell_bboxes)
  311. elif dims_bboxes == 8:
  312. drawed_img = self.draw_polylines(img, table_cell_bboxes)
  313. else:
  314. raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
  315. if save_drawed_path:
  316. self.save_img(save_drawed_path, drawed_img)
  317. if save_logic_path:
  318. polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
  319. self.plot_rec_box_with_logic_info(
  320. img, save_logic_path, table_logic_points, polygons
  321. )
  322. return drawed_img
  323. def insert_border_style(self, table_html_str: str):
  324. style_res = """<meta charset="UTF-8"><style>
  325. table {
  326. border-collapse: collapse;
  327. width: 100%;
  328. }
  329. th, td {
  330. border: 1px solid black;
  331. padding: 8px;
  332. text-align: center;
  333. }
  334. th {
  335. background-color: #f2f2f2;
  336. }
  337. </style>"""
  338. prefix_table, suffix_table = table_html_str.split("<body>")
  339. html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
  340. return html_with_border
  341. def plot_rec_box_with_logic_info(
  342. self, img, output_path, logic_points, sorted_polygons
  343. ):
  344. """
  345. :param img_path
  346. :param output_path
  347. :param logic_points: [row_start,row_end,col_start,col_end]
  348. :param sorted_polygons: [xmin,ymin,xmax,ymax]
  349. :return:
  350. """
  351. # 读取原图
  352. img = cv2.copyMakeBorder(
  353. img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
  354. )
  355. # 绘制 polygons 矩形
  356. for idx, polygon in enumerate(sorted_polygons):
  357. x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
  358. x0 = round(x0)
  359. y0 = round(y0)
  360. x1 = round(x1)
  361. y1 = round(y1)
  362. cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
  363. # 增大字体大小和线宽
  364. font_scale = 0.9 # 原先是0.5
  365. thickness = 1 # 原先是1
  366. logic_point = logic_points[idx]
  367. cv2.putText(
  368. img,
  369. f"row: {logic_point[0]}-{logic_point[1]}",
  370. (x0 + 3, y0 + 8),
  371. cv2.FONT_HERSHEY_PLAIN,
  372. font_scale,
  373. (0, 0, 255),
  374. thickness,
  375. )
  376. cv2.putText(
  377. img,
  378. f"col: {logic_point[2]}-{logic_point[3]}",
  379. (x0 + 3, y0 + 18),
  380. cv2.FONT_HERSHEY_PLAIN,
  381. font_scale,
  382. (0, 0, 255),
  383. thickness,
  384. )
  385. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  386. # 保存绘制后的图像
  387. self.save_img(output_path, img)
  388. @staticmethod
  389. def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
  390. img_copy = img.copy()
  391. for box in boxes.astype(int):
  392. x1, y1, x2, y2 = box
  393. cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
  394. return img_copy
  395. @staticmethod
  396. def draw_polylines(img: np.ndarray, points) -> np.ndarray:
  397. img_copy = img.copy()
  398. for point in points.astype(int):
  399. point = point.reshape(4, 2)
  400. cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
  401. return img_copy
  402. @staticmethod
  403. def save_img(save_path: Union[str, Path], img: np.ndarray):
  404. cv2.imwrite(str(save_path), img)
  405. @staticmethod
  406. def save_html(save_path: Union[str, Path], html: str):
  407. with open(save_path, "w", encoding="utf-8") as f:
  408. f.write(html)