| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import platform
- import traceback
- from enum import Enum
- from pathlib import Path
- from typing import Any, Dict, List, Tuple, Union
- import cv2
- import numpy as np
- from onnxruntime import (
- GraphOptimizationLevel,
- InferenceSession,
- SessionOptions,
- get_available_providers,
- get_device,
- )
- from loguru import logger
- class EP(Enum):
- CPU_EP = "CPUExecutionProvider"
- CUDA_EP = "CUDAExecutionProvider"
- DIRECTML_EP = "DmlExecutionProvider"
- class OrtInferSession:
- def __init__(self, config: Dict[str, Any]):
- self.logger = logger
- model_path = config.get("model_path", None)
- self._verify_model(model_path)
- self.cfg_use_cuda = config.get("use_cuda", None)
- self.cfg_use_dml = config.get("use_dml", None)
- self.had_providers: List[str] = get_available_providers()
- EP_list = self._get_ep_list()
- sess_opt = self._init_sess_opts(config)
- self.session = InferenceSession(
- model_path,
- sess_options=sess_opt,
- providers=EP_list,
- )
- self._verify_providers()
- @staticmethod
- def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
- sess_opt = SessionOptions()
- sess_opt.log_severity_level = 4
- sess_opt.enable_cpu_mem_arena = False
- sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
- cpu_nums = os.cpu_count()
- intra_op_num_threads = config.get("intra_op_num_threads", -1)
- if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
- sess_opt.intra_op_num_threads = intra_op_num_threads
- inter_op_num_threads = config.get("inter_op_num_threads", -1)
- if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
- sess_opt.inter_op_num_threads = inter_op_num_threads
- return sess_opt
- def get_metadata(self, key: str = "character") -> list:
- meta_dict = self.session.get_modelmeta().custom_metadata_map
- content_list = meta_dict[key].splitlines()
- return content_list
- def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
- cpu_provider_opts = {
- "arena_extend_strategy": "kSameAsRequested",
- }
- EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
- cuda_provider_opts = {
- "device_id": 0,
- "arena_extend_strategy": "kNextPowerOfTwo",
- "cudnn_conv_algo_search": "EXHAUSTIVE",
- "do_copy_in_default_stream": True,
- }
- self.use_cuda = self._check_cuda()
- if self.use_cuda:
- EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
- self.use_directml = self._check_dml()
- if self.use_directml:
- self.logger.info(
- "Windows 10 or above detected, try to use DirectML as primary provider"
- )
- directml_options = (
- cuda_provider_opts if self.use_cuda else cpu_provider_opts
- )
- EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
- return EP_list
- def _check_cuda(self) -> bool:
- if not self.cfg_use_cuda:
- return False
- cur_device = get_device()
- if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
- return True
- self.logger.warning(
- "%s is not in available providers (%s). Use %s inference by default.",
- EP.CUDA_EP.value,
- self.had_providers,
- self.had_providers[0],
- )
- self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
- self.logger.info(
- "(For reference only) If you want to use GPU acceleration, you must do:"
- )
- self.logger.info(
- "First, uninstall all onnxruntime pakcages in current environment."
- )
- self.logger.info(
- "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
- )
- self.logger.info(
- "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
- )
- self.logger.info(
- "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
- )
- self.logger.info(
- "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
- EP.CUDA_EP.value,
- )
- return False
- def _check_dml(self) -> bool:
- if not self.cfg_use_dml:
- return False
- cur_os = platform.system()
- if cur_os != "Windows":
- self.logger.warning(
- "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
- cur_os,
- self.had_providers[0],
- )
- return False
- cur_window_version = int(platform.release().split(".")[0])
- if cur_window_version < 10:
- self.logger.warning(
- "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
- cur_window_version,
- self.had_providers[0],
- )
- return False
- if EP.DIRECTML_EP.value in self.had_providers:
- return True
- self.logger.warning(
- "%s is not in available providers (%s). Use %s inference by default.",
- EP.DIRECTML_EP.value,
- self.had_providers,
- self.had_providers[0],
- )
- self.logger.info("If you want to use DirectML acceleration, you must do:")
- self.logger.info(
- "First, uninstall all onnxruntime pakcages in current environment."
- )
- self.logger.info(
- "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
- )
- self.logger.info(
- "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
- EP.DIRECTML_EP.value,
- )
- return False
- def _verify_providers(self):
- session_providers = self.session.get_providers()
- first_provider = session_providers[0]
- if self.use_cuda and first_provider != EP.CUDA_EP.value:
- self.logger.warning(
- "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
- EP.CUDA_EP.value,
- first_provider,
- )
- if self.use_directml and first_provider != EP.DIRECTML_EP.value:
- self.logger.warning(
- "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
- EP.DIRECTML_EP.value,
- first_provider,
- )
- def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
- input_dict = dict(zip(self.get_input_names(), input_content))
- try:
- return self.session.run(None, input_dict)
- except Exception as e:
- error_info = traceback.format_exc()
- raise ONNXRuntimeError(error_info) from e
- def get_input_names(self) -> List[str]:
- return [v.name for v in self.session.get_inputs()]
- def get_output_names(self) -> List[str]:
- return [v.name for v in self.session.get_outputs()]
- def get_character_list(self, key: str = "character") -> List[str]:
- meta_dict = self.session.get_modelmeta().custom_metadata_map
- return meta_dict[key].splitlines()
- def have_key(self, key: str = "character") -> bool:
- meta_dict = self.session.get_modelmeta().custom_metadata_map
- if key in meta_dict.keys():
- return True
- return False
- @staticmethod
- def _verify_model(model_path: Union[str, Path, None]):
- if model_path is None:
- raise ValueError("model_path is None!")
- model_path = Path(model_path)
- if not model_path.exists():
- raise FileNotFoundError(f"{model_path} does not exists.")
- if not model_path.is_file():
- raise FileExistsError(f"{model_path} is not a file.")
- class ONNXRuntimeError(Exception):
- pass
- class TableLabelDecode:
- def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
- if merge_no_span_structure:
- if "<td></td>" not in dict_character:
- dict_character.append("<td></td>")
- if "<td>" in dict_character:
- dict_character.remove("<td>")
- dict_character = self.add_special_char(dict_character)
- self.dict = {}
- for i, char in enumerate(dict_character):
- self.dict[char] = i
- self.character = dict_character
- self.td_token = ["<td>", "<td", "<td></td>"]
- def __call__(self, preds, batch=None):
- structure_probs = preds["structure_probs"]
- bbox_preds = preds["loc_preds"]
- shape_list = batch[-1]
- result = self.decode(structure_probs, bbox_preds, shape_list)
- if len(batch) == 1: # only contains shape
- return result
- label_decode_result = self.decode_label(batch)
- return result, label_decode_result
- def decode(self, structure_probs, bbox_preds, shape_list):
- """convert text-label into text-index."""
- ignored_tokens = self.get_ignored_tokens()
- end_idx = self.dict[self.end_str]
- structure_idx = structure_probs.argmax(axis=2)
- structure_probs = structure_probs.max(axis=2)
- structure_batch_list = []
- bbox_batch_list = []
- batch_size = len(structure_idx)
- for batch_idx in range(batch_size):
- structure_list = []
- bbox_list = []
- score_list = []
- for idx in range(len(structure_idx[batch_idx])):
- char_idx = int(structure_idx[batch_idx][idx])
- if idx > 0 and char_idx == end_idx:
- break
- if char_idx in ignored_tokens:
- continue
- text = self.character[char_idx]
- if text in self.td_token:
- bbox = bbox_preds[batch_idx, idx]
- bbox = self._bbox_decode(bbox, shape_list[batch_idx])
- bbox_list.append(bbox)
- structure_list.append(text)
- score_list.append(structure_probs[batch_idx, idx])
- structure_batch_list.append([structure_list, np.mean(score_list)])
- bbox_batch_list.append(np.array(bbox_list))
- result = {
- "bbox_batch_list": bbox_batch_list,
- "structure_batch_list": structure_batch_list,
- }
- return result
- def decode_label(self, batch):
- """convert text-label into text-index."""
- structure_idx = batch[1]
- gt_bbox_list = batch[2]
- shape_list = batch[-1]
- ignored_tokens = self.get_ignored_tokens()
- end_idx = self.dict[self.end_str]
- structure_batch_list = []
- bbox_batch_list = []
- batch_size = len(structure_idx)
- for batch_idx in range(batch_size):
- structure_list = []
- bbox_list = []
- for idx in range(len(structure_idx[batch_idx])):
- char_idx = int(structure_idx[batch_idx][idx])
- if idx > 0 and char_idx == end_idx:
- break
- if char_idx in ignored_tokens:
- continue
- structure_list.append(self.character[char_idx])
- bbox = gt_bbox_list[batch_idx][idx]
- if bbox.sum() != 0:
- bbox = self._bbox_decode(bbox, shape_list[batch_idx])
- bbox_list.append(bbox)
- structure_batch_list.append(structure_list)
- bbox_batch_list.append(bbox_list)
- result = {
- "bbox_batch_list": bbox_batch_list,
- "structure_batch_list": structure_batch_list,
- }
- return result
- def _bbox_decode(self, bbox, shape):
- h, w = shape[:2]
- bbox[0::2] *= w
- bbox[1::2] *= h
- return bbox
- def get_ignored_tokens(self):
- beg_idx = self.get_beg_end_flag_idx("beg")
- end_idx = self.get_beg_end_flag_idx("end")
- return [beg_idx, end_idx]
- def get_beg_end_flag_idx(self, beg_or_end):
- if beg_or_end == "beg":
- return np.array(self.dict[self.beg_str])
- if beg_or_end == "end":
- return np.array(self.dict[self.end_str])
- raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
- def add_special_char(self, dict_character):
- self.beg_str = "sos"
- self.end_str = "eos"
- dict_character = [self.beg_str] + dict_character + [self.end_str]
- return dict_character
- class TablePreprocess:
- def __init__(self):
- self.table_max_len = 488
- self.build_pre_process_list()
- self.ops = self.create_operators()
- def __call__(self, data):
- """transform"""
- if self.ops is None:
- self.ops = []
- for op in self.ops:
- data = op(data)
- if data is None:
- return None
- return data
- def create_operators(
- self,
- ):
- """
- create operators based on the config
- Args:
- params(list): a dict list, used to create some operators
- """
- assert isinstance(
- self.pre_process_list, list
- ), "operator config should be a list"
- ops = []
- for operator in self.pre_process_list:
- assert (
- isinstance(operator, dict) and len(operator) == 1
- ), "yaml format error"
- op_name = list(operator)[0]
- param = {} if operator[op_name] is None else operator[op_name]
- op = eval(op_name)(**param)
- ops.append(op)
- return ops
- def build_pre_process_list(self):
- resize_op = {
- "ResizeTableImage": {
- "max_len": self.table_max_len,
- }
- }
- pad_op = {
- "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
- }
- normalize_op = {
- "NormalizeImage": {
- "std": [0.229, 0.224, 0.225],
- "mean": [0.485, 0.456, 0.406],
- "scale": "1./255.",
- "order": "hwc",
- }
- }
- to_chw_op = {"ToCHWImage": None}
- keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
- self.pre_process_list = [
- resize_op,
- normalize_op,
- pad_op,
- to_chw_op,
- keep_keys_op,
- ]
- class BatchTablePreprocess:
- def __init__(self):
- self.preprocess = TablePreprocess()
- def __call__(
- self, img_list: List[np.ndarray]
- ) -> Tuple[List[np.ndarray], List[List[float]]]:
- """批量处理图像
- Args:
- img_list: 图像列表
- Returns:
- 预处理后的图像列表和形状信息列表
- """
- processed_imgs = []
- shape_lists = []
- for img in img_list:
- if img is None:
- continue
- data = {"image": img}
- img_processed, shape_list = self.preprocess(data)
- processed_imgs.append(img_processed)
- shape_lists.append(shape_list)
- return processed_imgs, shape_lists
- class ResizeTableImage:
- def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
- super(ResizeTableImage, self).__init__()
- self.max_len = max_len
- self.resize_bboxes = resize_bboxes
- self.infer_mode = infer_mode
- def __call__(self, data):
- img = data["image"]
- height, width = img.shape[0:2]
- ratio = self.max_len / (max(height, width) * 1.0)
- resize_h = int(height * ratio)
- resize_w = int(width * ratio)
- resize_img = cv2.resize(img, (resize_w, resize_h))
- if self.resize_bboxes and not self.infer_mode:
- data["bboxes"] = data["bboxes"] * ratio
- data["image"] = resize_img
- data["src_img"] = img
- data["shape"] = np.array([height, width, ratio, ratio])
- data["max_len"] = self.max_len
- return data
- class PaddingTableImage:
- def __init__(self, size, **kwargs):
- super(PaddingTableImage, self).__init__()
- self.size = size
- def __call__(self, data):
- img = data["image"]
- pad_h, pad_w = self.size
- padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
- height, width = img.shape[0:2]
- padding_img[0:height, 0:width, :] = img.copy()
- data["image"] = padding_img
- shape = data["shape"].tolist()
- shape.extend([pad_h, pad_w])
- data["shape"] = np.array(shape)
- return data
- class NormalizeImage:
- """normalize image such as substract mean, divide std"""
- def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
- if isinstance(scale, str):
- scale = eval(scale)
- self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
- mean = mean if mean is not None else [0.485, 0.456, 0.406]
- std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype("float32")
- self.std = np.array(std).reshape(shape).astype("float32")
- def __call__(self, data):
- img = np.array(data["image"])
- assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
- data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
- return data
- class ToCHWImage:
- """convert hwc image to chw image"""
- def __init__(self, **kwargs):
- pass
- def __call__(self, data):
- img = np.array(data["image"])
- data["image"] = img.transpose((2, 0, 1))
- return data
- class KeepKeys:
- def __init__(self, keep_keys, **kwargs):
- self.keep_keys = keep_keys
- def __call__(self, data):
- data_list = []
- for key in self.keep_keys:
- data_list.append(data[key])
- return data_list
- def trans_char_ocr_res(ocr_res):
- word_result = []
- for res in ocr_res:
- score = res[2]
- for word_box, word in zip(res[3], res[4]):
- word_res = []
- word_res.append(word_box)
- word_res.append(word)
- word_res.append(score)
- word_result.append(word_res)
- return word_result
|