| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- import os
- from PIL import Image
- import cv2
- import numpy as np
- import onnxruntime
- from loguru import logger
- from mineru.backend.pipeline.model_list import AtomicModel
- from mineru.utils.enum_class import ModelPath
- from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
- class PaddleTableClsModel:
- def __init__(self):
- self.sess = onnxruntime.InferenceSession(
- os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
- )
- self.less_length = 256
- self.cw, self.ch = 224, 224
- self.std = [0.229, 0.224, 0.225]
- self.scale = 0.00392156862745098
- self.mean = [0.485, 0.456, 0.406]
- self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
- def preprocess(self, input_img):
- # 放大图片,使其最短边长为256
- h, w = input_img.shape[:2]
- scale = 256 / min(h, w)
- h_resize = round(h * scale)
- w_resize = round(w * scale)
- img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
- # 调整为224*224的正方形
- h, w = img.shape[:2]
- cw, ch = 224, 224
- x1 = max(0, (w - cw) // 2)
- y1 = max(0, (h - ch) // 2)
- x2 = min(w, x1 + cw)
- y2 = min(h, y1 + ch)
- if w < cw or h < ch:
- raise ValueError(
- f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
- )
- img = img[y1:y2, x1:x2, ...]
- # 正则化
- split_im = list(cv2.split(img))
- std = [0.229, 0.224, 0.225]
- scale = 0.00392156862745098
- mean = [0.485, 0.456, 0.406]
- alpha = [scale / std[i] for i in range(len(std))]
- beta = [-mean[i] / std[i] for i in range(len(std))]
- for c in range(img.shape[2]):
- split_im[c] = split_im[c].astype(np.float32)
- split_im[c] *= alpha[c]
- split_im[c] += beta[c]
- img = cv2.merge(split_im)
- # 5. 转换为 CHW 格式
- img = img.transpose((2, 0, 1))
- imgs = [img]
- x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
- return x
- def predict(self, input_img):
- if isinstance(input_img, Image.Image):
- np_img = np.asarray(input_img)
- elif isinstance(input_img, np.ndarray):
- np_img = input_img
- else:
- raise ValueError("Input must be a pillow object or a numpy array.")
- x = self.preprocess(np_img)
- result = self.sess.run(None, {"x": x})
- idx = np.argmax(result)
- conf = float(np.max(result))
- # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
- if idx == 0 and conf < 0.8:
- idx = 1
- return self.labels[idx], conf
|