paddle_ori_cls.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. from collections import defaultdict
  4. from typing import List, Dict
  5. from tqdm import tqdm
  6. import cv2
  7. import numpy as np
  8. import onnxruntime
  9. from PIL import Image
  10. from mineru.utils.enum_class import ModelPath
  11. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  12. class PaddleOrientationClsModel:
  13. def __init__(self, ocr_engine):
  14. self.sess = onnxruntime.InferenceSession(
  15. os.path.join(
  16. auto_download_and_get_model_root_path(
  17. ModelPath.paddle_orientation_classification
  18. ),
  19. ModelPath.paddle_orientation_classification,
  20. )
  21. )
  22. self.ocr_engine = ocr_engine
  23. self.less_length = 256
  24. self.cw, self.ch = 224, 224
  25. self.std = [0.229, 0.224, 0.225]
  26. self.scale = 0.00392156862745098
  27. self.mean = [0.485, 0.456, 0.406]
  28. self.labels = ["0", "90", "180", "270"]
  29. def preprocess(self, img):
  30. # PIL图像转cv2
  31. img = np.array(img)
  32. # 放大图片,使其最短边长为256
  33. h, w = img.shape[:2]
  34. scale = 256 / min(h, w)
  35. h_resize = round(h * scale)
  36. w_resize = round(w * scale)
  37. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  38. # 调整为224*224的正方形
  39. h, w = img.shape[:2]
  40. cw, ch = 224, 224
  41. x1 = max(0, (w - cw) // 2)
  42. y1 = max(0, (h - ch) // 2)
  43. x2 = min(w, x1 + cw)
  44. y2 = min(h, y1 + ch)
  45. if w < cw or h < ch:
  46. raise ValueError(
  47. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  48. )
  49. img = img[y1:y2, x1:x2, ...]
  50. # 正则化
  51. split_im = list(cv2.split(img))
  52. std = [0.229, 0.224, 0.225]
  53. scale = 0.00392156862745098
  54. mean = [0.485, 0.456, 0.406]
  55. alpha = [scale / std[i] for i in range(len(std))]
  56. beta = [-mean[i] / std[i] for i in range(len(std))]
  57. for c in range(img.shape[2]):
  58. split_im[c] = split_im[c].astype(np.float32)
  59. split_im[c] *= alpha[c]
  60. split_im[c] += beta[c]
  61. img = cv2.merge(split_im)
  62. # 5. 转换为 CHW 格式
  63. img = img.transpose((2, 0, 1))
  64. imgs = [img]
  65. x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
  66. return x
  67. def predict(self, img):
  68. bgr_image = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
  69. # First check the overall image aspect ratio (height/width)
  70. img_height, img_width = bgr_image.shape[:2]
  71. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  72. img_is_portrait = img_aspect_ratio > 1.2
  73. if img_is_portrait:
  74. det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
  75. # Check if table is rotated by analyzing text box aspect ratios
  76. if det_res:
  77. vertical_count = 0
  78. is_rotated = False
  79. for box_ocr_res in det_res:
  80. p1, p2, p3, p4 = box_ocr_res
  81. # Calculate width and height
  82. width = p3[0] - p1[0]
  83. height = p3[1] - p1[1]
  84. aspect_ratio = width / height if height > 0 else 1.0
  85. # Count vertical vs horizontal text boxes
  86. if aspect_ratio < 0.8: # Taller than wide - vertical text
  87. vertical_count += 1
  88. # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
  89. # horizontal_count += 1
  90. if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
  91. is_rotated = True
  92. # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
  93. # If we have more vertical text boxes than horizontal ones,
  94. # and vertical ones are significant, table might be rotated
  95. if is_rotated:
  96. x = self.preprocess(img)
  97. (result,) = self.sess.run(None, {"x": x})
  98. label = self.labels[np.argmax(result)]
  99. # logger.debug(f"Orientation classification result: {label}")
  100. if label == "270":
  101. img = cv2.rotate(np.asarray(img), cv2.ROTATE_90_CLOCKWISE)
  102. elif label == "90":
  103. img = cv2.rotate(
  104. np.asarray(img), cv2.ROTATE_90_COUNTERCLOCKWISE
  105. )
  106. else:
  107. pass
  108. return img
  109. def list_2_batch(self, img_list, batch_size=16):
  110. """
  111. 将任意长度的列表按照指定的batch size分成多个batch
  112. Args:
  113. img_list: 输入的列表
  114. batch_size: 每个batch的大小,默认为16
  115. Returns:
  116. 一个包含多个batch的列表,每个batch都是原列表的一个子列表
  117. """
  118. batches = []
  119. for i in range(0, len(img_list), batch_size):
  120. batch = img_list[i : min(i + batch_size, len(img_list))]
  121. batches.append(batch)
  122. return batches
  123. def batch_preprocess(self, imgs):
  124. res_imgs = []
  125. for img_info in imgs:
  126. # PIL图像转cv2
  127. img = cv2.cvtColor(np.asarray(img_info["table_img"]), cv2.COLOR_RGB2BGR)
  128. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  129. # 放大图片,使其最短边长为256
  130. h, w = img.shape[:2]
  131. scale = 256 / min(h, w)
  132. h_resize = round(h * scale)
  133. w_resize = round(w * scale)
  134. img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
  135. # 调整为224*224的正方形
  136. h, w = img.shape[:2]
  137. cw, ch = 224, 224
  138. x1 = max(0, (w - cw) // 2)
  139. y1 = max(0, (h - ch) // 2)
  140. x2 = min(w, x1 + cw)
  141. y2 = min(h, y1 + ch)
  142. if w < cw or h < ch:
  143. raise ValueError(
  144. f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
  145. )
  146. img = img[y1:y2, x1:x2, ...]
  147. # 正则化
  148. split_im = list(cv2.split(img))
  149. std = [0.229, 0.224, 0.225]
  150. scale = 0.00392156862745098
  151. mean = [0.485, 0.456, 0.406]
  152. alpha = [scale / std[i] for i in range(len(std))]
  153. beta = [-mean[i] / std[i] for i in range(len(std))]
  154. for c in range(img.shape[2]):
  155. split_im[c] = split_im[c].astype(np.float32)
  156. split_im[c] *= alpha[c]
  157. split_im[c] += beta[c]
  158. img = cv2.merge(split_im)
  159. # 5. 转换为 CHW 格式
  160. img = img.transpose((2, 0, 1))
  161. res_imgs.append(img)
  162. x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
  163. return x
  164. def batch_predict(
  165. self, imgs: List[Dict], atom_model_manager, ocr_model_name: str, batch_size: int
  166. ) -> None:
  167. """
  168. 批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
  169. """
  170. # 按语言分组,跳过长宽比小于1.2的图片
  171. lang_groups = defaultdict(list)
  172. for img in imgs:
  173. # PIL RGB图像转换BGR
  174. table_img: np.ndarray = cv2.cvtColor(
  175. np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR
  176. )
  177. img["table_img_ndarray"] = table_img
  178. img_height, img_width = table_img.shape[:2]
  179. img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  180. img_is_portrait = img_aspect_ratio > 1.2
  181. if img_is_portrait:
  182. lang = img["lang"]
  183. lang_groups[lang].append(img)
  184. # 对每种语言按分辨率分组并批处理
  185. for lang, lang_group_img_list in lang_groups.items():
  186. if not lang_group_img_list:
  187. continue
  188. # 获取OCR模型
  189. ocr_model = atom_model_manager.get_atom_model(
  190. atom_model_name=ocr_model_name, det_db_box_thresh=0.3, lang=lang
  191. )
  192. # 按分辨率分组并同时完成padding
  193. resolution_groups = defaultdict(list)
  194. for img in lang_group_img_list:
  195. h, w = img["table_img_ndarray"].shape[:2]
  196. normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
  197. normalized_w = ((w + 32) // 32) * 32
  198. group_key = (normalized_h, normalized_w)
  199. resolution_groups[group_key].append(img)
  200. # 对每个分辨率组进行批处理
  201. for group_key, group_imgs in tqdm(
  202. resolution_groups.items(), desc=f"ORI CLS OCR-det {lang}"
  203. ):
  204. # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
  205. max_h = max(img["table_img_ndarray"].shape[0] for img in group_imgs)
  206. max_w = max(img["table_img_ndarray"].shape[1] for img in group_imgs)
  207. target_h = ((max_h + 32 - 1) // 32) * 32
  208. target_w = ((max_w + 32 - 1) // 32) * 32
  209. # 对所有图像进行padding到统一尺寸
  210. batch_images = []
  211. for img in group_imgs:
  212. table_img_ndarray = img["table_img_ndarray"]
  213. h, w = table_img_ndarray.shape[:2]
  214. # 创建目标尺寸的白色背景
  215. padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
  216. # 将原图像粘贴到左上角
  217. padded_img[:h, :w] = table_img_ndarray
  218. batch_images.append(padded_img)
  219. # 批处理检测
  220. det_batch_size = min(len(batch_images), batch_size) # 增加批处理大小
  221. batch_results = ocr_model.text_detector.batch_predict(
  222. batch_images, det_batch_size
  223. )
  224. rotated_imgs = []
  225. # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
  226. for index, (img_info, (dt_boxes, elapse)) in enumerate(
  227. zip(group_imgs, batch_results)
  228. ):
  229. vertical_count = 0
  230. for box_ocr_res in dt_boxes:
  231. p1, p2, p3, p4 = box_ocr_res
  232. # Calculate width and height
  233. width = p3[0] - p1[0]
  234. height = p3[1] - p1[1]
  235. aspect_ratio = width / height if height > 0 else 1.0
  236. # Count vertical text boxes
  237. if aspect_ratio < 0.8: # Taller than wide - vertical text
  238. vertical_count += 1
  239. if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
  240. rotated_imgs.append(img_info)
  241. if len(rotated_imgs) > 0:
  242. x = self.batch_preprocess(rotated_imgs)
  243. results = self.sess.run(None, {"x": x})
  244. for img_info, res in zip(rotated_imgs, results[0]):
  245. label = self.labels[np.argmax(res)]
  246. if label == "270":
  247. img_info["table_img"] = Image.fromarray(
  248. cv2.rotate(
  249. np.asarray(img_info["table_img"]),
  250. cv2.ROTATE_90_CLOCKWISE,
  251. )
  252. )
  253. elif label == "90":
  254. img_info["table_img"] = Image.fromarray(
  255. cv2.rotate(
  256. np.asarray(img_info["table_img"]),
  257. cv2.ROTATE_90_COUNTERCLOCKWISE,
  258. )
  259. )