import os import traceback from enum import Enum from io import BytesIO from pathlib import Path from typing import List, Union, Dict, Any, Tuple import cv2 import numpy as np from onnxruntime import ( GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers, ) from PIL import Image, UnidentifiedImageError root_dir = Path(__file__).resolve().parent InputType = Union[str, np.ndarray, bytes, Path] class EP(Enum): CPU_EP = "CPUExecutionProvider" class OrtInferSession: def __init__(self, config: Dict[str, Any]): model_path = config.get("model_path", None) self._verify_model(model_path) self.had_providers: List[str] = get_available_providers() EP_list = self._get_ep_list() sess_opt = self._init_sess_opts(config) self.session = InferenceSession( model_path, sess_options=sess_opt, providers=EP_list, ) @staticmethod def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL cpu_nums = os.cpu_count() intra_op_num_threads = config.get("intra_op_num_threads", -1) if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: sess_opt.intra_op_num_threads = intra_op_num_threads inter_op_num_threads = config.get("inter_op_num_threads", -1) if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: sess_opt.inter_op_num_threads = inter_op_num_threads return sess_opt def get_metadata(self, key: str = "character") -> list: meta_dict = self.session.get_modelmeta().custom_metadata_map content_list = meta_dict[key].splitlines() return content_list def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: cpu_provider_opts = { "arena_extend_strategy": "kSameAsRequested", } EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] return EP_list def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: input_dict = dict(zip(self.get_input_names(), input_content)) try: return self.session.run(None, input_dict) except Exception as e: error_info = traceback.format_exc() raise ONNXRuntimeError(error_info) from e def get_input_names(self) -> List[str]: return [v.name for v in self.session.get_inputs()] def get_output_names(self) -> List[str]: return [v.name for v in self.session.get_outputs()] def get_character_list(self, key: str = "character") -> List[str]: meta_dict = self.session.get_modelmeta().custom_metadata_map return meta_dict[key].splitlines() def have_key(self, key: str = "character") -> bool: meta_dict = self.session.get_modelmeta().custom_metadata_map if key in meta_dict.keys(): return True return False @staticmethod def _verify_model(model_path: Union[str, Path, None]): if model_path is None: raise ValueError("model_path is None!") model_path = Path(model_path) if not model_path.exists(): raise FileNotFoundError(f"{model_path} does not exists.") if not model_path.is_file(): raise FileExistsError(f"{model_path} is not a file.") class ONNXRuntimeError(Exception): pass class LoadImage: """ Utility class for loading and converting images from various input types to a numpy ndarray. Supported input types: - str or pathlib.Path: Path to an image file. - bytes: Image data in bytes format. - numpy.ndarray: Already loaded image array. The class attempts to load the image and convert it to a numpy ndarray in BGR format. Raises LoadImageError for unsupported types or if the image cannot be loaded. """ def __init__( self, ): pass def __call__(self, img: InputType) -> np.ndarray: img = self.load_img(img) img = self.convert_img(img) return img def load_img(self, img: InputType) -> np.ndarray: if isinstance(img, (str, Path)): self.verify_exist(img) try: img = np.array(Image.open(img)) except UnidentifiedImageError as e: raise LoadImageError(f"cannot identify image file {img}") from e return img elif isinstance(img, bytes): try: img = np.array(Image.open(BytesIO(img))) except UnidentifiedImageError as e: raise LoadImageError(f"cannot identify image from bytes data") from e return img elif isinstance(img, np.ndarray): return img else: raise LoadImageError(f"{type(img)} is not supported!") def convert_img(self, img: np.ndarray): if img.ndim == 2: return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if img.ndim == 3: channel = img.shape[2] if channel == 1: return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if channel == 2: return self.cvt_two_to_three(img) if channel == 4: return self.cvt_four_to_three(img) if channel == 3: return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) raise LoadImageError( f"The channel({channel}) of the img is not in [1, 2, 3, 4]" ) raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") @staticmethod def cvt_four_to_three(img: np.ndarray) -> np.ndarray: """RGBA → BGR""" r, g, b, a = cv2.split(img) new_img = cv2.merge((b, g, r)) not_a = cv2.bitwise_not(a) not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) new_img = cv2.bitwise_and(new_img, new_img, mask=a) new_img = cv2.add(new_img, not_a) return new_img @staticmethod def cvt_two_to_three(img: np.ndarray) -> np.ndarray: """gray + alpha → BGR""" img_gray = img[..., 0] img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) img_alpha = img[..., 1] not_a = cv2.bitwise_not(img_alpha) not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) new_img = cv2.add(new_img, not_a) return new_img @staticmethod def verify_exist(file_path: Union[str, Path]): if not Path(file_path).exists(): raise LoadImageError(f"{file_path} does not exist.") class LoadImageError(Exception): pass # Pillow >=v9.1.0 use a slightly different naming scheme for filters. # Set pillow_interp_codes according to the naming scheme used. if Image is not None: if hasattr(Image, "Resampling"): pillow_interp_codes = { "nearest": Image.Resampling.NEAREST, "bilinear": Image.Resampling.BILINEAR, "bicubic": Image.Resampling.BICUBIC, "box": Image.Resampling.BOX, "lanczos": Image.Resampling.LANCZOS, "hamming": Image.Resampling.HAMMING, } else: pillow_interp_codes = { "nearest": Image.NEAREST, "bilinear": Image.BILINEAR, "bicubic": Image.BICUBIC, "box": Image.BOX, "lanczos": Image.LANCZOS, "hamming": Image.HAMMING, } cv2_interp_codes = { "nearest": cv2.INTER_NEAREST, "bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC, "area": cv2.INTER_AREA, "lanczos": cv2.INTER_LANCZOS4, } def resize_img(img, scale, keep_ratio=True): if keep_ratio: # 缩小使用area更保真 if min(img.shape[:2]) > min(scale): interpolation = "area" else: interpolation = "bicubic" # bilinear img_new, scale_factor = imrescale( img, scale, return_scale=True, interpolation=interpolation ) # the w_scale and h_scale has minor difference # a real fix should be done in the mmcv.imrescale in the future new_h, new_w = img_new.shape[:2] h, w = img.shape[:2] w_scale = new_w / w h_scale = new_h / h else: img_new, w_scale, h_scale = imresize(img, scale, return_scale=True) return img_new, w_scale, h_scale def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None): """Resize image while keeping the aspect ratio. Args: img (ndarray): The input image. scale (float | tuple[int]): The scaling factor or maximum size. If it is a float number, then the image will be rescaled by this factor, else if it is a tuple of 2 integers, then the image will be rescaled as large as possible within the scale. return_scale (bool): Whether to return the scaling factor besides the rescaled image. interpolation (str): Same as :func:`resize`. backend (str | None): Same as :func:`resize`. Returns: ndarray: The rescaled image. """ h, w = img.shape[:2] new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend) if return_scale: return rescaled_img, scale_factor else: return rescaled_img def imresize( img, size, return_scale=False, interpolation="bilinear", out=None, backend=None ): """Resize image to a given size. Args: img (ndarray): The input image. size (tuple[int]): Target size (w, h). return_scale (bool): Whether to return `w_scale` and `h_scale`. interpolation (str): Interpolation method, accepted values are "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' backend, "nearest", "bilinear" for 'pillow' backend. out (ndarray): The output destination. backend (str | None): The image resize backend type. Options are `cv2`, `pillow`, `None`. If backend is None, the global imread_backend specified by ``mmcv.use_backend()`` will be used. Default: None. Returns: tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or `resized_img`. """ h, w = img.shape[:2] if backend is None: backend = "cv2" if backend not in ["cv2", "pillow"]: raise ValueError( f"backend: {backend} is not supported for resize." f"Supported backends are 'cv2', 'pillow'" ) if backend == "pillow": assert img.dtype == np.uint8, "Pillow backend only support uint8 type" pil_image = Image.fromarray(img) pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) resized_img = np.array(pil_image) else: resized_img = cv2.resize( img, size, dst=out, interpolation=cv2_interp_codes[interpolation] ) if not return_scale: return resized_img else: w_scale = size[0] / w h_scale = size[1] / h return resized_img, w_scale, h_scale def rescale_size(old_size, scale, return_scale=False): """Calculate the new size to be rescaled to. Args: old_size (tuple[int]): The old size (w, h) of image. scale (float | tuple[int]): The scaling factor or maximum size. If it is a float number, then the image will be rescaled by this factor, else if it is a tuple of 2 integers, then the image will be rescaled as large as possible within the scale. return_scale (bool): Whether to return the scaling factor besides the rescaled image size. Returns: tuple[int]: The new rescaled image size. """ w, h = old_size if isinstance(scale, (float, int)): if scale <= 0: raise ValueError(f"Invalid scale {scale}, must be positive.") scale_factor = scale elif isinstance(scale, tuple): max_long_edge = max(scale) max_short_edge = min(scale) scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) else: raise TypeError( f"Scale must be a number or tuple of int, but got {type(scale)}" ) new_size = _scale_size((w, h), scale_factor) if return_scale: return new_size, scale_factor else: return new_size def _scale_size(size, scale): """Rescale a size by a ratio. Args: size (tuple[int]): (w, h). scale (float | tuple(float)): Scaling factor. Returns: tuple[int]: scaled size. """ if isinstance(scale, (float, int)): scale = (scale, scale) w, h = size return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)