paddle_table_cls.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import os
  2. from pathlib import Path
  3. from PIL import Image
  4. import cv2
  5. import numpy as np
  6. import onnxruntime
  7. from loguru import logger
  8. from mineru.backend.pipeline.model_list import AtomicModel
  9. from mineru.utils.enum_class import ModelPath
  10. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  11. class PaddleTableClsModel:
  12. def __init__(self):
  13. self.sess = onnxruntime.InferenceSession(
  14. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
  15. )
  16. self.less_length = 256
  17. self.cw, self.ch = 224, 224
  18. self.std = [0.229, 0.224, 0.225]
  19. self.scale = 0.00392156862745098
  20. self.mean = [0.485, 0.456, 0.406]
  21. self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
  22. def preprocess(self, input_img):
  23. # 放大图片,使其最短边长为256
  24. h, w = input_img.shape[:2]
  25. scale = 256 / min(h, w)
  26. h_resize = round(h * scale)
  27. w_resize = round(w * scale)
  28. img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
  29. # 调整为224*224的正方形
  30. h, w = img.shape[:2]
  31. cw, ch = 224, 224
  32. x1 = max(0, (w - cw) // 2)
  33. y1 = max(0, (h - ch) // 2)
  34. x2 = min(w, x1 + cw)
  35. y2 = min(h, y1 + ch)
  36. if w < cw or h < ch:
  37. raise ValueError(
  38. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  39. )
  40. img = img[y1:y2, x1:x2, ...]
  41. # 正则化
  42. split_im = list(cv2.split(img))
  43. std = [0.229, 0.224, 0.225]
  44. scale = 0.00392156862745098
  45. mean = [0.485, 0.456, 0.406]
  46. alpha = [scale / std[i] for i in range(len(std))]
  47. beta = [-mean[i] / std[i] for i in range(len(std))]
  48. for c in range(img.shape[2]):
  49. split_im[c] = split_im[c].astype(np.float32)
  50. split_im[c] *= alpha[c]
  51. split_im[c] += beta[c]
  52. img = cv2.merge(split_im)
  53. # 5. 转换为 CHW 格式
  54. img = img.transpose((2, 0, 1))
  55. imgs = [img]
  56. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  57. return x
  58. def predict(self, input_img):
  59. if isinstance(input_img, Image.Image):
  60. np_img = np.asarray(input_img)
  61. elif isinstance(input_img, np.ndarray):
  62. np_img = input_img
  63. else:
  64. raise ValueError("Input must be a pillow object or a numpy array.")
  65. x = self.preprocess(np_img)
  66. result = self.sess.run(None, {"x": x})
  67. idx = np.argmax(result)
  68. conf = float(np.max(result))
  69. # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
  70. if idx == 0 and conf < 0.8:
  71. idx = 1
  72. return self.labels[idx], conf
  73. def list_2_batch(self, img_list, batch_size=16):
  74. """
  75. 将任意长度的列表按照指定的batch size分成多个batch
  76. Args:
  77. img_list: 输入的列表
  78. batch_size: 每个batch的大小,默认为16
  79. Returns:
  80. 一个包含多个batch的列表,每个batch都是原列表的一个子列表
  81. """
  82. batches = []
  83. for i in range(0, len(img_list), batch_size):
  84. batch = img_list[i : min(i + batch_size, len(img_list))]
  85. batches.append(batch)
  86. return batches
  87. def batch_preprocess(self, imgs):
  88. res_imgs = []
  89. for img in imgs:
  90. # PIL图像转cv2
  91. img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  92. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  93. # 放大图片,使其最短边长为256
  94. h, w = img.shape[:2]
  95. scale = 256 / min(h, w)
  96. h_resize = round(h * scale)
  97. w_resize = round(w * scale)
  98. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  99. # 调整为224*224的正方形
  100. h, w = img.shape[:2]
  101. cw, ch = 224, 224
  102. x1 = max(0, (w - cw) // 2)
  103. y1 = max(0, (h - ch) // 2)
  104. x2 = min(w, x1 + cw)
  105. y2 = min(h, y1 + ch)
  106. if w < cw or h < ch:
  107. raise ValueError(
  108. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  109. )
  110. img = img[y1:y2, x1:x2, ...]
  111. # 正则化
  112. split_im = list(cv2.split(img))
  113. std = [0.229, 0.224, 0.225]
  114. scale = 0.00392156862745098
  115. mean = [0.485, 0.456, 0.406]
  116. alpha = [scale / std[i] for i in range(len(std))]
  117. beta = [-mean[i] / std[i] for i in range(len(std))]
  118. for c in range(img.shape[2]):
  119. split_im[c] = split_im[c].astype(np.float32)
  120. split_im[c] *= alpha[c]
  121. split_im[c] += beta[c]
  122. img = cv2.merge(split_im)
  123. # 5. 转换为 CHW 格式
  124. img = img.transpose((2, 0, 1))
  125. res_imgs.append(img)
  126. x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
  127. return x
  128. def batch_predict(self, img_info_list, batch_size=16):
  129. imgs = [item["table_img"] for item in img_info_list]
  130. imgs = self.list_2_batch(imgs, batch_size=batch_size)
  131. label_res = []
  132. for img_batch in imgs:
  133. x = self.batch_preprocess(img_batch)
  134. result = self.sess.run(None, {"x": x})
  135. for img_res in result[0]:
  136. idx = np.argmax(img_res)
  137. conf = float(np.max(img_res))
  138. if idx == 0 and conf < 0.9:
  139. idx = 1
  140. label_res.append((self.labels[idx],conf))
  141. for img_info, (label, conf) in zip(img_info_list, label_res):
  142. img_info['table_res']["cls_label"] = label
  143. img_info['table_res']["cls_score"] = conf