paddle_table_cls.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import os
  2. from PIL import Image
  3. import cv2
  4. import numpy as np
  5. import onnxruntime
  6. from loguru import logger
  7. from mineru.backend.pipeline.model_list import AtomicModel
  8. from mineru.utils.enum_class import ModelPath
  9. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  10. class PaddleTableClsModel:
  11. def __init__(self):
  12. self.sess = onnxruntime.InferenceSession(
  13. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
  14. )
  15. self.less_length = 256
  16. self.cw, self.ch = 224, 224
  17. self.std = [0.229, 0.224, 0.225]
  18. self.scale = 0.00392156862745098
  19. self.mean = [0.485, 0.456, 0.406]
  20. self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
  21. def preprocess(self, input_img):
  22. # 放大图片,使其最短边长为256
  23. h, w = input_img.shape[:2]
  24. scale = 256 / min(h, w)
  25. h_resize = round(h * scale)
  26. w_resize = round(w * scale)
  27. img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
  28. # 调整为224*224的正方形
  29. h, w = img.shape[:2]
  30. cw, ch = 224, 224
  31. x1 = max(0, (w - cw) // 2)
  32. y1 = max(0, (h - ch) // 2)
  33. x2 = min(w, x1 + cw)
  34. y2 = min(h, y1 + ch)
  35. if w < cw or h < ch:
  36. raise ValueError(
  37. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  38. )
  39. img = img[y1:y2, x1:x2, ...]
  40. # 正则化
  41. split_im = list(cv2.split(img))
  42. std = [0.229, 0.224, 0.225]
  43. scale = 0.00392156862745098
  44. mean = [0.485, 0.456, 0.406]
  45. alpha = [scale / std[i] for i in range(len(std))]
  46. beta = [-mean[i] / std[i] for i in range(len(std))]
  47. for c in range(img.shape[2]):
  48. split_im[c] = split_im[c].astype(np.float32)
  49. split_im[c] *= alpha[c]
  50. split_im[c] += beta[c]
  51. img = cv2.merge(split_im)
  52. # 5. 转换为 CHW 格式
  53. img = img.transpose((2, 0, 1))
  54. imgs = [img]
  55. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  56. return x
  57. def predict(self, input_img):
  58. if isinstance(input_img, Image.Image):
  59. np_img = np.asarray(input_img)
  60. elif isinstance(input_img, np.ndarray):
  61. np_img = input_img
  62. else:
  63. raise ValueError("Input must be a pillow object or a numpy array.")
  64. x = self.preprocess(np_img)
  65. result = self.sess.run(None, {"x": x})
  66. idx = np.argmax(result)
  67. conf = float(np.max(result))
  68. # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
  69. if idx == 0 and conf < 0.8:
  70. idx = 1
  71. return self.labels[idx], conf