| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- import os
- from PIL import Image
- import cv2
- import numpy as np
- import onnxruntime
- from loguru import logger
- from tqdm import tqdm
- from mineru.backend.pipeline.model_list import AtomicModel
- from mineru.utils.enum_class import ModelPath
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- class PaddleTableClsModel:
- def __init__(self):
- self.sess = onnxruntime.InferenceSession(
- os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
- )
- self.less_length = 256
- self.cw, self.ch = 224, 224
- self.std = [0.229, 0.224, 0.225]
- self.scale = 0.00392156862745098
- self.mean = [0.485, 0.456, 0.406]
- self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
- def preprocess(self, input_img):
- # 放大图片,使其最短边长为256
- h, w = input_img.shape[:2]
- scale = 256 / min(h, w)
- h_resize = round(h * scale)
- w_resize = round(w * scale)
- img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
- # 调整为224*224的正方形
- h, w = img.shape[:2]
- cw, ch = 224, 224
- x1 = max(0, (w - cw) // 2)
- y1 = max(0, (h - ch) // 2)
- x2 = min(w, x1 + cw)
- y2 = min(h, y1 + ch)
- if w < cw or h < ch:
- raise ValueError(
- f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
- )
- img = img[y1:y2, x1:x2, ...]
- # 正则化
- split_im = list(cv2.split(img))
- std = [0.229, 0.224, 0.225]
- scale = 0.00392156862745098
- mean = [0.485, 0.456, 0.406]
- alpha = [scale / std[i] for i in range(len(std))]
- beta = [-mean[i] / std[i] for i in range(len(std))]
- for c in range(img.shape[2]):
- split_im[c] = split_im[c].astype(np.float32)
- split_im[c] *= alpha[c]
- split_im[c] += beta[c]
- img = cv2.merge(split_im)
- # 5. 转换为 CHW 格式
- img = img.transpose((2, 0, 1))
- imgs = [img]
- x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
- return x
- def predict(self, input_img):
- if isinstance(input_img, Image.Image):
- np_img = np.asarray(input_img)
- elif isinstance(input_img, np.ndarray):
- np_img = input_img
- else:
- raise ValueError("Input must be a pillow object or a numpy array.")
- x = self.preprocess(np_img)
- result = self.sess.run(None, {"x": x})
- idx = np.argmax(result)
- conf = float(np.max(result))
- return self.labels[idx], conf
- def list_2_batch(self, img_list, batch_size=16):
- """
- 将任意长度的列表按照指定的batch size分成多个batch
- Args:
- img_list: 输入的列表
- batch_size: 每个batch的大小,默认为16
- Returns:
- 一个包含多个batch的列表,每个batch都是原列表的一个子列表
- """
- batches = []
- for i in range(0, len(img_list), batch_size):
- batch = img_list[i : min(i + batch_size, len(img_list))]
- batches.append(batch)
- return batches
- def batch_preprocess(self, imgs):
- res_imgs = []
- for img in imgs:
- img = np.asarray(img)
- # 放大图片,使其最短边长为256
- h, w = img.shape[:2]
- scale = 256 / min(h, w)
- h_resize = round(h * scale)
- w_resize = round(w * scale)
- img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
- # 调整为224*224的正方形
- h, w = img.shape[:2]
- cw, ch = 224, 224
- x1 = max(0, (w - cw) // 2)
- y1 = max(0, (h - ch) // 2)
- x2 = min(w, x1 + cw)
- y2 = min(h, y1 + ch)
- if w < cw or h < ch:
- raise ValueError(
- f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
- )
- img = img[y1:y2, x1:x2, ...]
- # 正则化
- split_im = list(cv2.split(img))
- std = [0.229, 0.224, 0.225]
- scale = 0.00392156862745098
- mean = [0.485, 0.456, 0.406]
- alpha = [scale / std[i] for i in range(len(std))]
- beta = [-mean[i] / std[i] for i in range(len(std))]
- for c in range(img.shape[2]):
- split_im[c] = split_im[c].astype(np.float32)
- split_im[c] *= alpha[c]
- split_im[c] += beta[c]
- img = cv2.merge(split_im)
- # 5. 转换为 CHW 格式
- img = img.transpose((2, 0, 1))
- res_imgs.append(img)
- x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
- return x
- def batch_predict(self, img_info_list, batch_size=16):
- imgs = [item["wired_table_img"] for item in img_info_list]
- imgs = self.list_2_batch(imgs, batch_size=batch_size)
- label_res = []
- with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
- for img_batch in imgs:
- x = self.batch_preprocess(img_batch)
- result = self.sess.run(None, {"x": x})
- for img_res in result[0]:
- idx = np.argmax(img_res)
- conf = float(np.max(img_res))
- label_res.append((self.labels[idx],conf))
- pbar.update(len(img_batch))
- for img_info, (label, conf) in zip(img_info_list, label_res):
- img_info['table_res']["cls_label"] = label
- img_info['table_res']["cls_score"] = round(conf, 3)
|