paddle_table_cls.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import os
  2. from pathlib import Path
  3. from PIL import Image
  4. import cv2
  5. import numpy as np
  6. import onnxruntime
  7. from loguru import logger
  8. from tqdm import tqdm
  9. from mineru.backend.pipeline.model_list import AtomicModel
  10. from mineru.utils.enum_class import ModelPath
  11. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  12. class PaddleTableClsModel:
  13. def __init__(self):
  14. self.sess = onnxruntime.InferenceSession(
  15. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
  16. )
  17. self.less_length = 256
  18. self.cw, self.ch = 224, 224
  19. self.std = [0.229, 0.224, 0.225]
  20. self.scale = 0.00392156862745098
  21. self.mean = [0.485, 0.456, 0.406]
  22. self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
  23. def preprocess(self, input_img):
  24. # 放大图片,使其最短边长为256
  25. h, w = input_img.shape[:2]
  26. scale = 256 / min(h, w)
  27. h_resize = round(h * scale)
  28. w_resize = round(w * scale)
  29. img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
  30. # 调整为224*224的正方形
  31. h, w = img.shape[:2]
  32. cw, ch = 224, 224
  33. x1 = max(0, (w - cw) // 2)
  34. y1 = max(0, (h - ch) // 2)
  35. x2 = min(w, x1 + cw)
  36. y2 = min(h, y1 + ch)
  37. if w < cw or h < ch:
  38. raise ValueError(
  39. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  40. )
  41. img = img[y1:y2, x1:x2, ...]
  42. # 正则化
  43. split_im = list(cv2.split(img))
  44. std = [0.229, 0.224, 0.225]
  45. scale = 0.00392156862745098
  46. mean = [0.485, 0.456, 0.406]
  47. alpha = [scale / std[i] for i in range(len(std))]
  48. beta = [-mean[i] / std[i] for i in range(len(std))]
  49. for c in range(img.shape[2]):
  50. split_im[c] = split_im[c].astype(np.float32)
  51. split_im[c] *= alpha[c]
  52. split_im[c] += beta[c]
  53. img = cv2.merge(split_im)
  54. # 5. 转换为 CHW 格式
  55. img = img.transpose((2, 0, 1))
  56. imgs = [img]
  57. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  58. return x
  59. def predict(self, input_img):
  60. if isinstance(input_img, Image.Image):
  61. np_img = np.asarray(input_img)
  62. elif isinstance(input_img, np.ndarray):
  63. np_img = input_img
  64. else:
  65. raise ValueError("Input must be a pillow object or a numpy array.")
  66. x = self.preprocess(np_img)
  67. result = self.sess.run(None, {"x": x})
  68. idx = np.argmax(result)
  69. conf = float(np.max(result))
  70. # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
  71. if idx == 0 and conf < 0.8:
  72. idx = 1
  73. return self.labels[idx], conf
  74. def list_2_batch(self, img_list, batch_size=16):
  75. """
  76. 将任意长度的列表按照指定的batch size分成多个batch
  77. Args:
  78. img_list: 输入的列表
  79. batch_size: 每个batch的大小,默认为16
  80. Returns:
  81. 一个包含多个batch的列表,每个batch都是原列表的一个子列表
  82. """
  83. batches = []
  84. for i in range(0, len(img_list), batch_size):
  85. batch = img_list[i : min(i + batch_size, len(img_list))]
  86. batches.append(batch)
  87. return batches
  88. def batch_preprocess(self, imgs):
  89. res_imgs = []
  90. for img in imgs:
  91. img = np.asarray(img)
  92. # 放大图片,使其最短边长为256
  93. h, w = img.shape[:2]
  94. scale = 256 / min(h, w)
  95. h_resize = round(h * scale)
  96. w_resize = round(w * scale)
  97. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  98. # 调整为224*224的正方形
  99. h, w = img.shape[:2]
  100. cw, ch = 224, 224
  101. x1 = max(0, (w - cw) // 2)
  102. y1 = max(0, (h - ch) // 2)
  103. x2 = min(w, x1 + cw)
  104. y2 = min(h, y1 + ch)
  105. if w < cw or h < ch:
  106. raise ValueError(
  107. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  108. )
  109. img = img[y1:y2, x1:x2, ...]
  110. # 正则化
  111. split_im = list(cv2.split(img))
  112. std = [0.229, 0.224, 0.225]
  113. scale = 0.00392156862745098
  114. mean = [0.485, 0.456, 0.406]
  115. alpha = [scale / std[i] for i in range(len(std))]
  116. beta = [-mean[i] / std[i] for i in range(len(std))]
  117. for c in range(img.shape[2]):
  118. split_im[c] = split_im[c].astype(np.float32)
  119. split_im[c] *= alpha[c]
  120. split_im[c] += beta[c]
  121. img = cv2.merge(split_im)
  122. # 5. 转换为 CHW 格式
  123. img = img.transpose((2, 0, 1))
  124. res_imgs.append(img)
  125. x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
  126. return x
  127. def batch_predict(self, img_info_list, batch_size=16):
  128. imgs = [item["table_img"] for item in img_info_list]
  129. imgs = self.list_2_batch(imgs, batch_size=batch_size)
  130. label_res = []
  131. with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict") as pbar:
  132. for img_batch in imgs:
  133. x = self.batch_preprocess(img_batch)
  134. result = self.sess.run(None, {"x": x})
  135. for img_res in result[0]:
  136. idx = np.argmax(img_res)
  137. conf = float(np.max(img_res))
  138. # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
  139. if idx == 0 and conf < 0.8:
  140. idx = 1
  141. label_res.append((self.labels[idx],conf))
  142. pbar.update(len(img_batch))
  143. for img_info, (label, conf) in zip(img_info_list, label_res):
  144. img_info['table_res']["cls_label"] = label
  145. img_info['table_res']["cls_score"] = round(conf, 3)