| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- import os
- import traceback
- from enum import Enum
- from io import BytesIO
- from pathlib import Path
- from typing import List, Union, Dict, Any, Tuple
- import cv2
- import numpy as np
- from onnxruntime import (
- GraphOptimizationLevel,
- InferenceSession,
- SessionOptions,
- get_available_providers,
- )
- from PIL import Image, UnidentifiedImageError
- root_dir = Path(__file__).resolve().parent
- InputType = Union[str, np.ndarray, bytes, Path]
- class EP(Enum):
- CPU_EP = "CPUExecutionProvider"
- class OrtInferSession:
- def __init__(self, config: Dict[str, Any]):
- model_path = config.get("model_path", None)
- self._verify_model(model_path)
- self.had_providers: List[str] = get_available_providers()
- EP_list = self._get_ep_list()
- sess_opt = self._init_sess_opts(config)
- self.session = InferenceSession(
- model_path,
- sess_options=sess_opt,
- providers=EP_list,
- )
- @staticmethod
- def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
- sess_opt = SessionOptions()
- sess_opt.log_severity_level = 4
- sess_opt.enable_cpu_mem_arena = False
- sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
- cpu_nums = os.cpu_count()
- intra_op_num_threads = config.get("intra_op_num_threads", -1)
- if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
- sess_opt.intra_op_num_threads = intra_op_num_threads
- inter_op_num_threads = config.get("inter_op_num_threads", -1)
- if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
- sess_opt.inter_op_num_threads = inter_op_num_threads
- return sess_opt
- def get_metadata(self, key: str = "character") -> list:
- meta_dict = self.session.get_modelmeta().custom_metadata_map
- content_list = meta_dict[key].splitlines()
- return content_list
- def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
- cpu_provider_opts = {
- "arena_extend_strategy": "kSameAsRequested",
- }
- EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
- return EP_list
- def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
- input_dict = dict(zip(self.get_input_names(), input_content))
- try:
- return self.session.run(None, input_dict)
- except Exception as e:
- error_info = traceback.format_exc()
- raise ONNXRuntimeError(error_info) from e
- def get_input_names(self) -> List[str]:
- return [v.name for v in self.session.get_inputs()]
- def get_output_names(self) -> List[str]:
- return [v.name for v in self.session.get_outputs()]
- def get_character_list(self, key: str = "character") -> List[str]:
- meta_dict = self.session.get_modelmeta().custom_metadata_map
- return meta_dict[key].splitlines()
- def have_key(self, key: str = "character") -> bool:
- meta_dict = self.session.get_modelmeta().custom_metadata_map
- if key in meta_dict.keys():
- return True
- return False
- @staticmethod
- def _verify_model(model_path: Union[str, Path, None]):
- if model_path is None:
- raise ValueError("model_path is None!")
- model_path = Path(model_path)
- if not model_path.exists():
- raise FileNotFoundError(f"{model_path} does not exists.")
- if not model_path.is_file():
- raise FileExistsError(f"{model_path} is not a file.")
- class ONNXRuntimeError(Exception):
- pass
- class LoadImage:
- """
- Utility class for loading and converting images from various input types to a numpy ndarray.
- Supported input types:
- - str or pathlib.Path: Path to an image file.
- - bytes: Image data in bytes format.
- - numpy.ndarray: Already loaded image array.
- The class attempts to load the image and convert it to a numpy ndarray in BGR format.
- Raises LoadImageError for unsupported types or if the image cannot be loaded.
- """
- def __init__(
- self,
- ):
- pass
- def __call__(self, img: InputType) -> np.ndarray:
- img = self.load_img(img)
- img = self.convert_img(img)
- return img
- def load_img(self, img: InputType) -> np.ndarray:
- if isinstance(img, (str, Path)):
- self.verify_exist(img)
- try:
- img = np.array(Image.open(img))
- except UnidentifiedImageError as e:
- raise LoadImageError(f"cannot identify image file {img}") from e
- return img
- elif isinstance(img, bytes):
- try:
- img = np.array(Image.open(BytesIO(img)))
- except UnidentifiedImageError as e:
- raise LoadImageError(f"cannot identify image from bytes data") from e
- return img
- elif isinstance(img, np.ndarray):
- return img
- else:
- raise LoadImageError(f"{type(img)} is not supported!")
- def convert_img(self, img: np.ndarray):
- if img.ndim == 2:
- return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- if img.ndim == 3:
- channel = img.shape[2]
- if channel == 1:
- return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- if channel == 2:
- return self.cvt_two_to_three(img)
- if channel == 4:
- return self.cvt_four_to_three(img)
- if channel == 3:
- return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- raise LoadImageError(
- f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
- )
- raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
- @staticmethod
- def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
- """RGBA → BGR"""
- r, g, b, a = cv2.split(img)
- new_img = cv2.merge((b, g, r))
- not_a = cv2.bitwise_not(a)
- not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
- new_img = cv2.bitwise_and(new_img, new_img, mask=a)
- new_img = cv2.add(new_img, not_a)
- return new_img
- @staticmethod
- def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
- """gray + alpha → BGR"""
- img_gray = img[..., 0]
- img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
- img_alpha = img[..., 1]
- not_a = cv2.bitwise_not(img_alpha)
- not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
- new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
- new_img = cv2.add(new_img, not_a)
- return new_img
- @staticmethod
- def verify_exist(file_path: Union[str, Path]):
- if not Path(file_path).exists():
- raise LoadImageError(f"{file_path} does not exist.")
- class LoadImageError(Exception):
- pass
- # Pillow >=v9.1.0 use a slightly different naming scheme for filters.
- # Set pillow_interp_codes according to the naming scheme used.
- if Image is not None:
- if hasattr(Image, "Resampling"):
- pillow_interp_codes = {
- "nearest": Image.Resampling.NEAREST,
- "bilinear": Image.Resampling.BILINEAR,
- "bicubic": Image.Resampling.BICUBIC,
- "box": Image.Resampling.BOX,
- "lanczos": Image.Resampling.LANCZOS,
- "hamming": Image.Resampling.HAMMING,
- }
- else:
- pillow_interp_codes = {
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC,
- "box": Image.BOX,
- "lanczos": Image.LANCZOS,
- "hamming": Image.HAMMING,
- }
- cv2_interp_codes = {
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC,
- "area": cv2.INTER_AREA,
- "lanczos": cv2.INTER_LANCZOS4,
- }
- def resize_img(img, scale, keep_ratio=True):
- if keep_ratio:
- # 缩小使用area更保真
- if min(img.shape[:2]) > min(scale):
- interpolation = "area"
- else:
- interpolation = "bicubic" # bilinear
- img_new, scale_factor = imrescale(
- img, scale, return_scale=True, interpolation=interpolation
- )
- # the w_scale and h_scale has minor difference
- # a real fix should be done in the mmcv.imrescale in the future
- new_h, new_w = img_new.shape[:2]
- h, w = img.shape[:2]
- w_scale = new_w / w
- h_scale = new_h / h
- else:
- img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
- return img_new, w_scale, h_scale
- def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
- """Resize image while keeping the aspect ratio.
- Args:
- img (ndarray): The input image.
- scale (float | tuple[int]): The scaling factor or maximum size.
- If it is a float number, then the image will be rescaled by this
- factor, else if it is a tuple of 2 integers, then the image will
- be rescaled as large as possible within the scale.
- return_scale (bool): Whether to return the scaling factor besides the
- rescaled image.
- interpolation (str): Same as :func:`resize`.
- backend (str | None): Same as :func:`resize`.
- Returns:
- ndarray: The rescaled image.
- """
- h, w = img.shape[:2]
- new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
- rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
- if return_scale:
- return rescaled_img, scale_factor
- else:
- return rescaled_img
- def imresize(
- img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
- ):
- """Resize image to a given size.
- Args:
- img (ndarray): The input image.
- size (tuple[int]): Target size (w, h).
- return_scale (bool): Whether to return `w_scale` and `h_scale`.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend.
- out (ndarray): The output destination.
- backend (str | None): The image resize backend type. Options are `cv2`,
- `pillow`, `None`. If backend is None, the global imread_backend
- specified by ``mmcv.use_backend()`` will be used. Default: None.
- Returns:
- tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
- `resized_img`.
- """
- h, w = img.shape[:2]
- if backend is None:
- backend = "cv2"
- if backend not in ["cv2", "pillow"]:
- raise ValueError(
- f"backend: {backend} is not supported for resize."
- f"Supported backends are 'cv2', 'pillow'"
- )
- if backend == "pillow":
- assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
- pil_image = Image.fromarray(img)
- pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
- resized_img = np.array(pil_image)
- else:
- resized_img = cv2.resize(
- img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
- )
- if not return_scale:
- return resized_img
- else:
- w_scale = size[0] / w
- h_scale = size[1] / h
- return resized_img, w_scale, h_scale
- def rescale_size(old_size, scale, return_scale=False):
- """Calculate the new size to be rescaled to.
- Args:
- old_size (tuple[int]): The old size (w, h) of image.
- scale (float | tuple[int]): The scaling factor or maximum size.
- If it is a float number, then the image will be rescaled by this
- factor, else if it is a tuple of 2 integers, then the image will
- be rescaled as large as possible within the scale.
- return_scale (bool): Whether to return the scaling factor besides the
- rescaled image size.
- Returns:
- tuple[int]: The new rescaled image size.
- """
- w, h = old_size
- if isinstance(scale, (float, int)):
- if scale <= 0:
- raise ValueError(f"Invalid scale {scale}, must be positive.")
- scale_factor = scale
- elif isinstance(scale, tuple):
- max_long_edge = max(scale)
- max_short_edge = min(scale)
- scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
- else:
- raise TypeError(
- f"Scale must be a number or tuple of int, but got {type(scale)}"
- )
- new_size = _scale_size((w, h), scale_factor)
- if return_scale:
- return new_size, scale_factor
- else:
- return new_size
- def _scale_size(size, scale):
- """Rescale a size by a ratio.
- Args:
- size (tuple[int]): (w, h).
- scale (float | tuple(float)): Scaling factor.
- Returns:
- tuple[int]: scaled size.
- """
- if isinstance(scale, (float, int)):
- scale = (scale, scale)
- w, h = size
- return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
|