|
@@ -0,0 +1,570 @@
|
|
|
|
|
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
+#
|
|
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
|
|
+# You may obtain a copy of the License at
|
|
|
|
|
+#
|
|
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
+#
|
|
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
|
|
+# limitations under the License.
|
|
|
|
|
+import os
|
|
|
|
|
+import platform
|
|
|
|
|
+import traceback
|
|
|
|
|
+from enum import Enum
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from onnxruntime import (
|
|
|
|
|
+ GraphOptimizationLevel,
|
|
|
|
|
+ InferenceSession,
|
|
|
|
|
+ SessionOptions,
|
|
|
|
|
+ get_available_providers,
|
|
|
|
|
+ get_device,
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class EP(Enum):
|
|
|
|
|
+ CPU_EP = "CPUExecutionProvider"
|
|
|
|
|
+ CUDA_EP = "CUDAExecutionProvider"
|
|
|
|
|
+ DIRECTML_EP = "DmlExecutionProvider"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class OrtInferSession:
|
|
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
+ self.logger = logger
|
|
|
|
|
+
|
|
|
|
|
+ model_path = config.get("model_path", None)
|
|
|
|
|
+ self._verify_model(model_path)
|
|
|
|
|
+
|
|
|
|
|
+ self.cfg_use_cuda = config.get("use_cuda", None)
|
|
|
|
|
+ self.cfg_use_dml = config.get("use_dml", None)
|
|
|
|
|
+
|
|
|
|
|
+ self.had_providers: List[str] = get_available_providers()
|
|
|
|
|
+ EP_list = self._get_ep_list()
|
|
|
|
|
+
|
|
|
|
|
+ sess_opt = self._init_sess_opts(config)
|
|
|
|
|
+ self.session = InferenceSession(
|
|
|
|
|
+ model_path,
|
|
|
|
|
+ sess_options=sess_opt,
|
|
|
|
|
+ providers=EP_list,
|
|
|
|
|
+ )
|
|
|
|
|
+ self._verify_providers()
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
|
|
|
|
|
+ sess_opt = SessionOptions()
|
|
|
|
|
+ sess_opt.log_severity_level = 4
|
|
|
|
|
+ sess_opt.enable_cpu_mem_arena = False
|
|
|
|
|
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
|
+
|
|
|
|
|
+ cpu_nums = os.cpu_count()
|
|
|
|
|
+ intra_op_num_threads = config.get("intra_op_num_threads", -1)
|
|
|
|
|
+ if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
|
|
|
|
|
+ sess_opt.intra_op_num_threads = intra_op_num_threads
|
|
|
|
|
+
|
|
|
|
|
+ inter_op_num_threads = config.get("inter_op_num_threads", -1)
|
|
|
|
|
+ if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
|
|
|
|
|
+ sess_opt.inter_op_num_threads = inter_op_num_threads
|
|
|
|
|
+
|
|
|
|
|
+ return sess_opt
|
|
|
|
|
+
|
|
|
|
|
+ def get_metadata(self, key: str = "character") -> list:
|
|
|
|
|
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|
|
|
|
+ content_list = meta_dict[key].splitlines()
|
|
|
|
|
+ return content_list
|
|
|
|
|
+
|
|
|
|
|
+ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
|
|
|
|
|
+ cpu_provider_opts = {
|
|
|
|
|
+ "arena_extend_strategy": "kSameAsRequested",
|
|
|
|
|
+ }
|
|
|
|
|
+ EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
|
|
|
|
|
+
|
|
|
|
|
+ cuda_provider_opts = {
|
|
|
|
|
+ "device_id": 0,
|
|
|
|
|
+ "arena_extend_strategy": "kNextPowerOfTwo",
|
|
|
|
|
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
|
|
|
|
|
+ "do_copy_in_default_stream": True,
|
|
|
|
|
+ }
|
|
|
|
|
+ self.use_cuda = self._check_cuda()
|
|
|
|
|
+ if self.use_cuda:
|
|
|
|
|
+ EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
|
|
|
|
|
+
|
|
|
|
|
+ self.use_directml = self._check_dml()
|
|
|
|
|
+ if self.use_directml:
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "Windows 10 or above detected, try to use DirectML as primary provider"
|
|
|
|
|
+ )
|
|
|
|
|
+ directml_options = (
|
|
|
|
|
+ cuda_provider_opts if self.use_cuda else cpu_provider_opts
|
|
|
|
|
+ )
|
|
|
|
|
+ EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
|
|
|
|
|
+ return EP_list
|
|
|
|
|
+
|
|
|
|
|
+ def _check_cuda(self) -> bool:
|
|
|
|
|
+ if not self.cfg_use_cuda:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ cur_device = get_device()
|
|
|
|
|
+ if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ self.logger.warning(
|
|
|
|
|
+ "%s is not in available providers (%s). Use %s inference by default.",
|
|
|
|
|
+ EP.CUDA_EP.value,
|
|
|
|
|
+ self.had_providers,
|
|
|
|
|
+ self.had_providers[0],
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "(For reference only) If you want to use GPU acceleration, you must do:"
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "First, uninstall all onnxruntime pakcages in current environment."
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
|
|
|
|
|
+ EP.CUDA_EP.value,
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ def _check_dml(self) -> bool:
|
|
|
|
|
+ if not self.cfg_use_dml:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ cur_os = platform.system()
|
|
|
|
|
+ if cur_os != "Windows":
|
|
|
|
|
+ self.logger.warning(
|
|
|
|
|
+ "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
|
|
|
|
|
+ cur_os,
|
|
|
|
|
+ self.had_providers[0],
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ cur_window_version = int(platform.release().split(".")[0])
|
|
|
|
|
+ if cur_window_version < 10:
|
|
|
|
|
+ self.logger.warning(
|
|
|
|
|
+ "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
|
|
|
|
|
+ cur_window_version,
|
|
|
|
|
+ self.had_providers[0],
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ if EP.DIRECTML_EP.value in self.had_providers:
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ self.logger.warning(
|
|
|
|
|
+ "%s is not in available providers (%s). Use %s inference by default.",
|
|
|
|
|
+ EP.DIRECTML_EP.value,
|
|
|
|
|
+ self.had_providers,
|
|
|
|
|
+ self.had_providers[0],
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info("If you want to use DirectML acceleration, you must do:")
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "First, uninstall all onnxruntime pakcages in current environment."
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
|
|
|
|
|
+ EP.DIRECTML_EP.value,
|
|
|
|
|
+ )
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ def _verify_providers(self):
|
|
|
|
|
+ session_providers = self.session.get_providers()
|
|
|
|
|
+ first_provider = session_providers[0]
|
|
|
|
|
+
|
|
|
|
|
+ if self.use_cuda and first_provider != EP.CUDA_EP.value:
|
|
|
|
|
+ self.logger.warning(
|
|
|
|
|
+ "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
|
|
|
|
|
+ EP.CUDA_EP.value,
|
|
|
|
|
+ first_provider,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if self.use_directml and first_provider != EP.DIRECTML_EP.value:
|
|
|
|
|
+ self.logger.warning(
|
|
|
|
|
+ "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
|
|
|
|
|
+ EP.DIRECTML_EP.value,
|
|
|
|
|
+ first_provider,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
|
|
|
|
|
+ input_dict = dict(zip(self.get_input_names(), input_content))
|
|
|
|
|
+ try:
|
|
|
|
|
+ return self.session.run(None, input_dict)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ error_info = traceback.format_exc()
|
|
|
|
|
+ raise ONNXRuntimeError(error_info) from e
|
|
|
|
|
+
|
|
|
|
|
+ def get_input_names(self) -> List[str]:
|
|
|
|
|
+ return [v.name for v in self.session.get_inputs()]
|
|
|
|
|
+
|
|
|
|
|
+ def get_output_names(self) -> List[str]:
|
|
|
|
|
+ return [v.name for v in self.session.get_outputs()]
|
|
|
|
|
+
|
|
|
|
|
+ def get_character_list(self, key: str = "character") -> List[str]:
|
|
|
|
|
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|
|
|
|
+ return meta_dict[key].splitlines()
|
|
|
|
|
+
|
|
|
|
|
+ def have_key(self, key: str = "character") -> bool:
|
|
|
|
|
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|
|
|
|
+ if key in meta_dict.keys():
|
|
|
|
|
+ return True
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _verify_model(model_path: Union[str, Path, None]):
|
|
|
|
|
+ if model_path is None:
|
|
|
|
|
+ raise ValueError("model_path is None!")
|
|
|
|
|
+
|
|
|
|
|
+ model_path = Path(model_path)
|
|
|
|
|
+ if not model_path.exists():
|
|
|
|
|
+ raise FileNotFoundError(f"{model_path} does not exists.")
|
|
|
|
|
+
|
|
|
|
|
+ if not model_path.is_file():
|
|
|
|
|
+ raise FileExistsError(f"{model_path} is not a file.")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ONNXRuntimeError(Exception):
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TableLabelDecode:
|
|
|
|
|
+ def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
|
|
|
|
|
+ if merge_no_span_structure:
|
|
|
|
|
+ if "<td></td>" not in dict_character:
|
|
|
|
|
+ dict_character.append("<td></td>")
|
|
|
|
|
+ if "<td>" in dict_character:
|
|
|
|
|
+ dict_character.remove("<td>")
|
|
|
|
|
+
|
|
|
|
|
+ dict_character = self.add_special_char(dict_character)
|
|
|
|
|
+ self.dict = {}
|
|
|
|
|
+ for i, char in enumerate(dict_character):
|
|
|
|
|
+ self.dict[char] = i
|
|
|
|
|
+ self.character = dict_character
|
|
|
|
|
+ self.td_token = ["<td>", "<td", "<td></td>"]
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, preds, batch=None):
|
|
|
|
|
+ structure_probs = preds["structure_probs"]
|
|
|
|
|
+ bbox_preds = preds["loc_preds"]
|
|
|
|
|
+ shape_list = batch[-1]
|
|
|
|
|
+ result = self.decode(structure_probs, bbox_preds, shape_list)
|
|
|
|
|
+ if len(batch) == 1: # only contains shape
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ label_decode_result = self.decode_label(batch)
|
|
|
|
|
+ return result, label_decode_result
|
|
|
|
|
+
|
|
|
|
|
+ def decode(self, structure_probs, bbox_preds, shape_list):
|
|
|
|
|
+ """convert text-label into text-index."""
|
|
|
|
|
+ ignored_tokens = self.get_ignored_tokens()
|
|
|
|
|
+ end_idx = self.dict[self.end_str]
|
|
|
|
|
+
|
|
|
|
|
+ structure_idx = structure_probs.argmax(axis=2)
|
|
|
|
|
+ structure_probs = structure_probs.max(axis=2)
|
|
|
|
|
+
|
|
|
|
|
+ structure_batch_list = []
|
|
|
|
|
+ bbox_batch_list = []
|
|
|
|
|
+ batch_size = len(structure_idx)
|
|
|
|
|
+ for batch_idx in range(batch_size):
|
|
|
|
|
+ structure_list = []
|
|
|
|
|
+ bbox_list = []
|
|
|
|
|
+ score_list = []
|
|
|
|
|
+ for idx in range(len(structure_idx[batch_idx])):
|
|
|
|
|
+ char_idx = int(structure_idx[batch_idx][idx])
|
|
|
|
|
+ if idx > 0 and char_idx == end_idx:
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ if char_idx in ignored_tokens:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ text = self.character[char_idx]
|
|
|
|
|
+ if text in self.td_token:
|
|
|
|
|
+ bbox = bbox_preds[batch_idx, idx]
|
|
|
|
|
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
|
|
|
|
+ bbox_list.append(bbox)
|
|
|
|
|
+ structure_list.append(text)
|
|
|
|
|
+ score_list.append(structure_probs[batch_idx, idx])
|
|
|
|
|
+ structure_batch_list.append([structure_list, np.mean(score_list)])
|
|
|
|
|
+ bbox_batch_list.append(np.array(bbox_list))
|
|
|
|
|
+ result = {
|
|
|
|
|
+ "bbox_batch_list": bbox_batch_list,
|
|
|
|
|
+ "structure_batch_list": structure_batch_list,
|
|
|
|
|
+ }
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ def decode_label(self, batch):
|
|
|
|
|
+ """convert text-label into text-index."""
|
|
|
|
|
+ structure_idx = batch[1]
|
|
|
|
|
+ gt_bbox_list = batch[2]
|
|
|
|
|
+ shape_list = batch[-1]
|
|
|
|
|
+ ignored_tokens = self.get_ignored_tokens()
|
|
|
|
|
+ end_idx = self.dict[self.end_str]
|
|
|
|
|
+
|
|
|
|
|
+ structure_batch_list = []
|
|
|
|
|
+ bbox_batch_list = []
|
|
|
|
|
+ batch_size = len(structure_idx)
|
|
|
|
|
+ for batch_idx in range(batch_size):
|
|
|
|
|
+ structure_list = []
|
|
|
|
|
+ bbox_list = []
|
|
|
|
|
+ for idx in range(len(structure_idx[batch_idx])):
|
|
|
|
|
+ char_idx = int(structure_idx[batch_idx][idx])
|
|
|
|
|
+ if idx > 0 and char_idx == end_idx:
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ if char_idx in ignored_tokens:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ structure_list.append(self.character[char_idx])
|
|
|
|
|
+
|
|
|
|
|
+ bbox = gt_bbox_list[batch_idx][idx]
|
|
|
|
|
+ if bbox.sum() != 0:
|
|
|
|
|
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
|
|
|
|
+ bbox_list.append(bbox)
|
|
|
|
|
+
|
|
|
|
|
+ structure_batch_list.append(structure_list)
|
|
|
|
|
+ bbox_batch_list.append(bbox_list)
|
|
|
|
|
+ result = {
|
|
|
|
|
+ "bbox_batch_list": bbox_batch_list,
|
|
|
|
|
+ "structure_batch_list": structure_batch_list,
|
|
|
|
|
+ }
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ def _bbox_decode(self, bbox, shape):
|
|
|
|
|
+ h, w = shape[:2]
|
|
|
|
|
+ bbox[0::2] *= w
|
|
|
|
|
+ bbox[1::2] *= h
|
|
|
|
|
+ return bbox
|
|
|
|
|
+
|
|
|
|
|
+ def get_ignored_tokens(self):
|
|
|
|
|
+ beg_idx = self.get_beg_end_flag_idx("beg")
|
|
|
|
|
+ end_idx = self.get_beg_end_flag_idx("end")
|
|
|
|
|
+ return [beg_idx, end_idx]
|
|
|
|
|
+
|
|
|
|
|
+ def get_beg_end_flag_idx(self, beg_or_end):
|
|
|
|
|
+ if beg_or_end == "beg":
|
|
|
|
|
+ return np.array(self.dict[self.beg_str])
|
|
|
|
|
+
|
|
|
|
|
+ if beg_or_end == "end":
|
|
|
|
|
+ return np.array(self.dict[self.end_str])
|
|
|
|
|
+
|
|
|
|
|
+ raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
|
|
|
|
|
+
|
|
|
|
|
+ def add_special_char(self, dict_character):
|
|
|
|
|
+ self.beg_str = "sos"
|
|
|
|
|
+ self.end_str = "eos"
|
|
|
|
|
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
|
|
|
|
|
+ return dict_character
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TablePreprocess:
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ self.table_max_len = 488
|
|
|
|
|
+ self.build_pre_process_list()
|
|
|
|
|
+ self.ops = self.create_operators()
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, data):
|
|
|
|
|
+ """transform"""
|
|
|
|
|
+ if self.ops is None:
|
|
|
|
|
+ self.ops = []
|
|
|
|
|
+
|
|
|
|
|
+ for op in self.ops:
|
|
|
|
|
+ data = op(data)
|
|
|
|
|
+ if data is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+ return data
|
|
|
|
|
+
|
|
|
|
|
+ def create_operators(
|
|
|
|
|
+ self,
|
|
|
|
|
+ ):
|
|
|
|
|
+ """
|
|
|
|
|
+ create operators based on the config
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ params(list): a dict list, used to create some operators
|
|
|
|
|
+ """
|
|
|
|
|
+ assert isinstance(
|
|
|
|
|
+ self.pre_process_list, list
|
|
|
|
|
+ ), "operator config should be a list"
|
|
|
|
|
+ ops = []
|
|
|
|
|
+ for operator in self.pre_process_list:
|
|
|
|
|
+ assert (
|
|
|
|
|
+ isinstance(operator, dict) and len(operator) == 1
|
|
|
|
|
+ ), "yaml format error"
|
|
|
|
|
+ op_name = list(operator)[0]
|
|
|
|
|
+ param = {} if operator[op_name] is None else operator[op_name]
|
|
|
|
|
+ op = eval(op_name)(**param)
|
|
|
|
|
+ ops.append(op)
|
|
|
|
|
+ return ops
|
|
|
|
|
+
|
|
|
|
|
+ def build_pre_process_list(self):
|
|
|
|
|
+ resize_op = {
|
|
|
|
|
+ "ResizeTableImage": {
|
|
|
|
|
+ "max_len": self.table_max_len,
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ pad_op = {
|
|
|
|
|
+ "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
|
|
|
|
|
+ }
|
|
|
|
|
+ normalize_op = {
|
|
|
|
|
+ "NormalizeImage": {
|
|
|
|
|
+ "std": [0.229, 0.224, 0.225],
|
|
|
|
|
+ "mean": [0.485, 0.456, 0.406],
|
|
|
|
|
+ "scale": "1./255.",
|
|
|
|
|
+ "order": "hwc",
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ to_chw_op = {"ToCHWImage": None}
|
|
|
|
|
+ keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
|
|
|
|
|
+ self.pre_process_list = [
|
|
|
|
|
+ resize_op,
|
|
|
|
|
+ normalize_op,
|
|
|
|
|
+ pad_op,
|
|
|
|
|
+ to_chw_op,
|
|
|
|
|
+ keep_keys_op,
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class BatchTablePreprocess:
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ self.preprocess = TablePreprocess()
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(
|
|
|
|
|
+ self, img_list: List[np.ndarray]
|
|
|
|
|
+ ) -> Tuple[List[np.ndarray], List[List[float]]]:
|
|
|
|
|
+ """批量处理图像
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ img_list: 图像列表
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 预处理后的图像列表和形状信息列表
|
|
|
|
|
+ """
|
|
|
|
|
+ processed_imgs = []
|
|
|
|
|
+ shape_lists = []
|
|
|
|
|
+
|
|
|
|
|
+ for img in img_list:
|
|
|
|
|
+ if img is None:
|
|
|
|
|
+ continue
|
|
|
|
|
+ data = {"image": img}
|
|
|
|
|
+ img_processed, shape_list = self.preprocess(data)
|
|
|
|
|
+ processed_imgs.append(img_processed)
|
|
|
|
|
+ shape_lists.append(shape_list)
|
|
|
|
|
+ return processed_imgs, shape_lists
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ResizeTableImage:
|
|
|
|
|
+ def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
|
|
|
|
|
+ super(ResizeTableImage, self).__init__()
|
|
|
|
|
+ self.max_len = max_len
|
|
|
|
|
+ self.resize_bboxes = resize_bboxes
|
|
|
|
|
+ self.infer_mode = infer_mode
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, data):
|
|
|
|
|
+ img = data["image"]
|
|
|
|
|
+ height, width = img.shape[0:2]
|
|
|
|
|
+ ratio = self.max_len / (max(height, width) * 1.0)
|
|
|
|
|
+ resize_h = int(height * ratio)
|
|
|
|
|
+ resize_w = int(width * ratio)
|
|
|
|
|
+ resize_img = cv2.resize(img, (resize_w, resize_h))
|
|
|
|
|
+ if self.resize_bboxes and not self.infer_mode:
|
|
|
|
|
+ data["bboxes"] = data["bboxes"] * ratio
|
|
|
|
|
+ data["image"] = resize_img
|
|
|
|
|
+ data["src_img"] = img
|
|
|
|
|
+ data["shape"] = np.array([height, width, ratio, ratio])
|
|
|
|
|
+ data["max_len"] = self.max_len
|
|
|
|
|
+ return data
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class PaddingTableImage:
|
|
|
|
|
+ def __init__(self, size, **kwargs):
|
|
|
|
|
+ super(PaddingTableImage, self).__init__()
|
|
|
|
|
+ self.size = size
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, data):
|
|
|
|
|
+ img = data["image"]
|
|
|
|
|
+ pad_h, pad_w = self.size
|
|
|
|
|
+ padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
|
|
|
|
|
+ height, width = img.shape[0:2]
|
|
|
|
|
+ padding_img[0:height, 0:width, :] = img.copy()
|
|
|
|
|
+ data["image"] = padding_img
|
|
|
|
|
+ shape = data["shape"].tolist()
|
|
|
|
|
+ shape.extend([pad_h, pad_w])
|
|
|
|
|
+ data["shape"] = np.array(shape)
|
|
|
|
|
+ return data
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class NormalizeImage:
|
|
|
|
|
+ """normalize image such as substract mean, divide std"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
|
|
|
|
|
+ if isinstance(scale, str):
|
|
|
|
|
+ scale = eval(scale)
|
|
|
|
|
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
|
|
|
|
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
|
|
|
|
+ std = std if std is not None else [0.229, 0.224, 0.225]
|
|
|
|
|
+
|
|
|
|
|
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
|
|
|
|
|
+ self.mean = np.array(mean).reshape(shape).astype("float32")
|
|
|
|
|
+ self.std = np.array(std).reshape(shape).astype("float32")
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, data):
|
|
|
|
|
+ img = np.array(data["image"])
|
|
|
|
|
+ assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
|
|
|
|
|
+ data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
|
|
|
|
|
+ return data
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ToCHWImage:
|
|
|
|
|
+ """convert hwc image to chw image"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, **kwargs):
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, data):
|
|
|
|
|
+ img = np.array(data["image"])
|
|
|
|
|
+ data["image"] = img.transpose((2, 0, 1))
|
|
|
|
|
+ return data
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class KeepKeys:
|
|
|
|
|
+ def __init__(self, keep_keys, **kwargs):
|
|
|
|
|
+ self.keep_keys = keep_keys
|
|
|
|
|
+
|
|
|
|
|
+ def __call__(self, data):
|
|
|
|
|
+ data_list = []
|
|
|
|
|
+ for key in self.keep_keys:
|
|
|
|
|
+ data_list.append(data[key])
|
|
|
|
|
+ return data_list
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def trans_char_ocr_res(ocr_res):
|
|
|
|
|
+ word_result = []
|
|
|
|
|
+ for res in ocr_res:
|
|
|
|
|
+ score = res[2]
|
|
|
|
|
+ for word_box, word in zip(res[3], res[4]):
|
|
|
|
|
+ word_res = []
|
|
|
|
|
+ word_res.append(word_box)
|
|
|
|
|
+ word_res.append(word)
|
|
|
|
|
+ word_res.append(score)
|
|
|
|
|
+ word_result.append(word_res)
|
|
|
|
|
+ return word_result
|