paddle_table_cls.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import os
  2. from PIL import Image
  3. import cv2
  4. import numpy as np
  5. import onnxruntime
  6. from loguru import logger
  7. from tqdm import tqdm
  8. from mineru.backend.pipeline.model_list import AtomicModel
  9. from mineru.utils.enum_class import ModelPath
  10. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  11. class PaddleTableClsModel:
  12. def __init__(self):
  13. self.sess = onnxruntime.InferenceSession(
  14. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
  15. )
  16. self.less_length = 256
  17. self.cw, self.ch = 224, 224
  18. self.std = [0.229, 0.224, 0.225]
  19. self.scale = 0.00392156862745098
  20. self.mean = [0.485, 0.456, 0.406]
  21. self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
  22. def preprocess(self, input_img):
  23. # 放大图片,使其最短边长为256
  24. h, w = input_img.shape[:2]
  25. scale = 256 / min(h, w)
  26. h_resize = round(h * scale)
  27. w_resize = round(w * scale)
  28. img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
  29. # 调整为224*224的正方形
  30. h, w = img.shape[:2]
  31. cw, ch = 224, 224
  32. x1 = max(0, (w - cw) // 2)
  33. y1 = max(0, (h - ch) // 2)
  34. x2 = min(w, x1 + cw)
  35. y2 = min(h, y1 + ch)
  36. if w < cw or h < ch:
  37. raise ValueError(
  38. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  39. )
  40. img = img[y1:y2, x1:x2, ...]
  41. # 正则化
  42. split_im = list(cv2.split(img))
  43. std = [0.229, 0.224, 0.225]
  44. scale = 0.00392156862745098
  45. mean = [0.485, 0.456, 0.406]
  46. alpha = [scale / std[i] for i in range(len(std))]
  47. beta = [-mean[i] / std[i] for i in range(len(std))]
  48. for c in range(img.shape[2]):
  49. split_im[c] = split_im[c].astype(np.float32)
  50. split_im[c] *= alpha[c]
  51. split_im[c] += beta[c]
  52. img = cv2.merge(split_im)
  53. # 5. 转换为 CHW 格式
  54. img = img.transpose((2, 0, 1))
  55. imgs = [img]
  56. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  57. return x
  58. def predict(self, input_img):
  59. if isinstance(input_img, Image.Image):
  60. np_img = np.asarray(input_img)
  61. elif isinstance(input_img, np.ndarray):
  62. np_img = input_img
  63. else:
  64. raise ValueError("Input must be a pillow object or a numpy array.")
  65. x = self.preprocess(np_img)
  66. result = self.sess.run(None, {"x": x})
  67. idx = np.argmax(result)
  68. conf = float(np.max(result))
  69. return self.labels[idx], conf
  70. def list_2_batch(self, img_list, batch_size=16):
  71. """
  72. 将任意长度的列表按照指定的batch size分成多个batch
  73. Args:
  74. img_list: 输入的列表
  75. batch_size: 每个batch的大小,默认为16
  76. Returns:
  77. 一个包含多个batch的列表,每个batch都是原列表的一个子列表
  78. """
  79. batches = []
  80. for i in range(0, len(img_list), batch_size):
  81. batch = img_list[i : min(i + batch_size, len(img_list))]
  82. batches.append(batch)
  83. return batches
  84. def batch_preprocess(self, imgs):
  85. res_imgs = []
  86. for img in imgs:
  87. img = np.asarray(img)
  88. # 放大图片,使其最短边长为256
  89. h, w = img.shape[:2]
  90. scale = 256 / min(h, w)
  91. h_resize = round(h * scale)
  92. w_resize = round(w * scale)
  93. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  94. # 调整为224*224的正方形
  95. h, w = img.shape[:2]
  96. cw, ch = 224, 224
  97. x1 = max(0, (w - cw) // 2)
  98. y1 = max(0, (h - ch) // 2)
  99. x2 = min(w, x1 + cw)
  100. y2 = min(h, y1 + ch)
  101. if w < cw or h < ch:
  102. raise ValueError(
  103. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  104. )
  105. img = img[y1:y2, x1:x2, ...]
  106. # 正则化
  107. split_im = list(cv2.split(img))
  108. std = [0.229, 0.224, 0.225]
  109. scale = 0.00392156862745098
  110. mean = [0.485, 0.456, 0.406]
  111. alpha = [scale / std[i] for i in range(len(std))]
  112. beta = [-mean[i] / std[i] for i in range(len(std))]
  113. for c in range(img.shape[2]):
  114. split_im[c] = split_im[c].astype(np.float32)
  115. split_im[c] *= alpha[c]
  116. split_im[c] += beta[c]
  117. img = cv2.merge(split_im)
  118. # 5. 转换为 CHW 格式
  119. img = img.transpose((2, 0, 1))
  120. res_imgs.append(img)
  121. x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
  122. return x
  123. def batch_predict(self, img_info_list, batch_size=16):
  124. imgs = [item["wired_table_img"] for item in img_info_list]
  125. imgs = self.list_2_batch(imgs, batch_size=batch_size)
  126. label_res = []
  127. with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
  128. for img_batch in imgs:
  129. x = self.batch_preprocess(img_batch)
  130. result = self.sess.run(None, {"x": x})
  131. for img_res in result[0]:
  132. idx = np.argmax(img_res)
  133. conf = float(np.max(img_res))
  134. label_res.append((self.labels[idx],conf))
  135. pbar.update(len(img_batch))
  136. for img_info, (label, conf) in zip(img_info_list, label_res):
  137. img_info['table_res']["cls_label"] = label
  138. img_info['table_res']["cls_score"] = round(conf, 3)