table_stucture_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import platform
  16. import traceback
  17. from enum import Enum
  18. from pathlib import Path
  19. from typing import Any, Dict, List, Tuple, Union
  20. import cv2
  21. import numpy as np
  22. from onnxruntime import (
  23. GraphOptimizationLevel,
  24. InferenceSession,
  25. SessionOptions,
  26. get_available_providers,
  27. get_device,
  28. )
  29. from loguru import logger
  30. class EP(Enum):
  31. CPU_EP = "CPUExecutionProvider"
  32. CUDA_EP = "CUDAExecutionProvider"
  33. DIRECTML_EP = "DmlExecutionProvider"
  34. class OrtInferSession:
  35. def __init__(self, config: Dict[str, Any]):
  36. self.logger = logger
  37. model_path = config.get("model_path", None)
  38. self._verify_model(model_path)
  39. self.cfg_use_cuda = config.get("use_cuda", None)
  40. self.cfg_use_dml = config.get("use_dml", None)
  41. self.had_providers: List[str] = get_available_providers()
  42. EP_list = self._get_ep_list()
  43. sess_opt = self._init_sess_opts(config)
  44. self.session = InferenceSession(
  45. model_path,
  46. sess_options=sess_opt,
  47. providers=EP_list,
  48. )
  49. self._verify_providers()
  50. @staticmethod
  51. def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
  52. sess_opt = SessionOptions()
  53. sess_opt.log_severity_level = 4
  54. sess_opt.enable_cpu_mem_arena = False
  55. sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
  56. cpu_nums = os.cpu_count()
  57. intra_op_num_threads = config.get("intra_op_num_threads", -1)
  58. if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
  59. sess_opt.intra_op_num_threads = intra_op_num_threads
  60. inter_op_num_threads = config.get("inter_op_num_threads", -1)
  61. if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
  62. sess_opt.inter_op_num_threads = inter_op_num_threads
  63. return sess_opt
  64. def get_metadata(self, key: str = "character") -> list:
  65. meta_dict = self.session.get_modelmeta().custom_metadata_map
  66. content_list = meta_dict[key].splitlines()
  67. return content_list
  68. def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
  69. cpu_provider_opts = {
  70. "arena_extend_strategy": "kSameAsRequested",
  71. }
  72. EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
  73. cuda_provider_opts = {
  74. "device_id": 0,
  75. "arena_extend_strategy": "kNextPowerOfTwo",
  76. "cudnn_conv_algo_search": "EXHAUSTIVE",
  77. "do_copy_in_default_stream": True,
  78. }
  79. self.use_cuda = self._check_cuda()
  80. if self.use_cuda:
  81. EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
  82. self.use_directml = self._check_dml()
  83. if self.use_directml:
  84. self.logger.info(
  85. "Windows 10 or above detected, try to use DirectML as primary provider"
  86. )
  87. directml_options = (
  88. cuda_provider_opts if self.use_cuda else cpu_provider_opts
  89. )
  90. EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
  91. return EP_list
  92. def _check_cuda(self) -> bool:
  93. if not self.cfg_use_cuda:
  94. return False
  95. cur_device = get_device()
  96. if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
  97. return True
  98. self.logger.warning(
  99. "%s is not in available providers (%s). Use %s inference by default.",
  100. EP.CUDA_EP.value,
  101. self.had_providers,
  102. self.had_providers[0],
  103. )
  104. self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
  105. self.logger.info(
  106. "(For reference only) If you want to use GPU acceleration, you must do:"
  107. )
  108. self.logger.info(
  109. "First, uninstall all onnxruntime pakcages in current environment."
  110. )
  111. self.logger.info(
  112. "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
  113. )
  114. self.logger.info(
  115. "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
  116. )
  117. self.logger.info(
  118. "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
  119. )
  120. self.logger.info(
  121. "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
  122. EP.CUDA_EP.value,
  123. )
  124. return False
  125. def _check_dml(self) -> bool:
  126. if not self.cfg_use_dml:
  127. return False
  128. cur_os = platform.system()
  129. if cur_os != "Windows":
  130. self.logger.warning(
  131. "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
  132. cur_os,
  133. self.had_providers[0],
  134. )
  135. return False
  136. cur_window_version = int(platform.release().split(".")[0])
  137. if cur_window_version < 10:
  138. self.logger.warning(
  139. "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
  140. cur_window_version,
  141. self.had_providers[0],
  142. )
  143. return False
  144. if EP.DIRECTML_EP.value in self.had_providers:
  145. return True
  146. self.logger.warning(
  147. "%s is not in available providers (%s). Use %s inference by default.",
  148. EP.DIRECTML_EP.value,
  149. self.had_providers,
  150. self.had_providers[0],
  151. )
  152. self.logger.info("If you want to use DirectML acceleration, you must do:")
  153. self.logger.info(
  154. "First, uninstall all onnxruntime pakcages in current environment."
  155. )
  156. self.logger.info(
  157. "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
  158. )
  159. self.logger.info(
  160. "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
  161. EP.DIRECTML_EP.value,
  162. )
  163. return False
  164. def _verify_providers(self):
  165. session_providers = self.session.get_providers()
  166. first_provider = session_providers[0]
  167. if self.use_cuda and first_provider != EP.CUDA_EP.value:
  168. self.logger.warning(
  169. "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
  170. EP.CUDA_EP.value,
  171. first_provider,
  172. )
  173. if self.use_directml and first_provider != EP.DIRECTML_EP.value:
  174. self.logger.warning(
  175. "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
  176. EP.DIRECTML_EP.value,
  177. first_provider,
  178. )
  179. def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
  180. input_dict = dict(zip(self.get_input_names(), input_content))
  181. try:
  182. return self.session.run(None, input_dict)
  183. except Exception as e:
  184. error_info = traceback.format_exc()
  185. raise ONNXRuntimeError(error_info) from e
  186. def get_input_names(self) -> List[str]:
  187. return [v.name for v in self.session.get_inputs()]
  188. def get_output_names(self) -> List[str]:
  189. return [v.name for v in self.session.get_outputs()]
  190. def get_character_list(self, key: str = "character") -> List[str]:
  191. meta_dict = self.session.get_modelmeta().custom_metadata_map
  192. return meta_dict[key].splitlines()
  193. def have_key(self, key: str = "character") -> bool:
  194. meta_dict = self.session.get_modelmeta().custom_metadata_map
  195. if key in meta_dict.keys():
  196. return True
  197. return False
  198. @staticmethod
  199. def _verify_model(model_path: Union[str, Path, None]):
  200. if model_path is None:
  201. raise ValueError("model_path is None!")
  202. model_path = Path(model_path)
  203. if not model_path.exists():
  204. raise FileNotFoundError(f"{model_path} does not exists.")
  205. if not model_path.is_file():
  206. raise FileExistsError(f"{model_path} is not a file.")
  207. class ONNXRuntimeError(Exception):
  208. pass
  209. class TableLabelDecode:
  210. def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
  211. if merge_no_span_structure:
  212. if "<td></td>" not in dict_character:
  213. dict_character.append("<td></td>")
  214. if "<td>" in dict_character:
  215. dict_character.remove("<td>")
  216. dict_character = self.add_special_char(dict_character)
  217. self.dict = {}
  218. for i, char in enumerate(dict_character):
  219. self.dict[char] = i
  220. self.character = dict_character
  221. self.td_token = ["<td>", "<td", "<td></td>"]
  222. def __call__(self, preds, batch=None):
  223. structure_probs = preds["structure_probs"]
  224. bbox_preds = preds["loc_preds"]
  225. shape_list = batch[-1]
  226. result = self.decode(structure_probs, bbox_preds, shape_list)
  227. if len(batch) == 1: # only contains shape
  228. return result
  229. label_decode_result = self.decode_label(batch)
  230. return result, label_decode_result
  231. def decode(self, structure_probs, bbox_preds, shape_list):
  232. """convert text-label into text-index."""
  233. ignored_tokens = self.get_ignored_tokens()
  234. end_idx = self.dict[self.end_str]
  235. structure_idx = structure_probs.argmax(axis=2)
  236. structure_probs = structure_probs.max(axis=2)
  237. structure_batch_list = []
  238. bbox_batch_list = []
  239. batch_size = len(structure_idx)
  240. for batch_idx in range(batch_size):
  241. structure_list = []
  242. bbox_list = []
  243. score_list = []
  244. for idx in range(len(structure_idx[batch_idx])):
  245. char_idx = int(structure_idx[batch_idx][idx])
  246. if idx > 0 and char_idx == end_idx:
  247. break
  248. if char_idx in ignored_tokens:
  249. continue
  250. text = self.character[char_idx]
  251. if text in self.td_token:
  252. bbox = bbox_preds[batch_idx, idx]
  253. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  254. bbox_list.append(bbox)
  255. structure_list.append(text)
  256. score_list.append(structure_probs[batch_idx, idx])
  257. structure_batch_list.append([structure_list, np.mean(score_list)])
  258. bbox_batch_list.append(np.array(bbox_list))
  259. result = {
  260. "bbox_batch_list": bbox_batch_list,
  261. "structure_batch_list": structure_batch_list,
  262. }
  263. return result
  264. def decode_label(self, batch):
  265. """convert text-label into text-index."""
  266. structure_idx = batch[1]
  267. gt_bbox_list = batch[2]
  268. shape_list = batch[-1]
  269. ignored_tokens = self.get_ignored_tokens()
  270. end_idx = self.dict[self.end_str]
  271. structure_batch_list = []
  272. bbox_batch_list = []
  273. batch_size = len(structure_idx)
  274. for batch_idx in range(batch_size):
  275. structure_list = []
  276. bbox_list = []
  277. for idx in range(len(structure_idx[batch_idx])):
  278. char_idx = int(structure_idx[batch_idx][idx])
  279. if idx > 0 and char_idx == end_idx:
  280. break
  281. if char_idx in ignored_tokens:
  282. continue
  283. structure_list.append(self.character[char_idx])
  284. bbox = gt_bbox_list[batch_idx][idx]
  285. if bbox.sum() != 0:
  286. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  287. bbox_list.append(bbox)
  288. structure_batch_list.append(structure_list)
  289. bbox_batch_list.append(bbox_list)
  290. result = {
  291. "bbox_batch_list": bbox_batch_list,
  292. "structure_batch_list": structure_batch_list,
  293. }
  294. return result
  295. def _bbox_decode(self, bbox, shape):
  296. h, w = shape[:2]
  297. bbox[0::2] *= w
  298. bbox[1::2] *= h
  299. return bbox
  300. def get_ignored_tokens(self):
  301. beg_idx = self.get_beg_end_flag_idx("beg")
  302. end_idx = self.get_beg_end_flag_idx("end")
  303. return [beg_idx, end_idx]
  304. def get_beg_end_flag_idx(self, beg_or_end):
  305. if beg_or_end == "beg":
  306. return np.array(self.dict[self.beg_str])
  307. if beg_or_end == "end":
  308. return np.array(self.dict[self.end_str])
  309. raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
  310. def add_special_char(self, dict_character):
  311. self.beg_str = "sos"
  312. self.end_str = "eos"
  313. dict_character = [self.beg_str] + dict_character + [self.end_str]
  314. return dict_character
  315. class TablePreprocess:
  316. def __init__(self):
  317. self.table_max_len = 488
  318. self.build_pre_process_list()
  319. self.ops = self.create_operators()
  320. def __call__(self, data):
  321. """transform"""
  322. if self.ops is None:
  323. self.ops = []
  324. for op in self.ops:
  325. data = op(data)
  326. if data is None:
  327. return None
  328. return data
  329. def create_operators(
  330. self,
  331. ):
  332. """
  333. create operators based on the config
  334. Args:
  335. params(list): a dict list, used to create some operators
  336. """
  337. assert isinstance(
  338. self.pre_process_list, list
  339. ), "operator config should be a list"
  340. ops = []
  341. for operator in self.pre_process_list:
  342. assert (
  343. isinstance(operator, dict) and len(operator) == 1
  344. ), "yaml format error"
  345. op_name = list(operator)[0]
  346. param = {} if operator[op_name] is None else operator[op_name]
  347. op = eval(op_name)(**param)
  348. ops.append(op)
  349. return ops
  350. def build_pre_process_list(self):
  351. resize_op = {
  352. "ResizeTableImage": {
  353. "max_len": self.table_max_len,
  354. }
  355. }
  356. pad_op = {
  357. "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
  358. }
  359. normalize_op = {
  360. "NormalizeImage": {
  361. "std": [0.229, 0.224, 0.225],
  362. "mean": [0.485, 0.456, 0.406],
  363. "scale": "1./255.",
  364. "order": "hwc",
  365. }
  366. }
  367. to_chw_op = {"ToCHWImage": None}
  368. keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
  369. self.pre_process_list = [
  370. resize_op,
  371. normalize_op,
  372. pad_op,
  373. to_chw_op,
  374. keep_keys_op,
  375. ]
  376. class BatchTablePreprocess:
  377. def __init__(self):
  378. self.preprocess = TablePreprocess()
  379. def __call__(
  380. self, img_list: List[np.ndarray]
  381. ) -> Tuple[List[np.ndarray], List[List[float]]]:
  382. """批量处理图像
  383. Args:
  384. img_list: 图像列表
  385. Returns:
  386. 预处理后的图像列表和形状信息列表
  387. """
  388. processed_imgs = []
  389. shape_lists = []
  390. for img in img_list:
  391. if img is None:
  392. continue
  393. data = {"image": img}
  394. img_processed, shape_list = self.preprocess(data)
  395. processed_imgs.append(img_processed)
  396. shape_lists.append(shape_list)
  397. return processed_imgs, shape_lists
  398. class ResizeTableImage:
  399. def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
  400. super(ResizeTableImage, self).__init__()
  401. self.max_len = max_len
  402. self.resize_bboxes = resize_bboxes
  403. self.infer_mode = infer_mode
  404. def __call__(self, data):
  405. img = data["image"]
  406. height, width = img.shape[0:2]
  407. ratio = self.max_len / (max(height, width) * 1.0)
  408. resize_h = int(height * ratio)
  409. resize_w = int(width * ratio)
  410. resize_img = cv2.resize(img, (resize_w, resize_h))
  411. if self.resize_bboxes and not self.infer_mode:
  412. data["bboxes"] = data["bboxes"] * ratio
  413. data["image"] = resize_img
  414. data["src_img"] = img
  415. data["shape"] = np.array([height, width, ratio, ratio])
  416. data["max_len"] = self.max_len
  417. return data
  418. class PaddingTableImage:
  419. def __init__(self, size, **kwargs):
  420. super(PaddingTableImage, self).__init__()
  421. self.size = size
  422. def __call__(self, data):
  423. img = data["image"]
  424. pad_h, pad_w = self.size
  425. padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
  426. height, width = img.shape[0:2]
  427. padding_img[0:height, 0:width, :] = img.copy()
  428. data["image"] = padding_img
  429. shape = data["shape"].tolist()
  430. shape.extend([pad_h, pad_w])
  431. data["shape"] = np.array(shape)
  432. return data
  433. class NormalizeImage:
  434. """normalize image such as substract mean, divide std"""
  435. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  436. if isinstance(scale, str):
  437. scale = eval(scale)
  438. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  439. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  440. std = std if std is not None else [0.229, 0.224, 0.225]
  441. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  442. self.mean = np.array(mean).reshape(shape).astype("float32")
  443. self.std = np.array(std).reshape(shape).astype("float32")
  444. def __call__(self, data):
  445. img = np.array(data["image"])
  446. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  447. data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
  448. return data
  449. class ToCHWImage:
  450. """convert hwc image to chw image"""
  451. def __init__(self, **kwargs):
  452. pass
  453. def __call__(self, data):
  454. img = np.array(data["image"])
  455. data["image"] = img.transpose((2, 0, 1))
  456. return data
  457. class KeepKeys:
  458. def __init__(self, keep_keys, **kwargs):
  459. self.keep_keys = keep_keys
  460. def __call__(self, data):
  461. data_list = []
  462. for key in self.keep_keys:
  463. data_list.append(data[key])
  464. return data_list
  465. def trans_char_ocr_res(ocr_res):
  466. word_result = []
  467. for res in ocr_res:
  468. score = res[2]
  469. for word_box, word in zip(res[3], res[4]):
  470. word_res = []
  471. word_res.append(word_box)
  472. word_res.append(word)
  473. word_res.append(score)
  474. word_result.append(word_res)
  475. return word_result