paddle_ori_cls.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. from PIL import Image
  4. from collections import defaultdict
  5. from typing import List, Dict
  6. from tqdm import tqdm
  7. import cv2
  8. import numpy as np
  9. import onnxruntime
  10. from mineru.utils.enum_class import ModelPath
  11. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  12. class PaddleOrientationClsModel:
  13. def __init__(self, ocr_engine):
  14. self.sess = onnxruntime.InferenceSession(
  15. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
  16. )
  17. self.ocr_engine = ocr_engine
  18. self.less_length = 256
  19. self.cw, self.ch = 224, 224
  20. self.std = [0.229, 0.224, 0.225]
  21. self.scale = 0.00392156862745098
  22. self.mean = [0.485, 0.456, 0.406]
  23. self.labels = ["0", "90", "180", "270"]
  24. def preprocess(self, input_img):
  25. # 放大图片,使其最短边长为256
  26. h, w = input_img.shape[:2]
  27. scale = 256 / min(h, w)
  28. h_resize = round(h * scale)
  29. w_resize = round(w * scale)
  30. img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
  31. # 调整为224*224的正方形
  32. h, w = img.shape[:2]
  33. cw, ch = 224, 224
  34. x1 = max(0, (w - cw) // 2)
  35. y1 = max(0, (h - ch) // 2)
  36. x2 = min(w, x1 + cw)
  37. y2 = min(h, y1 + ch)
  38. if w < cw or h < ch:
  39. raise ValueError(
  40. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  41. )
  42. img = img[y1:y2, x1:x2, ...]
  43. # 正则化
  44. split_im = list(cv2.split(img))
  45. std = [0.229, 0.224, 0.225]
  46. scale = 0.00392156862745098
  47. mean = [0.485, 0.456, 0.406]
  48. alpha = [scale / std[i] for i in range(len(std))]
  49. beta = [-mean[i] / std[i] for i in range(len(std))]
  50. for c in range(img.shape[2]):
  51. split_im[c] = split_im[c].astype(np.float32)
  52. split_im[c] *= alpha[c]
  53. split_im[c] += beta[c]
  54. img = cv2.merge(split_im)
  55. # 5. 转换为 CHW 格式
  56. img = img.transpose((2, 0, 1))
  57. imgs = [img]
  58. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  59. return x
  60. def predict(self, input_img):
  61. rotate_label = "0" # Default to 0 if no rotation detected or not portrait
  62. if isinstance(input_img, Image.Image):
  63. np_img = np.asarray(input_img)
  64. elif isinstance(input_img, np.ndarray):
  65. np_img = input_img
  66. else:
  67. raise ValueError("Input must be a pillow object or a numpy array.")
  68. bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
  69. # First check the overall image aspect ratio (height/width)
  70. img_height, img_width = bgr_image.shape[:2]
  71. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  72. img_is_portrait = img_aspect_ratio > 1.2
  73. if img_is_portrait:
  74. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  75. # Check if table is rotated by analyzing text box aspect ratios
  76. if det_res:
  77. vertical_count = 0
  78. is_rotated = False
  79. for box_ocr_res in det_res:
  80. p1, p2, p3, p4 = box_ocr_res
  81. # Calculate width and height
  82. width = p3[0] - p1[0]
  83. height = p3[1] - p1[1]
  84. aspect_ratio = width / height if height > 0 else 1.0
  85. # Count vertical vs horizontal text boxes
  86. if aspect_ratio < 0.8: # Taller than wide - vertical text
  87. vertical_count += 1
  88. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  89. # horizontal_count += 1
  90. if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
  91. is_rotated = True
  92. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  93. # If we have more vertical text boxes than horizontal ones,
  94. # and vertical ones are significant, table might be rotated
  95. if is_rotated:
  96. x = self.preprocess(np_img)
  97. (result,) = self.sess.run(None, {"x": x})
  98. rotate_label = self.labels[np.argmax(result)]
  99. # logger.debug(f"Orientation classification result: {label}")
  100. return rotate_label
  101. def list_2_batch(self, img_list, batch_size=16):
  102. """
  103. 将任意长度的列表按照指定的batch size分成多个batch
  104. Args:
  105. img_list: 输入的列表
  106. batch_size: 每个batch的大小,默认为16
  107. Returns:
  108. 一个包含多个batch的列表,每个batch都是原列表的一个子列表
  109. """
  110. batches = []
  111. for i in range(0, len(img_list), batch_size):
  112. batch = img_list[i : min(i + batch_size, len(img_list))]
  113. batches.append(batch)
  114. return batches
  115. def batch_preprocess(self, imgs):
  116. res_imgs = []
  117. for img_info in imgs:
  118. img = np.asarray(img_info["table_img"])
  119. # 放大图片,使其最短边长为256
  120. h, w = img.shape[:2]
  121. scale = 256 / min(h, w)
  122. h_resize = round(h * scale)
  123. w_resize = round(w * scale)
  124. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  125. # 调整为224*224的正方形
  126. h, w = img.shape[:2]
  127. cw, ch = 224, 224
  128. x1 = max(0, (w - cw) // 2)
  129. y1 = max(0, (h - ch) // 2)
  130. x2 = min(w, x1 + cw)
  131. y2 = min(h, y1 + ch)
  132. if w < cw or h < ch:
  133. raise ValueError(
  134. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  135. )
  136. img = img[y1:y2, x1:x2, ...]
  137. # 正则化
  138. split_im = list(cv2.split(img))
  139. std = [0.229, 0.224, 0.225]
  140. scale = 0.00392156862745098
  141. mean = [0.485, 0.456, 0.406]
  142. alpha = [scale / std[i] for i in range(len(std))]
  143. beta = [-mean[i] / std[i] for i in range(len(std))]
  144. for c in range(img.shape[2]):
  145. split_im[c] = split_im[c].astype(np.float32)
  146. split_im[c] *= alpha[c]
  147. split_im[c] += beta[c]
  148. img = cv2.merge(split_im)
  149. # 5. 转换为 CHW 格式
  150. img = img.transpose((2, 0, 1))
  151. res_imgs.append(img)
  152. x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
  153. return x
  154. def batch_predict(
  155. self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
  156. ) -> None:
  157. import torch
  158. from packaging import version
  159. if version.parse(torch.__version__) >= version.parse("2.8.0"):
  160. return None
  161. """
  162. 批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
  163. """
  164. RESOLUTION_GROUP_STRIDE = 128
  165. # 跳过长宽比小于1.2的图片
  166. resolution_groups = defaultdict(list)
  167. for img in imgs:
  168. # RGB图像转换BGR
  169. bgr_img: np.ndarray = cv2.cvtColor(np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR)
  170. img["table_img_bgr"] = bgr_img
  171. img_height, img_width = bgr_img.shape[:2]
  172. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  173. if img_aspect_ratio > 1.2:
  174. # 归一化尺寸到RESOLUTION_GROUP_STRIDE的倍数
  175. normalized_h = ((img_height + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE # 向上取整到RESOLUTION_GROUP_STRIDE的倍数
  176. normalized_w = ((img_width + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
  177. group_key = (normalized_h, normalized_w)
  178. resolution_groups[group_key].append(img)
  179. # 对每个分辨率组进行批处理
  180. rotated_imgs = []
  181. for group_key, group_imgs in tqdm(resolution_groups.items(), desc="Table-ori cls stage1 predict", disable=True):
  182. # 计算目标尺寸(组内最大尺寸,向上取整到RESOLUTION_GROUP_STRIDE的倍数)
  183. max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
  184. max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
  185. target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
  186. target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
  187. # 对所有图像进行padding到统一尺寸
  188. batch_images = []
  189. for img in group_imgs:
  190. bgr_img = img["table_img_bgr"]
  191. h, w = bgr_img.shape[:2]
  192. # 创建目标尺寸的白色背景
  193. padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
  194. # 将原图像粘贴到左上角
  195. padded_img[:h, :w] = bgr_img
  196. batch_images.append(padded_img)
  197. # 批处理检测
  198. batch_results = self.ocr_engine.text_detector.batch_predict(
  199. batch_images, min(len(batch_images), det_batch_size)
  200. )
  201. # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
  202. for index, (img_info, (dt_boxes, elapse)) in enumerate(
  203. zip(group_imgs, batch_results)
  204. ):
  205. vertical_count = 0
  206. for box_ocr_res in dt_boxes:
  207. p1, p2, p3, p4 = box_ocr_res
  208. # Calculate width and height
  209. width = p3[0] - p1[0]
  210. height = p3[1] - p1[1]
  211. aspect_ratio = width / height if height > 0 else 1.0
  212. # Count vertical text boxes
  213. if aspect_ratio < 0.8: # Taller than wide - vertical text
  214. vertical_count += 1
  215. if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
  216. rotated_imgs.append(img_info)
  217. # 对旋转的图片进行旋转角度预测
  218. if len(rotated_imgs) > 0:
  219. imgs = self.list_2_batch(rotated_imgs, batch_size=batch_size)
  220. with tqdm(total=len(rotated_imgs), desc="Table-ori cls stage2 predict", disable=True) as pbar:
  221. for img_batch in imgs:
  222. x = self.batch_preprocess(img_batch)
  223. results = self.sess.run(None, {"x": x})
  224. for img_info, res in zip(rotated_imgs, results[0]):
  225. label = self.labels[np.argmax(res)]
  226. self.img_rotate(img_info, label)
  227. pbar.update(1)
  228. def img_rotate(self, img_info, label):
  229. if label == "270":
  230. img_info["table_img"] = cv2.rotate(
  231. np.asarray(img_info["table_img"]),
  232. cv2.ROTATE_90_CLOCKWISE,
  233. )
  234. img_info["wired_table_img"] = cv2.rotate(
  235. np.asarray(img_info["wired_table_img"]),
  236. cv2.ROTATE_90_CLOCKWISE,
  237. )
  238. elif label == "90":
  239. img_info["table_img"] = cv2.rotate(
  240. np.asarray(img_info["table_img"]),
  241. cv2.ROTATE_90_COUNTERCLOCKWISE,
  242. )
  243. img_info["wired_table_img"] = cv2.rotate(
  244. np.asarray(img_info["wired_table_img"]),
  245. cv2.ROTATE_90_COUNTERCLOCKWISE,
  246. )
  247. else:
  248. # 180度和0度不做处理
  249. pass