table_stucture_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # -*- encoding: utf-8 -*-
  15. # @Author: SWHL
  16. # @Contact: liekkaskono@163.com
  17. import os
  18. import platform
  19. import traceback
  20. from enum import Enum
  21. from pathlib import Path
  22. from typing import Any, Dict, List, Tuple, Union
  23. import cv2
  24. import numpy as np
  25. from onnxruntime import (
  26. GraphOptimizationLevel,
  27. InferenceSession,
  28. SessionOptions,
  29. get_available_providers,
  30. get_device,
  31. )
  32. from rapid_table.utils import Logger
  33. class EP(Enum):
  34. CPU_EP = "CPUExecutionProvider"
  35. CUDA_EP = "CUDAExecutionProvider"
  36. DIRECTML_EP = "DmlExecutionProvider"
  37. class OrtInferSession:
  38. def __init__(self, config: Dict[str, Any]):
  39. self.logger = Logger(logger_name=__name__).get_log()
  40. model_path = config.get("model_path", None)
  41. self._verify_model(model_path)
  42. self.cfg_use_cuda = config.get("use_cuda", None)
  43. self.cfg_use_dml = config.get("use_dml", None)
  44. self.had_providers: List[str] = get_available_providers()
  45. EP_list = self._get_ep_list()
  46. sess_opt = self._init_sess_opts(config)
  47. self.session = InferenceSession(
  48. model_path,
  49. sess_options=sess_opt,
  50. providers=EP_list,
  51. )
  52. self._verify_providers()
  53. @staticmethod
  54. def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
  55. sess_opt = SessionOptions()
  56. sess_opt.log_severity_level = 4
  57. sess_opt.enable_cpu_mem_arena = False
  58. sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
  59. cpu_nums = os.cpu_count()
  60. intra_op_num_threads = config.get("intra_op_num_threads", -1)
  61. if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
  62. sess_opt.intra_op_num_threads = intra_op_num_threads
  63. inter_op_num_threads = config.get("inter_op_num_threads", -1)
  64. if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
  65. sess_opt.inter_op_num_threads = inter_op_num_threads
  66. return sess_opt
  67. def get_metadata(self, key: str = "character") -> list:
  68. meta_dict = self.session.get_modelmeta().custom_metadata_map
  69. content_list = meta_dict[key].splitlines()
  70. return content_list
  71. def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
  72. cpu_provider_opts = {
  73. "arena_extend_strategy": "kSameAsRequested",
  74. }
  75. EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
  76. cuda_provider_opts = {
  77. "device_id": 0,
  78. "arena_extend_strategy": "kNextPowerOfTwo",
  79. "cudnn_conv_algo_search": "EXHAUSTIVE",
  80. "do_copy_in_default_stream": True,
  81. }
  82. self.use_cuda = self._check_cuda()
  83. if self.use_cuda:
  84. EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
  85. self.use_directml = self._check_dml()
  86. if self.use_directml:
  87. self.logger.info(
  88. "Windows 10 or above detected, try to use DirectML as primary provider"
  89. )
  90. directml_options = (
  91. cuda_provider_opts if self.use_cuda else cpu_provider_opts
  92. )
  93. EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
  94. return EP_list
  95. def _check_cuda(self) -> bool:
  96. if not self.cfg_use_cuda:
  97. return False
  98. cur_device = get_device()
  99. if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
  100. return True
  101. self.logger.warning(
  102. "%s is not in available providers (%s). Use %s inference by default.",
  103. EP.CUDA_EP.value,
  104. self.had_providers,
  105. self.had_providers[0],
  106. )
  107. self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
  108. self.logger.info(
  109. "(For reference only) If you want to use GPU acceleration, you must do:"
  110. )
  111. self.logger.info(
  112. "First, uninstall all onnxruntime pakcages in current environment."
  113. )
  114. self.logger.info(
  115. "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
  116. )
  117. self.logger.info(
  118. "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
  119. )
  120. self.logger.info(
  121. "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
  122. )
  123. self.logger.info(
  124. "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
  125. EP.CUDA_EP.value,
  126. )
  127. return False
  128. def _check_dml(self) -> bool:
  129. if not self.cfg_use_dml:
  130. return False
  131. cur_os = platform.system()
  132. if cur_os != "Windows":
  133. self.logger.warning(
  134. "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
  135. cur_os,
  136. self.had_providers[0],
  137. )
  138. return False
  139. cur_window_version = int(platform.release().split(".")[0])
  140. if cur_window_version < 10:
  141. self.logger.warning(
  142. "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
  143. cur_window_version,
  144. self.had_providers[0],
  145. )
  146. return False
  147. if EP.DIRECTML_EP.value in self.had_providers:
  148. return True
  149. self.logger.warning(
  150. "%s is not in available providers (%s). Use %s inference by default.",
  151. EP.DIRECTML_EP.value,
  152. self.had_providers,
  153. self.had_providers[0],
  154. )
  155. self.logger.info("If you want to use DirectML acceleration, you must do:")
  156. self.logger.info(
  157. "First, uninstall all onnxruntime pakcages in current environment."
  158. )
  159. self.logger.info(
  160. "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
  161. )
  162. self.logger.info(
  163. "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
  164. EP.DIRECTML_EP.value,
  165. )
  166. return False
  167. def _verify_providers(self):
  168. session_providers = self.session.get_providers()
  169. first_provider = session_providers[0]
  170. if self.use_cuda and first_provider != EP.CUDA_EP.value:
  171. self.logger.warning(
  172. "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
  173. EP.CUDA_EP.value,
  174. first_provider,
  175. )
  176. if self.use_directml and first_provider != EP.DIRECTML_EP.value:
  177. self.logger.warning(
  178. "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
  179. EP.DIRECTML_EP.value,
  180. first_provider,
  181. )
  182. def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
  183. input_dict = dict(zip(self.get_input_names(), input_content))
  184. try:
  185. return self.session.run(None, input_dict)
  186. except Exception as e:
  187. error_info = traceback.format_exc()
  188. raise ONNXRuntimeError(error_info) from e
  189. def get_input_names(self) -> List[str]:
  190. return [v.name for v in self.session.get_inputs()]
  191. def get_output_names(self) -> List[str]:
  192. return [v.name for v in self.session.get_outputs()]
  193. def get_character_list(self, key: str = "character") -> List[str]:
  194. meta_dict = self.session.get_modelmeta().custom_metadata_map
  195. return meta_dict[key].splitlines()
  196. def have_key(self, key: str = "character") -> bool:
  197. meta_dict = self.session.get_modelmeta().custom_metadata_map
  198. if key in meta_dict.keys():
  199. return True
  200. return False
  201. @staticmethod
  202. def _verify_model(model_path: Union[str, Path, None]):
  203. if model_path is None:
  204. raise ValueError("model_path is None!")
  205. model_path = Path(model_path)
  206. if not model_path.exists():
  207. raise FileNotFoundError(f"{model_path} does not exists.")
  208. if not model_path.is_file():
  209. raise FileExistsError(f"{model_path} is not a file.")
  210. class ONNXRuntimeError(Exception):
  211. pass
  212. class TableLabelDecode:
  213. def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
  214. if merge_no_span_structure:
  215. if "<td></td>" not in dict_character:
  216. dict_character.append("<td></td>")
  217. if "<td>" in dict_character:
  218. dict_character.remove("<td>")
  219. dict_character = self.add_special_char(dict_character)
  220. self.dict = {}
  221. for i, char in enumerate(dict_character):
  222. self.dict[char] = i
  223. self.character = dict_character
  224. self.td_token = ["<td>", "<td", "<td></td>"]
  225. def __call__(self, preds, batch=None):
  226. structure_probs = preds["structure_probs"]
  227. bbox_preds = preds["loc_preds"]
  228. shape_list = batch[-1]
  229. result = self.decode(structure_probs, bbox_preds, shape_list)
  230. if len(batch) == 1: # only contains shape
  231. return result
  232. label_decode_result = self.decode_label(batch)
  233. return result, label_decode_result
  234. def decode(self, structure_probs, bbox_preds, shape_list):
  235. """convert text-label into text-index."""
  236. ignored_tokens = self.get_ignored_tokens()
  237. end_idx = self.dict[self.end_str]
  238. structure_idx = structure_probs.argmax(axis=2)
  239. structure_probs = structure_probs.max(axis=2)
  240. structure_batch_list = []
  241. bbox_batch_list = []
  242. batch_size = len(structure_idx)
  243. for batch_idx in range(batch_size):
  244. structure_list = []
  245. bbox_list = []
  246. score_list = []
  247. for idx in range(len(structure_idx[batch_idx])):
  248. char_idx = int(structure_idx[batch_idx][idx])
  249. if idx > 0 and char_idx == end_idx:
  250. break
  251. if char_idx in ignored_tokens:
  252. continue
  253. text = self.character[char_idx]
  254. if text in self.td_token:
  255. bbox = bbox_preds[batch_idx, idx]
  256. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  257. bbox_list.append(bbox)
  258. structure_list.append(text)
  259. score_list.append(structure_probs[batch_idx, idx])
  260. structure_batch_list.append([structure_list, np.mean(score_list)])
  261. bbox_batch_list.append(np.array(bbox_list))
  262. result = {
  263. "bbox_batch_list": bbox_batch_list,
  264. "structure_batch_list": structure_batch_list,
  265. }
  266. return result
  267. def decode_label(self, batch):
  268. """convert text-label into text-index."""
  269. structure_idx = batch[1]
  270. gt_bbox_list = batch[2]
  271. shape_list = batch[-1]
  272. ignored_tokens = self.get_ignored_tokens()
  273. end_idx = self.dict[self.end_str]
  274. structure_batch_list = []
  275. bbox_batch_list = []
  276. batch_size = len(structure_idx)
  277. for batch_idx in range(batch_size):
  278. structure_list = []
  279. bbox_list = []
  280. for idx in range(len(structure_idx[batch_idx])):
  281. char_idx = int(structure_idx[batch_idx][idx])
  282. if idx > 0 and char_idx == end_idx:
  283. break
  284. if char_idx in ignored_tokens:
  285. continue
  286. structure_list.append(self.character[char_idx])
  287. bbox = gt_bbox_list[batch_idx][idx]
  288. if bbox.sum() != 0:
  289. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  290. bbox_list.append(bbox)
  291. structure_batch_list.append(structure_list)
  292. bbox_batch_list.append(bbox_list)
  293. result = {
  294. "bbox_batch_list": bbox_batch_list,
  295. "structure_batch_list": structure_batch_list,
  296. }
  297. return result
  298. def _bbox_decode(self, bbox, shape):
  299. h, w = shape[:2]
  300. bbox[0::2] *= w
  301. bbox[1::2] *= h
  302. return bbox
  303. def get_ignored_tokens(self):
  304. beg_idx = self.get_beg_end_flag_idx("beg")
  305. end_idx = self.get_beg_end_flag_idx("end")
  306. return [beg_idx, end_idx]
  307. def get_beg_end_flag_idx(self, beg_or_end):
  308. if beg_or_end == "beg":
  309. return np.array(self.dict[self.beg_str])
  310. if beg_or_end == "end":
  311. return np.array(self.dict[self.end_str])
  312. raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
  313. def add_special_char(self, dict_character):
  314. self.beg_str = "sos"
  315. self.end_str = "eos"
  316. dict_character = [self.beg_str] + dict_character + [self.end_str]
  317. return dict_character
  318. class TablePreprocess:
  319. def __init__(self):
  320. self.table_max_len = 488
  321. self.build_pre_process_list()
  322. self.ops = self.create_operators()
  323. def __call__(self, data):
  324. """transform"""
  325. if self.ops is None:
  326. self.ops = []
  327. for op in self.ops:
  328. data = op(data)
  329. if data is None:
  330. return None
  331. return data
  332. def create_operators(
  333. self,
  334. ):
  335. """
  336. create operators based on the config
  337. Args:
  338. params(list): a dict list, used to create some operators
  339. """
  340. assert isinstance(
  341. self.pre_process_list, list
  342. ), "operator config should be a list"
  343. ops = []
  344. for operator in self.pre_process_list:
  345. assert (
  346. isinstance(operator, dict) and len(operator) == 1
  347. ), "yaml format error"
  348. op_name = list(operator)[0]
  349. param = {} if operator[op_name] is None else operator[op_name]
  350. op = eval(op_name)(**param)
  351. ops.append(op)
  352. return ops
  353. def build_pre_process_list(self):
  354. resize_op = {
  355. "ResizeTableImage": {
  356. "max_len": self.table_max_len,
  357. }
  358. }
  359. pad_op = {
  360. "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
  361. }
  362. normalize_op = {
  363. "NormalizeImage": {
  364. "std": [0.229, 0.224, 0.225],
  365. "mean": [0.485, 0.456, 0.406],
  366. "scale": "1./255.",
  367. "order": "hwc",
  368. }
  369. }
  370. to_chw_op = {"ToCHWImage": None}
  371. keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
  372. self.pre_process_list = [
  373. resize_op,
  374. normalize_op,
  375. pad_op,
  376. to_chw_op,
  377. keep_keys_op,
  378. ]
  379. class ResizeTableImage:
  380. def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
  381. super(ResizeTableImage, self).__init__()
  382. self.max_len = max_len
  383. self.resize_bboxes = resize_bboxes
  384. self.infer_mode = infer_mode
  385. def __call__(self, data):
  386. img = data["image"]
  387. height, width = img.shape[0:2]
  388. ratio = self.max_len / (max(height, width) * 1.0)
  389. resize_h = int(height * ratio)
  390. resize_w = int(width * ratio)
  391. resize_img = cv2.resize(img, (resize_w, resize_h))
  392. if self.resize_bboxes and not self.infer_mode:
  393. data["bboxes"] = data["bboxes"] * ratio
  394. data["image"] = resize_img
  395. data["src_img"] = img
  396. data["shape"] = np.array([height, width, ratio, ratio])
  397. data["max_len"] = self.max_len
  398. return data
  399. class PaddingTableImage:
  400. def __init__(self, size, **kwargs):
  401. super(PaddingTableImage, self).__init__()
  402. self.size = size
  403. def __call__(self, data):
  404. img = data["image"]
  405. pad_h, pad_w = self.size
  406. padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
  407. height, width = img.shape[0:2]
  408. padding_img[0:height, 0:width, :] = img.copy()
  409. data["image"] = padding_img
  410. shape = data["shape"].tolist()
  411. shape.extend([pad_h, pad_w])
  412. data["shape"] = np.array(shape)
  413. return data
  414. class NormalizeImage:
  415. """normalize image such as substract mean, divide std"""
  416. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  417. if isinstance(scale, str):
  418. scale = eval(scale)
  419. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  420. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  421. std = std if std is not None else [0.229, 0.224, 0.225]
  422. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  423. self.mean = np.array(mean).reshape(shape).astype("float32")
  424. self.std = np.array(std).reshape(shape).astype("float32")
  425. def __call__(self, data):
  426. img = np.array(data["image"])
  427. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  428. data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
  429. return data
  430. class ToCHWImage:
  431. """convert hwc image to chw image"""
  432. def __init__(self, **kwargs):
  433. pass
  434. def __call__(self, data):
  435. img = np.array(data["image"])
  436. data["image"] = img.transpose((2, 0, 1))
  437. return data
  438. class KeepKeys:
  439. def __init__(self, keep_keys, **kwargs):
  440. self.keep_keys = keep_keys
  441. def __call__(self, data):
  442. data_list = []
  443. for key in self.keep_keys:
  444. data_list.append(data[key])
  445. return data_list
  446. def trans_char_ocr_res(ocr_res):
  447. word_result = []
  448. for res in ocr_res:
  449. score = res[2]
  450. for word_box, word in zip(res[3], res[4]):
  451. word_res = []
  452. word_res.append(word_box)
  453. word_res.append(word)
  454. word_res.append(score)
  455. word_result.append(word_res)
  456. return word_result