Selaa lähdekoodia

feat: Add batch prediction for image rotation classification and table classification

Sidney233 3 kuukautta sitten
vanhempi
commit
dfbccbc624

+ 30 - 27
mineru/backend/pipeline/batch_analyze.py

@@ -14,6 +14,7 @@ YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFR_BASE_BATCH_SIZE = 16
 OCR_DET_BASE_BATCH_SIZE = 16
+ORI_TAB_CLS_BATCH_SIZE = 16
 
 
 class BatchAnalyze:
@@ -44,7 +45,6 @@ class BatchAnalyze:
         for image_index, image in enumerate(images):
             layout_images.append(image)
 
-
         images_layout_res += self.model.layout_model.batch_predict(
             layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
         )
@@ -248,49 +248,52 @@ class BatchAnalyze:
 
         # 表格识别 table recognition
         if self.table_enable:
-            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
-                _lang = table_res_dict['lang']
+            # 图片旋转批量处理
 
-                # 调整图片方向
-                img_orientation_cls_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.ImgOrientationCls,
+            img_orientation_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.ImgOrientationCls,
+            )
+            try:
+                img_orientation_cls_model.batch_predict(table_res_list_all_page, atom_model_manager, AtomicModel.OCR, self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
+            except Exception as e:
+                logger.warning(
+                    f"Image orientation classification failed: {e}, using original image"
                 )
-                try:
-                    table_img = img_orientation_cls_model.predict(
-                        table_res_dict["table_img"]
-                    )
-                except Exception as e:
-                    logger.warning(
-                        f"Image orientation classification failed: {e}, using original image"
-                    )
-                    table_img = table_res_dict["table_img"]
-
-                # 有线表/无线表分类
-                table_cls_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.TableCls,
+            # 表格分类
+            table_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.TableCls,
+            )
+            try:
+                table_cls_model.batch_predict(table_res_list_all_page)
+            except Exception as e:
+                logger.warning(
+                    f"Table classification failed: {e}, using default model"
                 )
+            # 遍历表格,根据分类识别结构
+            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
+                _lang = table_res_dict['lang']
                 table_cls_score = 0.5
                 try:
-                    table_label, table_cls_score = table_cls_model.predict(table_img)
+                    table_label, table_cls_score = table_res_dict['table_res']["cls_label"], table_res_dict['table_res']["cls_score"]
                 except Exception as e:
-                    table_label = AtomicModel.WirelessTable
                     logger.warning(
-                        f"Table classification failed: {e}, using default model {table_label}"
+                        f"Table classification failed: {e}, return error classification result: {table_res_dict}"
                     )
-                # table_label = AtomicModel.WirelessTable
-                # logger.debug(f"Table classification result: {table_label}")
-                if table_label not in [AtomicModel.WirelessTable, AtomicModel.WiredTable]:
+                    table_label = AtomicModel.WirelessTable
+                if table_label not in [
+                    AtomicModel.WirelessTable,
+                    AtomicModel.WiredTable,
+                ]:
                     raise ValueError(
                         "Table classification failed, please check the model"
                     )
-
                 # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
                 table_model = atom_model_manager.get_atom_model(
                     atom_model_name=table_label,
                     lang=_lang,
                 )
 
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img, table_cls_score)
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict["table_img"], table_cls_score)
                 # 判断是否返回正常
                 if html_code:
                     # 检查html_code是否包含'<table>'和'</table>'

+ 182 - 7
mineru/model/ori_cls/paddle_ori_cls.py

@@ -1,10 +1,12 @@
 # Copyright (c) Opendatalab. All rights reserved.
 import os
-
+from collections import defaultdict
+from typing import List, Dict
+from tqdm import tqdm
 import cv2
 import numpy as np
 import onnxruntime
-from loguru import logger
+from PIL import Image
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
@@ -13,7 +15,12 @@ from mineru.utils.models_download_utils import auto_download_and_get_model_root_
 class PaddleOrientationClsModel:
     def __init__(self, ocr_engine):
         self.sess = onnxruntime.InferenceSession(
-            os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_orientation_classification), ModelPath.paddle_orientation_classification)
+            os.path.join(
+                auto_download_and_get_model_root_path(
+                    ModelPath.paddle_orientation_classification
+                ),
+                ModelPath.paddle_orientation_classification,
+            )
         )
         self.ocr_engine = ocr_engine
         self.less_length = 256
@@ -104,11 +111,179 @@ class PaddleOrientationClsModel:
                     label = self.labels[np.argmax(result)]
                     # logger.debug(f"Orientation classification result: {label}")
                     if label == "270":
-                        rotation = cv2.ROTATE_90_CLOCKWISE
-                        img = cv2.rotate(np.asarray(img), rotation)
+                        img = cv2.rotate(np.asarray(img), cv2.ROTATE_90_CLOCKWISE)
                     elif label == "90":
-                        rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
-                        img = cv2.rotate(np.asarray(img), rotation)
+                        img = cv2.rotate(
+                            np.asarray(img), cv2.ROTATE_90_COUNTERCLOCKWISE
+                        )
                     else:
                         pass
         return img
+
+    def list_2_batch(self, img_list, batch_size=16):
+        """
+        将任意长度的列表按照指定的batch size分成多个batch
+
+        Args:
+            img_list: 输入的列表
+            batch_size: 每个batch的大小,默认为16
+
+        Returns:
+            一个包含多个batch的列表,每个batch都是原列表的一个子列表
+        """
+        batches = []
+        for i in range(0, len(img_list), batch_size):
+            batch = img_list[i : min(i + batch_size, len(img_list))]
+            batches.append(batch)
+        return batches
+
+    def batch_preprocess(self, imgs):
+        res_imgs = []
+        for img_info in imgs:
+            # PIL图像转cv2
+            img = cv2.cvtColor(np.asarray(img_info["table_img"]), cv2.COLOR_RGB2BGR)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            # 放大图片,使其最短边长为256
+            h, w = img.shape[:2]
+            scale = 256 / min(h, w)
+            h_resize = round(h * scale)
+            w_resize = round(w * scale)
+            img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+            # 调整为224*224的正方形
+            h, w = img.shape[:2]
+            cw, ch = 224, 224
+            x1 = max(0, (w - cw) // 2)
+            y1 = max(0, (h - ch) // 2)
+            x2 = min(w, x1 + cw)
+            y2 = min(h, y1 + ch)
+            if w < cw or h < ch:
+                raise ValueError(
+                    f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+                )
+            img = img[y1:y2, x1:x2, ...]
+            # 正则化
+            split_im = list(cv2.split(img))
+            std = [0.229, 0.224, 0.225]
+            scale = 0.00392156862745098
+            mean = [0.485, 0.456, 0.406]
+            alpha = [scale / std[i] for i in range(len(std))]
+            beta = [-mean[i] / std[i] for i in range(len(std))]
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= alpha[c]
+                split_im[c] += beta[c]
+            img = cv2.merge(split_im)
+            # 5. 转换为 CHW 格式
+            img = img.transpose((2, 0, 1))
+            res_imgs.append(img)
+        x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+
+    def batch_predict(
+        self, imgs: List[Dict], atom_model_manager, ocr_model_name: str, batch_size: int
+    ) -> None:
+        """
+        批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
+        """
+        # 按语言分组,跳过长宽比小于1.2的图片
+        lang_groups = defaultdict(list)
+        for img in imgs:
+            # PIL RGB图像转换BGR
+            table_img: np.ndarray = cv2.cvtColor(
+                np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR
+            )
+            img["table_img_ndarray"] = table_img
+            img_height, img_width = table_img.shape[:2]
+            img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+            img_is_portrait = img_aspect_ratio > 1.2
+            if img_is_portrait:
+                lang = img["lang"]
+                lang_groups[lang].append(img)
+
+        # 对每种语言按分辨率分组并批处理
+        for lang, lang_group_img_list in lang_groups.items():
+            if not lang_group_img_list:
+                continue
+
+            # 获取OCR模型
+            ocr_model = atom_model_manager.get_atom_model(
+                atom_model_name=ocr_model_name, det_db_box_thresh=0.3, lang=lang
+            )
+
+            # 按分辨率分组并同时完成padding
+            resolution_groups = defaultdict(list)
+            for img in lang_group_img_list:
+                h, w = img["table_img_ndarray"].shape[:2]
+                normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
+                normalized_w = ((w + 32) // 32) * 32
+                group_key = (normalized_h, normalized_w)
+                resolution_groups[group_key].append(img)
+
+            # 对每个分辨率组进行批处理
+            for group_key, group_imgs in tqdm(
+                resolution_groups.items(), desc=f"ORI CLS OCR-det {lang}"
+            ):
+
+                # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
+                max_h = max(img["table_img_ndarray"].shape[0] for img in group_imgs)
+                max_w = max(img["table_img_ndarray"].shape[1] for img in group_imgs)
+                target_h = ((max_h + 32 - 1) // 32) * 32
+                target_w = ((max_w + 32 - 1) // 32) * 32
+
+                # 对所有图像进行padding到统一尺寸
+                batch_images = []
+                for img in group_imgs:
+                    table_img_ndarray = img["table_img_ndarray"]
+                    h, w = table_img_ndarray.shape[:2]
+                    # 创建目标尺寸的白色背景
+                    padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                    # 将原图像粘贴到左上角
+                    padded_img[:h, :w] = table_img_ndarray
+                    batch_images.append(padded_img)
+
+                # 批处理检测
+                det_batch_size = min(len(batch_images), batch_size)  # 增加批处理大小
+                batch_results = ocr_model.text_detector.batch_predict(
+                    batch_images, det_batch_size
+                )
+
+                rotated_imgs = []
+                # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
+                for index, (img_info, (dt_boxes, elapse)) in enumerate(
+                    zip(group_imgs, batch_results)
+                ):
+                    vertical_count = 0
+                    for box_ocr_res in dt_boxes:
+                        p1, p2, p3, p4 = box_ocr_res
+
+                        # Calculate width and height
+                        width = p3[0] - p1[0]
+                        height = p3[1] - p1[1]
+
+                        aspect_ratio = width / height if height > 0 else 1.0
+
+                        # Count vertical text boxes
+                        if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                            vertical_count += 1
+
+                    if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
+                        rotated_imgs.append(img_info)
+                if len(rotated_imgs) > 0:
+                    x = self.batch_preprocess(rotated_imgs)
+                    results = self.sess.run(None, {"x": x})
+                    for img_info, res in zip(rotated_imgs, results[0]):
+                        label = self.labels[np.argmax(res)]
+                        if label == "270":
+                            img_info["table_img"] = Image.fromarray(
+                                cv2.rotate(
+                                    np.asarray(img_info["table_img"]),
+                                    cv2.ROTATE_90_CLOCKWISE,
+                                )
+                            )
+                        elif label == "90":
+                            img_info["table_img"] = Image.fromarray(
+                                cv2.rotate(
+                                    np.asarray(img_info["table_img"]),
+                                    cv2.ROTATE_90_COUNTERCLOCKWISE,
+                                )
+                            )

+ 77 - 0
mineru/model/table/cls/paddle_table_cls.py

@@ -1,4 +1,5 @@
 import os
+from pathlib import Path
 
 import cv2
 import numpy as np
@@ -61,6 +62,65 @@ class PaddleTableClsModel:
         x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
 
+    def list_2_batch(self, img_list, batch_size=16):
+        """
+        将任意长度的列表按照指定的batch size分成多个batch
+
+        Args:
+            img_list: 输入的列表
+            batch_size: 每个batch的大小,默认为16
+
+        Returns:
+            一个包含多个batch的列表,每个batch都是原列表的一个子列表
+        """
+        batches = []
+        for i in range(0, len(img_list), batch_size):
+            batch = img_list[i : min(i + batch_size, len(img_list))]
+            batches.append(batch)
+        return batches
+
+    def batch_preprocess(self, imgs):
+        res_imgs = []
+        for img in imgs:
+            # PIL图像转cv2
+            img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            # 放大图片,使其最短边长为256
+            h, w = img.shape[:2]
+            scale = 256 / min(h, w)
+            h_resize = round(h * scale)
+            w_resize = round(w * scale)
+            img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+            # 调整为224*224的正方形
+            h, w = img.shape[:2]
+            cw, ch = 224, 224
+            x1 = max(0, (w - cw) // 2)
+            y1 = max(0, (h - ch) // 2)
+            x2 = min(w, x1 + cw)
+            y2 = min(h, y1 + ch)
+            if w < cw or h < ch:
+                raise ValueError(
+                    f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+                )
+            img = img[y1:y2, x1:x2, ...]
+            # 正则化
+            split_im = list(cv2.split(img))
+            std = [0.229, 0.224, 0.225]
+            scale = 0.00392156862745098
+            mean = [0.485, 0.456, 0.406]
+            alpha = [scale / std[i] for i in range(len(std))]
+            beta = [-mean[i] / std[i] for i in range(len(std))]
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= alpha[c]
+                split_im[c] += beta[c]
+            img = cv2.merge(split_im)
+            # 5. 转换为 CHW 格式
+            img = img.transpose((2, 0, 1))
+            res_imgs.append(img)
+        x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+
     def predict(self, img):
         x = self.preprocess(img)
         result = self.sess.run(None, {"x": x})
@@ -70,3 +130,20 @@ class PaddleTableClsModel:
         if idx == 0 and conf < 0.8:
             idx = 1
         return self.labels[idx], conf
+
+    def batch_predict(self, img_info_list, batch_size=16):
+        imgs = [item["table_img"] for item in img_info_list]
+        imgs = self.list_2_batch(imgs, batch_size=batch_size)
+        label_res = []
+        for img_batch in imgs:
+            x = self.batch_preprocess(img_batch)
+            result = self.sess.run(None, {"x": x})
+            for img_res in result[0]:
+                idx = np.argmax(img_res)
+                conf = float(np.max(img_res))
+                if idx == 0 and conf < 0.9:
+                    idx = 1
+                label_res.append((self.labels[idx],conf))
+        for img_info, (label, conf) in zip(img_info_list, label_res):
+            img_info['table_res']["cls_label"] = label
+            img_info['table_res']["cls_score"] = conf