paddle_ori_cls.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. from PIL import Image
  4. from collections import defaultdict
  5. from typing import List, Dict
  6. from tqdm import tqdm
  7. import cv2
  8. import numpy as np
  9. import onnxruntime
  10. from mineru.utils.enum_class import ModelPath
  11. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  12. class PaddleOrientationClsModel:
  13. def __init__(self, ocr_engine):
  14. self.sess = onnxruntime.InferenceSession(
  15. os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
  16. )
  17. self.ocr_engine = ocr_engine
  18. self.less_length = 256
  19. self.cw, self.ch = 224, 224
  20. self.std = [0.229, 0.224, 0.225]
  21. self.scale = 0.00392156862745098
  22. self.mean = [0.485, 0.456, 0.406]
  23. self.labels = ["0", "90", "180", "270"]
  24. def preprocess(self, input_img):
  25. # 放大图片,使其最短边长为256
  26. h, w = input_img.shape[:2]
  27. scale = 256 / min(h, w)
  28. h_resize = round(h * scale)
  29. w_resize = round(w * scale)
  30. img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
  31. # 调整为224*224的正方形
  32. h, w = img.shape[:2]
  33. cw, ch = 224, 224
  34. x1 = max(0, (w - cw) // 2)
  35. y1 = max(0, (h - ch) // 2)
  36. x2 = min(w, x1 + cw)
  37. y2 = min(h, y1 + ch)
  38. if w < cw or h < ch:
  39. raise ValueError(
  40. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  41. )
  42. img = img[y1:y2, x1:x2, ...]
  43. # 正则化
  44. split_im = list(cv2.split(img))
  45. std = [0.229, 0.224, 0.225]
  46. scale = 0.00392156862745098
  47. mean = [0.485, 0.456, 0.406]
  48. alpha = [scale / std[i] for i in range(len(std))]
  49. beta = [-mean[i] / std[i] for i in range(len(std))]
  50. for c in range(img.shape[2]):
  51. split_im[c] = split_im[c].astype(np.float32)
  52. split_im[c] *= alpha[c]
  53. split_im[c] += beta[c]
  54. img = cv2.merge(split_im)
  55. # 5. 转换为 CHW 格式
  56. img = img.transpose((2, 0, 1))
  57. imgs = [img]
  58. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  59. return x
  60. def predict(self, input_img):
  61. rotate_label = "0" # Default to 0 if no rotation detected or not portrait
  62. if isinstance(input_img, Image.Image):
  63. np_img = np.asarray(input_img)
  64. elif isinstance(input_img, np.ndarray):
  65. np_img = input_img
  66. else:
  67. raise ValueError("Input must be a pillow object or a numpy array.")
  68. bgr_image = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
  69. # First check the overall image aspect ratio (height/width)
  70. img_height, img_width = bgr_image.shape[:2]
  71. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  72. img_is_portrait = img_aspect_ratio > 1.2
  73. if img_is_portrait:
  74. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  75. # Check if table is rotated by analyzing text box aspect ratios
  76. if det_res:
  77. vertical_count = 0
  78. is_rotated = False
  79. for box_ocr_res in det_res:
  80. p1, p2, p3, p4 = box_ocr_res
  81. # Calculate width and height
  82. width = p3[0] - p1[0]
  83. height = p3[1] - p1[1]
  84. aspect_ratio = width / height if height > 0 else 1.0
  85. # Count vertical vs horizontal text boxes
  86. if aspect_ratio < 0.8: # Taller than wide - vertical text
  87. vertical_count += 1
  88. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  89. # horizontal_count += 1
  90. if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
  91. is_rotated = True
  92. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  93. # If we have more vertical text boxes than horizontal ones,
  94. # and vertical ones are significant, table might be rotated
  95. if is_rotated:
  96. x = self.preprocess(np_img)
  97. (result,) = self.sess.run(None, {"x": x})
  98. rotate_label = self.labels[np.argmax(result)]
  99. # logger.debug(f"Orientation classification result: {label}")
  100. return rotate_label
  101. def list_2_batch(self, img_list, batch_size=16):
  102. """
  103. 将任意长度的列表按照指定的batch size分成多个batch
  104. Args:
  105. img_list: 输入的列表
  106. batch_size: 每个batch的大小,默认为16
  107. Returns:
  108. 一个包含多个batch的列表,每个batch都是原列表的一个子列表
  109. """
  110. batches = []
  111. for i in range(0, len(img_list), batch_size):
  112. batch = img_list[i : min(i + batch_size, len(img_list))]
  113. batches.append(batch)
  114. return batches
  115. def batch_preprocess(self, imgs):
  116. res_imgs = []
  117. for img_info in imgs:
  118. # PIL图像转cv2
  119. img = cv2.cvtColor(np.asarray(img_info["table_img"]), cv2.COLOR_RGB2BGR)
  120. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  121. # 放大图片,使其最短边长为256
  122. h, w = img.shape[:2]
  123. scale = 256 / min(h, w)
  124. h_resize = round(h * scale)
  125. w_resize = round(w * scale)
  126. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  127. # 调整为224*224的正方形
  128. h, w = img.shape[:2]
  129. cw, ch = 224, 224
  130. x1 = max(0, (w - cw) // 2)
  131. y1 = max(0, (h - ch) // 2)
  132. x2 = min(w, x1 + cw)
  133. y2 = min(h, y1 + ch)
  134. if w < cw or h < ch:
  135. raise ValueError(
  136. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  137. )
  138. img = img[y1:y2, x1:x2, ...]
  139. # 正则化
  140. split_im = list(cv2.split(img))
  141. std = [0.229, 0.224, 0.225]
  142. scale = 0.00392156862745098
  143. mean = [0.485, 0.456, 0.406]
  144. alpha = [scale / std[i] for i in range(len(std))]
  145. beta = [-mean[i] / std[i] for i in range(len(std))]
  146. for c in range(img.shape[2]):
  147. split_im[c] = split_im[c].astype(np.float32)
  148. split_im[c] *= alpha[c]
  149. split_im[c] += beta[c]
  150. img = cv2.merge(split_im)
  151. # 5. 转换为 CHW 格式
  152. img = img.transpose((2, 0, 1))
  153. res_imgs.append(img)
  154. x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
  155. return x
  156. def batch_predict(
  157. self, imgs: List[Dict], atom_model_manager, ocr_model_name: str, batch_size: int
  158. ) -> None:
  159. """
  160. 批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
  161. """
  162. # 按语言分组,跳过长宽比小于1.2的图片
  163. lang_groups = defaultdict(list)
  164. for img in imgs:
  165. # PIL RGB图像转换BGR
  166. table_img: np.ndarray = cv2.cvtColor(
  167. np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR
  168. )
  169. img["table_img_ndarray"] = table_img
  170. img_height, img_width = table_img.shape[:2]
  171. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  172. img_is_portrait = img_aspect_ratio > 1.2
  173. if img_is_portrait:
  174. lang = img["lang"]
  175. lang_groups[lang].append(img)
  176. # 对每种语言按分辨率分组并批处理
  177. for lang, lang_group_img_list in lang_groups.items():
  178. if not lang_group_img_list:
  179. continue
  180. # 获取OCR模型
  181. ocr_model = atom_model_manager.get_atom_model(
  182. atom_model_name=ocr_model_name, det_db_box_thresh=0.3, lang=lang
  183. )
  184. # 按分辨率分组并同时完成padding
  185. resolution_groups = defaultdict(list)
  186. for img in lang_group_img_list:
  187. h, w = img["table_img_ndarray"].shape[:2]
  188. normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
  189. normalized_w = ((w + 32) // 32) * 32
  190. group_key = (normalized_h, normalized_w)
  191. resolution_groups[group_key].append(img)
  192. # 对每个分辨率组进行批处理
  193. for group_key, group_imgs in tqdm(
  194. resolution_groups.items(), desc=f"ORI CLS OCR-det {lang}"
  195. ):
  196. # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
  197. max_h = max(img["table_img_ndarray"].shape[0] for img in group_imgs)
  198. max_w = max(img["table_img_ndarray"].shape[1] for img in group_imgs)
  199. target_h = ((max_h + 32 - 1) // 32) * 32
  200. target_w = ((max_w + 32 - 1) // 32) * 32
  201. # 对所有图像进行padding到统一尺寸
  202. batch_images = []
  203. for img in group_imgs:
  204. table_img_ndarray = img["table_img_ndarray"]
  205. h, w = table_img_ndarray.shape[:2]
  206. # 创建目标尺寸的白色背景
  207. padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
  208. # 将原图像粘贴到左上角
  209. padded_img[:h, :w] = table_img_ndarray
  210. batch_images.append(padded_img)
  211. # 批处理检测
  212. det_batch_size = min(len(batch_images), batch_size) # 增加批处理大小
  213. batch_results = ocr_model.text_detector.batch_predict(
  214. batch_images, det_batch_size
  215. )
  216. rotated_imgs = []
  217. # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
  218. for index, (img_info, (dt_boxes, elapse)) in enumerate(
  219. zip(group_imgs, batch_results)
  220. ):
  221. vertical_count = 0
  222. for box_ocr_res in dt_boxes:
  223. p1, p2, p3, p4 = box_ocr_res
  224. # Calculate width and height
  225. width = p3[0] - p1[0]
  226. height = p3[1] - p1[1]
  227. aspect_ratio = width / height if height > 0 else 1.0
  228. # Count vertical text boxes
  229. if aspect_ratio < 0.8: # Taller than wide - vertical text
  230. vertical_count += 1
  231. if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
  232. rotated_imgs.append(img_info)
  233. if len(rotated_imgs) > 0:
  234. x = self.batch_preprocess(rotated_imgs)
  235. results = self.sess.run(None, {"x": x})
  236. for img_info, res in zip(rotated_imgs, results[0]):
  237. label = self.labels[np.argmax(res)]
  238. if label == "270":
  239. img_info["table_img"] = Image.fromarray(
  240. cv2.rotate(
  241. np.asarray(img_info["table_img"]),
  242. cv2.ROTATE_90_CLOCKWISE,
  243. )
  244. )
  245. elif label == "90":
  246. img_info["table_img"] = Image.fromarray(
  247. cv2.rotate(
  248. np.asarray(img_info["table_img"]),
  249. cv2.ROTATE_90_COUNTERCLOCKWISE,
  250. )
  251. )