| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493 |
- import os
- import traceback
- from enum import Enum
- from io import BytesIO
- from pathlib import Path
- from typing import List, Union, Dict, Any, Tuple, Optional
- import cv2
- import loguru
- import numpy as np
- from onnxruntime import (
- GraphOptimizationLevel,
- InferenceSession,
- SessionOptions,
- get_available_providers,
- )
- from PIL import Image, UnidentifiedImageError
- root_dir = Path(__file__).resolve().parent
- InputType = Union[str, np.ndarray, bytes, Path]
- class EP(Enum):
- CPU_EP = "CPUExecutionProvider"
- class OrtInferSession:
- def __init__(self, config: Dict[str, Any]):
- self.logger = loguru.logger
- model_path = config.get("model_path", None)
- self.had_providers: List[str] = get_available_providers()
- EP_list = self._get_ep_list()
- sess_opt = self._init_sess_opts(config)
- self.session = InferenceSession(
- model_path,
- sess_options=sess_opt,
- providers=EP_list,
- )
- @staticmethod
- def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
- sess_opt = SessionOptions()
- sess_opt.log_severity_level = 4
- sess_opt.enable_cpu_mem_arena = False
- sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
- cpu_nums = os.cpu_count()
- intra_op_num_threads = config.get("intra_op_num_threads", -1)
- if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
- sess_opt.intra_op_num_threads = intra_op_num_threads
- inter_op_num_threads = config.get("inter_op_num_threads", -1)
- if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
- sess_opt.inter_op_num_threads = inter_op_num_threads
- return sess_opt
- def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
- cpu_provider_opts = {
- "arena_extend_strategy": "kSameAsRequested",
- }
- EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
- return EP_list
- def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
- input_dict = dict(zip(self.get_input_names(), input_content))
- try:
- return self.session.run(None, input_dict)
- except Exception as e:
- error_info = traceback.format_exc()
- raise ONNXRuntimeError(error_info) from e
- def get_input_names(self) -> List[str]:
- return [v.name for v in self.session.get_inputs()]
- class ONNXRuntimeError(Exception):
- pass
- class LoadImage:
- def __init__(
- self,
- ):
- pass
- def __call__(self, img: InputType) -> np.ndarray:
- if not isinstance(img, InputType.__args__):
- raise LoadImageError(
- f"The img type {type(img)} does not in {InputType.__args__}"
- )
- img = self.load_img(img)
- img = self.convert_img(img)
- return img
- def load_img(self, img: InputType) -> np.ndarray:
- if isinstance(img, (str, Path)):
- self.verify_exist(img)
- try:
- img = np.array(Image.open(img))
- except UnidentifiedImageError as e:
- raise LoadImageError(f"cannot identify image file {img}") from e
- return img
- if isinstance(img, bytes):
- img = np.array(Image.open(BytesIO(img)))
- return img
- if isinstance(img, np.ndarray):
- return img
- raise LoadImageError(f"{type(img)} is not supported!")
- def convert_img(self, img: np.ndarray):
- if img.ndim == 2:
- return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- if img.ndim == 3:
- channel = img.shape[2]
- if channel == 1:
- return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
- if channel == 2:
- return self.cvt_two_to_three(img)
- if channel == 4:
- return self.cvt_four_to_three(img)
- if channel == 3:
- return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- raise LoadImageError(
- f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
- )
- raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
- @staticmethod
- def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
- """RGBA → BGR"""
- r, g, b, a = cv2.split(img)
- new_img = cv2.merge((b, g, r))
- not_a = cv2.bitwise_not(a)
- not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
- new_img = cv2.bitwise_and(new_img, new_img, mask=a)
- new_img = cv2.add(new_img, not_a)
- return new_img
- @staticmethod
- def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
- """gray + alpha → BGR"""
- img_gray = img[..., 0]
- img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
- img_alpha = img[..., 1]
- not_a = cv2.bitwise_not(img_alpha)
- not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
- new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
- new_img = cv2.add(new_img, not_a)
- return new_img
- @staticmethod
- def verify_exist(file_path: Union[str, Path]):
- if not Path(file_path).exists():
- raise LoadImageError(f"{file_path} does not exist.")
- class LoadImageError(Exception):
- pass
- # Pillow >=v9.1.0 use a slightly different naming scheme for filters.
- # Set pillow_interp_codes according to the naming scheme used.
- if Image is not None:
- if hasattr(Image, "Resampling"):
- pillow_interp_codes = {
- "nearest": Image.Resampling.NEAREST,
- "bilinear": Image.Resampling.BILINEAR,
- "bicubic": Image.Resampling.BICUBIC,
- "box": Image.Resampling.BOX,
- "lanczos": Image.Resampling.LANCZOS,
- "hamming": Image.Resampling.HAMMING,
- }
- else:
- pillow_interp_codes = {
- "nearest": Image.NEAREST,
- "bilinear": Image.BILINEAR,
- "bicubic": Image.BICUBIC,
- "box": Image.BOX,
- "lanczos": Image.LANCZOS,
- "hamming": Image.HAMMING,
- }
- cv2_interp_codes = {
- "nearest": cv2.INTER_NEAREST,
- "bilinear": cv2.INTER_LINEAR,
- "bicubic": cv2.INTER_CUBIC,
- "area": cv2.INTER_AREA,
- "lanczos": cv2.INTER_LANCZOS4,
- }
- def resize_img(img, scale, keep_ratio=True):
- if keep_ratio:
- # 缩小使用area更保真
- if min(img.shape[:2]) > min(scale):
- interpolation = "area"
- else:
- interpolation = "bicubic" # bilinear
- img_new, scale_factor = imrescale(
- img, scale, return_scale=True, interpolation=interpolation
- )
- # the w_scale and h_scale has minor difference
- # a real fix should be done in the mmcv.imrescale in the future
- new_h, new_w = img_new.shape[:2]
- h, w = img.shape[:2]
- w_scale = new_w / w
- h_scale = new_h / h
- else:
- img_new, w_scale, h_scale = imresize(img, scale, return_scale=True)
- return img_new, w_scale, h_scale
- def imrescale(img, scale, return_scale=False, interpolation="bilinear", backend=None):
- """Resize image while keeping the aspect ratio.
- Args:
- img (ndarray): The input image.
- scale (float | tuple[int]): The scaling factor or maximum size.
- If it is a float number, then the image will be rescaled by this
- factor, else if it is a tuple of 2 integers, then the image will
- be rescaled as large as possible within the scale.
- return_scale (bool): Whether to return the scaling factor besides the
- rescaled image.
- interpolation (str): Same as :func:`resize`.
- backend (str | None): Same as :func:`resize`.
- Returns:
- ndarray: The rescaled image.
- """
- h, w = img.shape[:2]
- new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
- rescaled_img = imresize(img, new_size, interpolation=interpolation, backend=backend)
- if return_scale:
- return rescaled_img, scale_factor
- else:
- return rescaled_img
- def imresize(
- img, size, return_scale=False, interpolation="bilinear", out=None, backend=None
- ):
- """Resize image to a given size.
- Args:
- img (ndarray): The input image.
- size (tuple[int]): Target size (w, h).
- return_scale (bool): Whether to return `w_scale` and `h_scale`.
- interpolation (str): Interpolation method, accepted values are
- "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
- backend, "nearest", "bilinear" for 'pillow' backend.
- out (ndarray): The output destination.
- backend (str | None): The image resize backend type. Options are `cv2`,
- `pillow`, `None`. If backend is None, the global imread_backend
- specified by ``mmcv.use_backend()`` will be used. Default: None.
- Returns:
- tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
- `resized_img`.
- """
- h, w = img.shape[:2]
- if backend is None:
- backend = "cv2"
- if backend not in ["cv2", "pillow"]:
- raise ValueError(
- f"backend: {backend} is not supported for resize."
- f"Supported backends are 'cv2', 'pillow'"
- )
- if backend == "pillow":
- assert img.dtype == np.uint8, "Pillow backend only support uint8 type"
- pil_image = Image.fromarray(img)
- pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
- resized_img = np.array(pil_image)
- else:
- resized_img = cv2.resize(
- img, size, dst=out, interpolation=cv2_interp_codes[interpolation]
- )
- if not return_scale:
- return resized_img
- else:
- w_scale = size[0] / w
- h_scale = size[1] / h
- return resized_img, w_scale, h_scale
- def rescale_size(old_size, scale, return_scale=False):
- """Calculate the new size to be rescaled to.
- Args:
- old_size (tuple[int]): The old size (w, h) of image.
- scale (float | tuple[int]): The scaling factor or maximum size.
- If it is a float number, then the image will be rescaled by this
- factor, else if it is a tuple of 2 integers, then the image will
- be rescaled as large as possible within the scale.
- return_scale (bool): Whether to return the scaling factor besides the
- rescaled image size.
- Returns:
- tuple[int]: The new rescaled image size.
- """
- w, h = old_size
- if isinstance(scale, (float, int)):
- if scale <= 0:
- raise ValueError(f"Invalid scale {scale}, must be positive.")
- scale_factor = scale
- elif isinstance(scale, tuple):
- max_long_edge = max(scale)
- max_short_edge = min(scale)
- scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
- else:
- raise TypeError(
- f"Scale must be a number or tuple of int, but got {type(scale)}"
- )
- new_size = _scale_size((w, h), scale_factor)
- if return_scale:
- return new_size, scale_factor
- else:
- return new_size
- def _scale_size(size, scale):
- """Rescale a size by a ratio.
- Args:
- size (tuple[int]): (w, h).
- scale (float | tuple(float)): Scaling factor.
- Returns:
- tuple[int]: scaled size.
- """
- if isinstance(scale, (float, int)):
- scale = (scale, scale)
- w, h = size
- return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
- class VisTable:
- def __init__(self):
- self.load_img = LoadImage()
- def __call__(
- self,
- img_path: Union[str, Path],
- table_results,
- save_html_path: Optional[Union[str, Path]] = None,
- save_drawed_path: Optional[Union[str, Path]] = None,
- save_logic_path: Optional[Union[str, Path]] = None,
- ):
- if save_html_path:
- html_with_border = self.insert_border_style(table_results.pred_html)
- self.save_html(save_html_path, html_with_border)
- table_cell_bboxes = table_results.cell_bboxes
- table_logic_points = table_results.logic_points
- if table_cell_bboxes is None:
- return None
- img = self.load_img(img_path)
- dims_bboxes = table_cell_bboxes.shape[1]
- if dims_bboxes == 4:
- drawed_img = self.draw_rectangle(img, table_cell_bboxes)
- elif dims_bboxes == 8:
- drawed_img = self.draw_polylines(img, table_cell_bboxes)
- else:
- raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
- if save_drawed_path:
- self.save_img(save_drawed_path, drawed_img)
- if save_logic_path:
- polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
- self.plot_rec_box_with_logic_info(
- img_path, save_logic_path, table_logic_points, polygons
- )
- return drawed_img
- def insert_border_style(self, table_html_str: str):
- style_res = """<meta charset="UTF-8"><style>
- table {
- border-collapse: collapse;
- width: 100%;
- }
- th, td {
- border: 1px solid black;
- padding: 8px;
- text-align: center;
- }
- th {
- background-color: #f2f2f2;
- }
- </style>"""
- prefix_table, suffix_table = table_html_str.split("<body>")
- html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
- return html_with_border
- def plot_rec_box_with_logic_info(
- self, img_path, output_path, logic_points, sorted_polygons
- ):
- """
- :param img_path
- :param output_path
- :param logic_points: [row_start,row_end,col_start,col_end]
- :param sorted_polygons: [xmin,ymin,xmax,ymax]
- :return:
- """
- # 读取原图
- img = cv2.imread(img_path)
- img = cv2.copyMakeBorder(
- img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
- )
- # 绘制 polygons 矩形
- for idx, polygon in enumerate(sorted_polygons):
- x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
- x0 = round(x0)
- y0 = round(y0)
- x1 = round(x1)
- y1 = round(y1)
- cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
- # 增大字体大小和线宽
- font_scale = 0.9 # 原先是0.5
- thickness = 1 # 原先是1
- logic_point = logic_points[idx]
- cv2.putText(
- img,
- f"row: {logic_point[0]}-{logic_point[1]}",
- (x0 + 3, y0 + 8),
- cv2.FONT_HERSHEY_PLAIN,
- font_scale,
- (0, 0, 255),
- thickness,
- )
- cv2.putText(
- img,
- f"col: {logic_point[2]}-{logic_point[3]}",
- (x0 + 3, y0 + 18),
- cv2.FONT_HERSHEY_PLAIN,
- font_scale,
- (0, 0, 255),
- thickness,
- )
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- # 保存绘制后的图像
- self.save_img(output_path, img)
- @staticmethod
- def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
- img_copy = img.copy()
- for box in boxes.astype(int):
- x1, y1, x2, y2 = box
- cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
- return img_copy
- @staticmethod
- def draw_polylines(img: np.ndarray, points) -> np.ndarray:
- img_copy = img.copy()
- for point in points.astype(int):
- point = point.reshape(4, 2)
- cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
- return img_copy
- @staticmethod
- def save_img(save_path: Union[str, Path], img: np.ndarray):
- cv2.imwrite(str(save_path), img)
- @staticmethod
- def save_html(save_path: Union[str, Path], html: str):
- with open(save_path, "w", encoding="utf-8") as f:
- f.write(html)
|