|
@@ -1,10 +1,12 @@
|
|
|
# Copyright (c) Opendatalab. All rights reserved.
|
|
# Copyright (c) Opendatalab. All rights reserved.
|
|
|
import os
|
|
import os
|
|
|
-
|
|
|
|
|
|
|
+from collections import defaultdict
|
|
|
|
|
+from typing import List, Dict
|
|
|
|
|
+from tqdm import tqdm
|
|
|
import cv2
|
|
import cv2
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import onnxruntime
|
|
import onnxruntime
|
|
|
-from loguru import logger
|
|
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
|
|
|
from mineru.utils.enum_class import ModelPath
|
|
from mineru.utils.enum_class import ModelPath
|
|
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
@@ -13,7 +15,12 @@ from mineru.utils.models_download_utils import auto_download_and_get_model_root_
|
|
|
class PaddleOrientationClsModel:
|
|
class PaddleOrientationClsModel:
|
|
|
def __init__(self, ocr_engine):
|
|
def __init__(self, ocr_engine):
|
|
|
self.sess = onnxruntime.InferenceSession(
|
|
self.sess = onnxruntime.InferenceSession(
|
|
|
- os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
|
|
|
|
|
|
|
+ os.path.join(
|
|
|
|
|
+ auto_download_and_get_model_root_path(
|
|
|
|
|
+ ModelPath.paddle_orientation_classification
|
|
|
|
|
+ ),
|
|
|
|
|
+ ModelPath.paddle_orientation_classification,
|
|
|
|
|
+ )
|
|
|
)
|
|
)
|
|
|
self.ocr_engine = ocr_engine
|
|
self.ocr_engine = ocr_engine
|
|
|
self.less_length = 256
|
|
self.less_length = 256
|
|
@@ -104,11 +111,179 @@ class PaddleOrientationClsModel:
|
|
|
label = self.labels[np.argmax(result)]
|
|
label = self.labels[np.argmax(result)]
|
|
|
# logger.debug(f"Orientation classification result: {label}")
|
|
# logger.debug(f"Orientation classification result: {label}")
|
|
|
if label == "270":
|
|
if label == "270":
|
|
|
- rotation = cv2.ROTATE_90_CLOCKWISE
|
|
|
|
|
- img = cv2.rotate(np.asarray(img), rotation)
|
|
|
|
|
|
|
+ img = cv2.rotate(np.asarray(img), cv2.ROTATE_90_CLOCKWISE)
|
|
|
elif label == "90":
|
|
elif label == "90":
|
|
|
- rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
|
|
|
|
- img = cv2.rotate(np.asarray(img), rotation)
|
|
|
|
|
|
|
+ img = cv2.rotate(
|
|
|
|
|
+ np.asarray(img), cv2.ROTATE_90_COUNTERCLOCKWISE
|
|
|
|
|
+ )
|
|
|
else:
|
|
else:
|
|
|
pass
|
|
pass
|
|
|
return img
|
|
return img
|
|
|
|
|
+
|
|
|
|
|
+ def list_2_batch(self, img_list, batch_size=16):
|
|
|
|
|
+ """
|
|
|
|
|
+ 将任意长度的列表按照指定的batch size分成多个batch
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ img_list: 输入的列表
|
|
|
|
|
+ batch_size: 每个batch的大小,默认为16
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 一个包含多个batch的列表,每个batch都是原列表的一个子列表
|
|
|
|
|
+ """
|
|
|
|
|
+ batches = []
|
|
|
|
|
+ for i in range(0, len(img_list), batch_size):
|
|
|
|
|
+ batch = img_list[i : min(i + batch_size, len(img_list))]
|
|
|
|
|
+ batches.append(batch)
|
|
|
|
|
+ return batches
|
|
|
|
|
+
|
|
|
|
|
+ def batch_preprocess(self, imgs):
|
|
|
|
|
+ res_imgs = []
|
|
|
|
|
+ for img_info in imgs:
|
|
|
|
|
+ # PIL图像转cv2
|
|
|
|
|
+ img = cv2.cvtColor(np.asarray(img_info["table_img"]), cv2.COLOR_RGB2BGR)
|
|
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
|
+ # 放大图片,使其最短边长为256
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+ scale = 256 / min(h, w)
|
|
|
|
|
+ h_resize = round(h * scale)
|
|
|
|
|
+ w_resize = round(w * scale)
|
|
|
|
|
+ img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
|
|
|
|
|
+ # 调整为224*224的正方形
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+ cw, ch = 224, 224
|
|
|
|
|
+ x1 = max(0, (w - cw) // 2)
|
|
|
|
|
+ y1 = max(0, (h - ch) // 2)
|
|
|
|
|
+ x2 = min(w, x1 + cw)
|
|
|
|
|
+ y2 = min(h, y1 + ch)
|
|
|
|
|
+ if w < cw or h < ch:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
|
|
|
|
|
+ )
|
|
|
|
|
+ img = img[y1:y2, x1:x2, ...]
|
|
|
|
|
+ # 正则化
|
|
|
|
|
+ split_im = list(cv2.split(img))
|
|
|
|
|
+ std = [0.229, 0.224, 0.225]
|
|
|
|
|
+ scale = 0.00392156862745098
|
|
|
|
|
+ mean = [0.485, 0.456, 0.406]
|
|
|
|
|
+ alpha = [scale / std[i] for i in range(len(std))]
|
|
|
|
|
+ beta = [-mean[i] / std[i] for i in range(len(std))]
|
|
|
|
|
+ for c in range(img.shape[2]):
|
|
|
|
|
+ split_im[c] = split_im[c].astype(np.float32)
|
|
|
|
|
+ split_im[c] *= alpha[c]
|
|
|
|
|
+ split_im[c] += beta[c]
|
|
|
|
|
+ img = cv2.merge(split_im)
|
|
|
|
|
+ # 5. 转换为 CHW 格式
|
|
|
|
|
+ img = img.transpose((2, 0, 1))
|
|
|
|
|
+ res_imgs.append(img)
|
|
|
|
|
+ x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
|
|
|
|
|
+ return x
|
|
|
|
|
+
|
|
|
|
|
+ def batch_predict(
|
|
|
|
|
+ self, imgs: List[Dict], atom_model_manager, ocr_model_name: str, batch_size: int
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ 批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
|
|
|
|
|
+ """
|
|
|
|
|
+ # 按语言分组,跳过长宽比小于1.2的图片
|
|
|
|
|
+ lang_groups = defaultdict(list)
|
|
|
|
|
+ for img in imgs:
|
|
|
|
|
+ # PIL RGB图像转换BGR
|
|
|
|
|
+ table_img: np.ndarray = cv2.cvtColor(
|
|
|
|
|
+ np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR
|
|
|
|
|
+ )
|
|
|
|
|
+ img["table_img_ndarray"] = table_img
|
|
|
|
|
+ img_height, img_width = table_img.shape[:2]
|
|
|
|
|
+ img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
|
|
|
|
|
+ img_is_portrait = img_aspect_ratio > 1.2
|
|
|
|
|
+ if img_is_portrait:
|
|
|
|
|
+ lang = img["lang"]
|
|
|
|
|
+ lang_groups[lang].append(img)
|
|
|
|
|
+
|
|
|
|
|
+ # 对每种语言按分辨率分组并批处理
|
|
|
|
|
+ for lang, lang_group_img_list in lang_groups.items():
|
|
|
|
|
+ if not lang_group_img_list:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # 获取OCR模型
|
|
|
|
|
+ ocr_model = atom_model_manager.get_atom_model(
|
|
|
|
|
+ atom_model_name=ocr_model_name, det_db_box_thresh=0.3, lang=lang
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 按分辨率分组并同时完成padding
|
|
|
|
|
+ resolution_groups = defaultdict(list)
|
|
|
|
|
+ for img in lang_group_img_list:
|
|
|
|
|
+ h, w = img["table_img_ndarray"].shape[:2]
|
|
|
|
|
+ normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
|
|
|
|
|
+ normalized_w = ((w + 32) // 32) * 32
|
|
|
|
|
+ group_key = (normalized_h, normalized_w)
|
|
|
|
|
+ resolution_groups[group_key].append(img)
|
|
|
|
|
+
|
|
|
|
|
+ # 对每个分辨率组进行批处理
|
|
|
|
|
+ for group_key, group_imgs in tqdm(
|
|
|
|
|
+ resolution_groups.items(), desc=f"ORI CLS OCR-det {lang}"
|
|
|
|
|
+ ):
|
|
|
|
|
+
|
|
|
|
|
+ # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
|
|
|
|
|
+ max_h = max(img["table_img_ndarray"].shape[0] for img in group_imgs)
|
|
|
|
|
+ max_w = max(img["table_img_ndarray"].shape[1] for img in group_imgs)
|
|
|
|
|
+ target_h = ((max_h + 32 - 1) // 32) * 32
|
|
|
|
|
+ target_w = ((max_w + 32 - 1) // 32) * 32
|
|
|
|
|
+
|
|
|
|
|
+ # 对所有图像进行padding到统一尺寸
|
|
|
|
|
+ batch_images = []
|
|
|
|
|
+ for img in group_imgs:
|
|
|
|
|
+ table_img_ndarray = img["table_img_ndarray"]
|
|
|
|
|
+ h, w = table_img_ndarray.shape[:2]
|
|
|
|
|
+ # 创建目标尺寸的白色背景
|
|
|
|
|
+ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
|
|
|
|
|
+ # 将原图像粘贴到左上角
|
|
|
|
|
+ padded_img[:h, :w] = table_img_ndarray
|
|
|
|
|
+ batch_images.append(padded_img)
|
|
|
|
|
+
|
|
|
|
|
+ # 批处理检测
|
|
|
|
|
+ det_batch_size = min(len(batch_images), batch_size) # 增加批处理大小
|
|
|
|
|
+ batch_results = ocr_model.text_detector.batch_predict(
|
|
|
|
|
+ batch_images, det_batch_size
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ rotated_imgs = []
|
|
|
|
|
+ # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
|
|
|
|
|
+ for index, (img_info, (dt_boxes, elapse)) in enumerate(
|
|
|
|
|
+ zip(group_imgs, batch_results)
|
|
|
|
|
+ ):
|
|
|
|
|
+ vertical_count = 0
|
|
|
|
|
+ for box_ocr_res in dt_boxes:
|
|
|
|
|
+ p1, p2, p3, p4 = box_ocr_res
|
|
|
|
|
+
|
|
|
|
|
+ # Calculate width and height
|
|
|
|
|
+ width = p3[0] - p1[0]
|
|
|
|
|
+ height = p3[1] - p1[1]
|
|
|
|
|
+
|
|
|
|
|
+ aspect_ratio = width / height if height > 0 else 1.0
|
|
|
|
|
+
|
|
|
|
|
+ # Count vertical text boxes
|
|
|
|
|
+ if aspect_ratio < 0.8: # Taller than wide - vertical text
|
|
|
|
|
+ vertical_count += 1
|
|
|
|
|
+
|
|
|
|
|
+ if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
|
|
|
|
|
+ rotated_imgs.append(img_info)
|
|
|
|
|
+ if len(rotated_imgs) > 0:
|
|
|
|
|
+ x = self.batch_preprocess(rotated_imgs)
|
|
|
|
|
+ results = self.sess.run(None, {"x": x})
|
|
|
|
|
+ for img_info, res in zip(rotated_imgs, results[0]):
|
|
|
|
|
+ label = self.labels[np.argmax(res)]
|
|
|
|
|
+ if label == "270":
|
|
|
|
|
+ img_info["table_img"] = Image.fromarray(
|
|
|
|
|
+ cv2.rotate(
|
|
|
|
|
+ np.asarray(img_info["table_img"]),
|
|
|
|
|
+ cv2.ROTATE_90_CLOCKWISE,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ elif label == "90":
|
|
|
|
|
+ img_info["table_img"] = Image.fromarray(
|
|
|
|
|
+ cv2.rotate(
|
|
|
|
|
+ np.asarray(img_info["table_img"]),
|
|
|
|
|
+ cv2.ROTATE_90_COUNTERCLOCKWISE,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|