paddle_table_cls.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. from pathlib import Path
  3. import cv2
  4. import numpy as np
  5. import onnxruntime
  6. from loguru import logger
  7. from mineru.backend.pipeline.model_list import AtomicModel
  8. from mineru.utils.enum_class import ModelPath
  9. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  10. class PaddleTableClsModel:
  11. def __init__(self):
  12. self.sess = onnxruntime.InferenceSession(
  13. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
  14. )
  15. self.less_length = 256
  16. self.cw, self.ch = 224, 224
  17. self.std = [0.229, 0.224, 0.225]
  18. self.scale = 0.00392156862745098
  19. self.mean = [0.485, 0.456, 0.406]
  20. self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]
  21. def preprocess(self, img):
  22. # PIL图像转cv2
  23. img = np.array(img)
  24. # 放大图片,使其最短边长为256
  25. h, w = img.shape[:2]
  26. scale = 256 / min(h, w)
  27. h_resize = round(h * scale)
  28. w_resize = round(w * scale)
  29. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  30. # 调整为224*224的正方形
  31. h, w = img.shape[:2]
  32. cw, ch = 224, 224
  33. x1 = max(0, (w - cw) // 2)
  34. y1 = max(0, (h - ch) // 2)
  35. x2 = min(w, x1 + cw)
  36. y2 = min(h, y1 + ch)
  37. if w < cw or h < ch:
  38. raise ValueError(
  39. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  40. )
  41. img = img[y1:y2, x1:x2, ...]
  42. # 正则化
  43. split_im = list(cv2.split(img))
  44. std = [0.229, 0.224, 0.225]
  45. scale = 0.00392156862745098
  46. mean = [0.485, 0.456, 0.406]
  47. alpha = [scale / std[i] for i in range(len(std))]
  48. beta = [-mean[i] / std[i] for i in range(len(std))]
  49. for c in range(img.shape[2]):
  50. split_im[c] = split_im[c].astype(np.float32)
  51. split_im[c] *= alpha[c]
  52. split_im[c] += beta[c]
  53. img = cv2.merge(split_im)
  54. # 5. 转换为 CHW 格式
  55. img = img.transpose((2, 0, 1))
  56. imgs = [img]
  57. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  58. return x
  59. def list_2_batch(self, img_list, batch_size=16):
  60. """
  61. 将任意长度的列表按照指定的batch size分成多个batch
  62. Args:
  63. img_list: 输入的列表
  64. batch_size: 每个batch的大小,默认为16
  65. Returns:
  66. 一个包含多个batch的列表,每个batch都是原列表的一个子列表
  67. """
  68. batches = []
  69. for i in range(0, len(img_list), batch_size):
  70. batch = img_list[i : min(i + batch_size, len(img_list))]
  71. batches.append(batch)
  72. return batches
  73. def batch_preprocess(self, imgs):
  74. res_imgs = []
  75. for img in imgs:
  76. # PIL图像转cv2
  77. img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  78. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  79. # 放大图片,使其最短边长为256
  80. h, w = img.shape[:2]
  81. scale = 256 / min(h, w)
  82. h_resize = round(h * scale)
  83. w_resize = round(w * scale)
  84. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  85. # 调整为224*224的正方形
  86. h, w = img.shape[:2]
  87. cw, ch = 224, 224
  88. x1 = max(0, (w - cw) // 2)
  89. y1 = max(0, (h - ch) // 2)
  90. x2 = min(w, x1 + cw)
  91. y2 = min(h, y1 + ch)
  92. if w < cw or h < ch:
  93. raise ValueError(
  94. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  95. )
  96. img = img[y1:y2, x1:x2, ...]
  97. # 正则化
  98. split_im = list(cv2.split(img))
  99. std = [0.229, 0.224, 0.225]
  100. scale = 0.00392156862745098
  101. mean = [0.485, 0.456, 0.406]
  102. alpha = [scale / std[i] for i in range(len(std))]
  103. beta = [-mean[i] / std[i] for i in range(len(std))]
  104. for c in range(img.shape[2]):
  105. split_im[c] = split_im[c].astype(np.float32)
  106. split_im[c] *= alpha[c]
  107. split_im[c] += beta[c]
  108. img = cv2.merge(split_im)
  109. # 5. 转换为 CHW 格式
  110. img = img.transpose((2, 0, 1))
  111. res_imgs.append(img)
  112. x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
  113. return x
  114. def predict(self, img):
  115. x = self.preprocess(img)
  116. result = self.sess.run(None, {"x": x})
  117. idx = np.argmax(result)
  118. conf = float(np.max(result))
  119. # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
  120. if idx == 0 and conf < 0.8:
  121. idx = 1
  122. return self.labels[idx], conf
  123. def batch_predict(self, img_info_list, batch_size=16):
  124. imgs = [item["table_img"] for item in img_info_list]
  125. imgs = self.list_2_batch(imgs, batch_size=batch_size)
  126. label_res = []
  127. for img_batch in imgs:
  128. x = self.batch_preprocess(img_batch)
  129. result = self.sess.run(None, {"x": x})
  130. for img_res in result[0]:
  131. idx = np.argmax(img_res)
  132. conf = float(np.max(img_res))
  133. if idx == 0 and conf < 0.9:
  134. idx = 1
  135. label_res.append((self.labels[idx],conf))
  136. for img_info, (label, conf) in zip(img_info_list, label_res):
  137. img_info['table_res']["cls_label"] = label
  138. img_info['table_res']["cls_score"] = conf