Преглед изворни кода

Merge pull request #3382 from myhloli/feat_table_batch

Feat table batch
Xiaomeng Zhao пре 2 месеци
родитељ
комит
959163a5b5

+ 124 - 75
mineru/backend/pipeline/batch_analyze.py

@@ -1,3 +1,5 @@
+import html
+
 import cv2
 import cv2
 from loguru import logger
 from loguru import logger
 from tqdm import tqdm
 from tqdm import tqdm
@@ -8,13 +10,16 @@ from .model_init import AtomModelSingleton
 from .model_list import AtomicModel
 from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
-from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
+from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
 from ...utils.pdf_image_tools import get_crop_np_img
 from ...utils.pdf_image_tools import get_crop_np_img
 
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFR_BASE_BATCH_SIZE = 16
 MFR_BASE_BATCH_SIZE = 16
 OCR_DET_BASE_BATCH_SIZE = 16
 OCR_DET_BASE_BATCH_SIZE = 16
+TABLE_ORI_CLS_BATCH_SIZE = 16
+TABLE_Wired_Wireless_CLS_BATCH_SIZE = 16
 
 
 
 
 class BatchAnalyze:
 class BatchAnalyze:
@@ -38,12 +43,14 @@ class BatchAnalyze:
         )
         )
         atom_model_manager = AtomModelSingleton()
         atom_model_manager = AtomModelSingleton()
 
 
+        pil_images = [image for image, _, _ in images_with_extra_info]
+
         np_images = [np.asarray(image) for image, _, _ in images_with_extra_info]
         np_images = [np.asarray(image) for image, _, _ in images_with_extra_info]
 
 
         # doclayout_yolo
         # doclayout_yolo
 
 
         images_layout_res += self.model.layout_model.batch_predict(
         images_layout_res += self.model.layout_model.batch_predict(
-            np_images, YOLO_LAYOUT_BASE_BATCH_SIZE
+            pil_images, YOLO_LAYOUT_BASE_BATCH_SIZE
         )
         )
 
 
         if self.formula_enable:
         if self.formula_enable:
@@ -99,7 +106,120 @@ class BatchAnalyze:
                                                 'table_img':table_img,
                                                 'table_img':table_img,
                                               })
                                               })
 
 
-        # OCR检测处理
+        # 表格识别 table recognition
+        if self.table_enable:
+
+            # 图片旋转批量处理
+            img_orientation_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.ImgOrientationCls,
+            )
+            try:
+                img_orientation_cls_model.batch_predict(table_res_list_all_page,
+                                                        det_batch_size=self.batch_ratio * OCR_DET_BASE_BATCH_SIZE,
+                                                        batch_size=TABLE_ORI_CLS_BATCH_SIZE)
+            except Exception as e:
+                logger.warning(
+                    f"Image orientation classification failed: {e}, using original image"
+                )
+
+            # 表格分类
+            table_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.TableCls,
+            )
+            try:
+                table_cls_model.batch_predict(table_res_list_all_page,
+                                              batch_size=TABLE_Wired_Wireless_CLS_BATCH_SIZE)
+            except Exception as e:
+                logger.warning(
+                    f"Table classification failed: {e}, using default model"
+                )
+
+            # OCR det 过程,顺序执行
+            rec_img_lang_group = defaultdict(list)
+            for index, table_res_dict in enumerate(
+                    tqdm(table_res_list_all_page, desc="Table-ocr det")
+            ):
+                ocr_engine = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.OCR,
+                    det_db_box_thresh=0.5,
+                    det_db_unclip_ratio=1.6,
+                    # lang= table_res_dict["lang"],
+                    enable_merge_det_boxes=False,
+                )
+                bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
+                ocr_result = ocr_engine.ocr(bgr_image, rec=False)[0]
+                # 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
+                for dt_box in ocr_result:
+                    rec_img_lang_group[_lang].append(
+                        {
+                            "cropped_img": get_rotate_crop_image(
+                                bgr_image, np.asarray(dt_box, dtype=np.float32)
+                            ),
+                            "dt_box": np.asarray(dt_box, dtype=np.float32),
+                            "table_id": index,
+                        }
+                    )
+
+            # OCR rec,按照语言分批处理
+            for _lang, rec_img_list in rec_img_lang_group.items():
+                ocr_engine = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.OCR,
+                    det_db_box_thresh=0.5,
+                    det_db_unclip_ratio=1.6,
+                    lang=_lang,
+                    enable_merge_det_boxes=False,
+                )
+                cropped_img_list = [item["cropped_img"] for item in rec_img_list]
+                ocr_res_list = \
+                ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True, tqdm_desc="Table-ocr rec")[0]
+                # 按照 table_id 将识别结果进行回填
+                for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
+                    if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
+                        table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append(
+                            [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
+                        )
+                    else:
+                        table_res_list_all_page[img_dict["table_id"]]["ocr_result"] = [
+                            [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
+                        ]
+
+            # 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
+            wireless_table_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.WirelessTable,
+            )
+            wireless_table_model.batch_predict(table_res_list_all_page)
+
+            # 单独拿出有线表格进行预测
+            wired_table_res_list = []
+            for table_res_dict in table_res_list_all_page:
+                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                    wired_table_res_list.append(table_res_dict)
+            if wired_table_res_list:
+                for table_res_dict in tqdm(
+                        wired_table_res_list, desc="Table-wired Predict"
+                ):
+                    wired_table_model = atom_model_manager.get_atom_model(
+                        atom_model_name=AtomicModel.WiredTable,
+                        lang=table_res_dict["lang"],
+                    )
+                    table_res_dict["table_res"]["html"] = wired_table_model.predict(
+                        table_res_dict["table_img"],
+                        table_res_dict["ocr_result"],
+                        table_res_dict["table_res"].get("html", None)
+                    )
+
+            # 表格格式清理
+            for table_res_dict in table_res_list_all_page:
+                html_code = table_res_dict["table_res"].get("html", "")
+
+                # 检查html_code是否包含'<table>'和'</table>'
+                if "<table>" in html_code and "</table>" in html_code:
+                    # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
+                    start_index = html_code.find("<table>")
+                    end_index = html_code.rfind("</table>") + len("</table>")
+                    table_res_dict["table_res"]["html"] = html_code[start_index:end_index]
+
+        # OCR det
         if self.enable_ocr_det_batch:
         if self.enable_ocr_det_batch:
             # 批处理模式 - 按语言和分辨率分组
             # 批处理模式 - 按语言和分辨率分组
             # 收集所有需要OCR检测的裁剪图像
             # 收集所有需要OCR检测的裁剪图像
@@ -189,9 +309,6 @@ class BatchAnalyze:
 
 
                         if dt_boxes is not None and len(dt_boxes) > 0:
                         if dt_boxes is not None and len(dt_boxes) > 0:
                             # 直接应用原始OCR流程中的关键处理步骤
                             # 直接应用原始OCR流程中的关键处理步骤
-                            from mineru.utils.ocr_utils import (
-                                merge_det_boxes, update_det_boxes, sorted_boxes
-                            )
 
 
                             # 1. 排序检测框
                             # 1. 排序检测框
                             if len(dt_boxes) > 0:
                             if len(dt_boxes) > 0:
@@ -253,75 +370,7 @@ class BatchAnalyze:
 
 
                         ocr_res_list_dict['layout_res'].extend(ocr_result_list)
                         ocr_res_list_dict['layout_res'].extend(ocr_result_list)
 
 
-        # 表格识别 table recognition
-        if self.table_enable:
-            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
-                _lang = table_res_dict['lang']
-
-                # 调整图片方向
-                img_orientation_cls_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.ImgOrientationCls,
-                )
-                try:
-                    rotate_label = img_orientation_cls_model.predict(
-                        table_res_dict["table_img"]
-                    )
-                except Exception as e:
-                    logger.warning(
-                        f"Image orientation classification failed: {e}, using original image"
-                    )
-                    rotate_label = "0"
-
-                np_table_img = table_res_dict["table_img"]
-                if rotate_label == "270":
-                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_CLOCKWISE)
-                elif rotate_label == "90":
-                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
-                else:
-                    pass
-
-                # 有线表/无线表分类
-                table_cls_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.TableCls,
-                )
-                table_cls_score = 0.5
-                try:
-                    table_label, table_cls_score = table_cls_model.predict(np_table_img)
-                except Exception as e:
-                    table_label = AtomicModel.WirelessTable
-                    logger.warning(
-                        f"Table classification failed: {e}, using default model {table_label}"
-                    )
-                # table_label = AtomicModel.WirelessTable
-                # logger.debug(f"Table classification result: {table_label}")
-                if table_label not in [AtomicModel.WirelessTable, AtomicModel.WiredTable]:
-                    raise ValueError(
-                        "Table classification failed, please check the model"
-                    )
-
-                # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
-                table_model = atom_model_manager.get_atom_model(
-                    atom_model_name=table_label,
-                    lang=_lang,
-                )
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(np_table_img, table_cls_score)
-                # 判断是否返回正常
-                if html_code:
-                    # 检查html_code是否包含'<table>'和'</table>'
-                    if '<table>' in html_code and '</table>' in html_code:
-                        # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
-                        start_index = html_code.find('<table>')
-                        end_index = html_code.rfind('</table>') + len('</table>')
-                        table_res_dict['table_res']['html'] = html_code[start_index:end_index]
-                    else:
-                        logger.warning(
-                            'table recognition processing fails, not found expected HTML table end'
-                        )
-                else:
-                    logger.warning(
-                        'table recognition processing fails, not get html return'
-                    )
-
+        # OCR rec
         # Create dictionaries to store items by language
         # Create dictionaries to store items by language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language
         img_crop_lists_by_lang = {}  # Dict of lists for each language
         img_crop_lists_by_lang = {}  # Dict of lists for each language

+ 12 - 7
mineru/backend/pipeline/model_init.py

@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
-from ...model.table.rec.rapid_table import RapidTableModel
+from ...model.table.rec.slanet_plus.main import RapidTableModel
 from ...model.table.rec.unet_table.main import UnetTableModel
 from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
@@ -114,13 +114,18 @@ class AtomModelSingleton:
         lang = kwargs.get('lang', None)
         lang = kwargs.get('lang', None)
 
 
         if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
         if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
-            key = (atom_model_name, lang)
+            key = (
+                atom_model_name,
+                lang
+            )
         elif atom_model_name in [AtomicModel.OCR]:
         elif atom_model_name in [AtomicModel.OCR]:
-            key = (atom_model_name,
-                   kwargs.get('det_db_box_thresh', 0.3),
-                   lang, kwargs.get('det_db_unclip_ratio', 1.8),
-                   kwargs.get('enable_merge_det_boxes', True)
-                   )
+            key = (
+                atom_model_name,
+                kwargs.get('det_db_box_thresh', 0.3),
+                lang,
+                kwargs.get('det_db_unclip_ratio', 1.8),
+                kwargs.get('enable_merge_det_boxes', True)
+            )
         else:
         else:
             key = atom_model_name
             key = atom_model_name
 
 

+ 2 - 1
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -105,6 +105,7 @@ class PytorchPaddleOCR(TextSystem):
             rec=True,
             rec=True,
             mfd_res=None,
             mfd_res=None,
             tqdm_enable=False,
             tqdm_enable=False,
+            tqdm_desc="OCR-rec Predict",
             ):
             ):
         assert isinstance(img, (np.ndarray, list, str, bytes))
         assert isinstance(img, (np.ndarray, list, str, bytes))
         if isinstance(img, list) and det == True:
         if isinstance(img, list) and det == True:
@@ -149,7 +150,7 @@ class PytorchPaddleOCR(TextSystem):
                     if not isinstance(img, list):
                     if not isinstance(img, list):
                         img = preprocess_image(img)
                         img = preprocess_image(img)
                         img = [img]
                         img = [img]
-                    rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable)
+                    rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
                     # logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
                     # logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
                     ocr_res.append(rec_res)
                     ocr_res.append(rec_res)
                 return ocr_res
                 return ocr_res

+ 2 - 2
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py

@@ -288,7 +288,7 @@ class TextRecognizer(BaseOCRV20):
 
 
         return img
         return img
 
 
-    def __call__(self, img_list, tqdm_enable=False):
+    def __call__(self, img_list, tqdm_enable=False, tqdm_desc="OCR-rec Predict"):
         img_num = len(img_list)
         img_num = len(img_list)
         # Calculate the aspect ratio of all text bars
         # Calculate the aspect ratio of all text bars
         width_list = []
         width_list = []
@@ -302,7 +302,7 @@ class TextRecognizer(BaseOCRV20):
         batch_num = self.rec_batch_num
         batch_num = self.rec_batch_num
         elapse = 0
         elapse = 0
         # for beg_img_no in range(0, img_num, batch_num):
         # for beg_img_no in range(0, img_num, batch_num):
-        with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
+        with tqdm(total=img_num, desc=tqdm_desc, disable=not tqdm_enable) as pbar:
             index = 0
             index = 0
             for beg_img_no in range(0, img_num, batch_num):
             for beg_img_no in range(0, img_num, batch_num):
                 end_img_no = min(img_num, beg_img_no + batch_num)
                 end_img_no = min(img_num, beg_img_no + batch_num)

+ 153 - 1
mineru/model/ori_cls/paddle_ori_cls.py

@@ -2,10 +2,12 @@
 import os
 import os
 
 
 from PIL import Image
 from PIL import Image
+from collections import defaultdict
+from typing import List, Dict
+from tqdm import tqdm
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 import onnxruntime
 import onnxruntime
-from loguru import logger
 
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
@@ -111,3 +113,153 @@ class PaddleOrientationClsModel:
                     # logger.debug(f"Orientation classification result: {label}")
                     # logger.debug(f"Orientation classification result: {label}")
 
 
         return rotate_label
         return rotate_label
+
+    def list_2_batch(self, img_list, batch_size=16):
+        """
+        将任意长度的列表按照指定的batch size分成多个batch
+
+        Args:
+            img_list: 输入的列表
+            batch_size: 每个batch的大小,默认为16
+
+        Returns:
+            一个包含多个batch的列表,每个batch都是原列表的一个子列表
+        """
+        batches = []
+        for i in range(0, len(img_list), batch_size):
+            batch = img_list[i : min(i + batch_size, len(img_list))]
+            batches.append(batch)
+        return batches
+
+    def batch_preprocess(self, imgs):
+        res_imgs = []
+        for img_info in imgs:
+            img = np.asarray(img_info["table_img"])
+            # 放大图片,使其最短边长为256
+            h, w = img.shape[:2]
+            scale = 256 / min(h, w)
+            h_resize = round(h * scale)
+            w_resize = round(w * scale)
+            img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+            # 调整为224*224的正方形
+            h, w = img.shape[:2]
+            cw, ch = 224, 224
+            x1 = max(0, (w - cw) // 2)
+            y1 = max(0, (h - ch) // 2)
+            x2 = min(w, x1 + cw)
+            y2 = min(h, y1 + ch)
+            if w < cw or h < ch:
+                raise ValueError(
+                    f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+                )
+            img = img[y1:y2, x1:x2, ...]
+            # 正则化
+            split_im = list(cv2.split(img))
+            std = [0.229, 0.224, 0.225]
+            scale = 0.00392156862745098
+            mean = [0.485, 0.456, 0.406]
+            alpha = [scale / std[i] for i in range(len(std))]
+            beta = [-mean[i] / std[i] for i in range(len(std))]
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= alpha[c]
+                split_im[c] += beta[c]
+            img = cv2.merge(split_im)
+            # 5. 转换为 CHW 格式
+            img = img.transpose((2, 0, 1))
+            res_imgs.append(img)
+        x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+
+    def batch_predict(
+        self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
+    ) -> None:
+        """
+        批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
+        """
+        RESOLUTION_GROUP_STRIDE = 128
+        # 跳过长宽比小于1.2的图片
+        resolution_groups = defaultdict(list)
+        for img in imgs:
+            # RGB图像转换BGR
+            bgr_img: np.ndarray = cv2.cvtColor(np.asarray(img["table_img"]), cv2.COLOR_RGB2BGR)
+            img["table_img_bgr"] = bgr_img
+            img_height, img_width = bgr_img.shape[:2]
+            img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+            if img_aspect_ratio > 1.2:
+                # 归一化尺寸到RESOLUTION_GROUP_STRIDE的倍数
+                normalized_h = ((img_height + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE  # 向上取整到RESOLUTION_GROUP_STRIDE的倍数
+                normalized_w = ((img_width + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+                group_key = (normalized_h, normalized_w)
+                resolution_groups[group_key].append(img)
+
+        # 对每个分辨率组进行批处理
+        rotated_imgs = []
+        for group_key, group_imgs in tqdm(resolution_groups.items(), desc="Table-ori cls stage1 predict"):
+            # 计算目标尺寸(组内最大尺寸,向上取整到RESOLUTION_GROUP_STRIDE的倍数)
+            max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
+            max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
+            target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+            target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+
+            # 对所有图像进行padding到统一尺寸
+            batch_images = []
+            for img in group_imgs:
+                bgr_img = img["table_img_bgr"]
+                h, w = bgr_img.shape[:2]
+                # 创建目标尺寸的白色背景
+                padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                # 将原图像粘贴到左上角
+                padded_img[:h, :w] = bgr_img
+                batch_images.append(padded_img)
+
+            # 批处理检测
+            batch_results = self.ocr_engine.text_detector.batch_predict(
+                batch_images, min(len(batch_images), det_batch_size)
+            )
+
+            # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
+
+            for index, (img_info, (dt_boxes, elapse)) in enumerate(
+                zip(group_imgs, batch_results)
+            ):
+                vertical_count = 0
+                for box_ocr_res in dt_boxes:
+                    p1, p2, p3, p4 = box_ocr_res
+
+                    # Calculate width and height
+                    width = p3[0] - p1[0]
+                    height = p3[1] - p1[1]
+
+                    aspect_ratio = width / height if height > 0 else 1.0
+
+                    # Count vertical text boxes
+                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                        vertical_count += 1
+
+                if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
+                    rotated_imgs.append(img_info)
+
+        # 对旋转的图片进行旋转角度预测
+        if len(rotated_imgs) > 0:
+            imgs = self.list_2_batch(rotated_imgs, batch_size=batch_size)
+            with tqdm(total=len(rotated_imgs), desc="Table-ori cls stage2 predict") as pbar:
+                for img_batch in imgs:
+                    x = self.batch_preprocess(img_batch)
+                    results = self.sess.run(None, {"x": x})
+                    for img_info, res in zip(rotated_imgs, results[0]):
+                        label = self.labels[np.argmax(res)]
+                        if label == "270":
+                            img_info["table_img"] = cv2.rotate(
+                                np.asarray(img_info["table_img"]),
+                                cv2.ROTATE_90_CLOCKWISE,
+                            )
+                        elif label == "90":
+                            img_info["table_img"] = cv2.rotate(
+                                np.asarray(img_info["table_img"]),
+                                cv2.ROTATE_90_COUNTERCLOCKWISE,
+                            )
+                        else:
+                            # 180度和0度不做处理
+                            pass
+                        pbar.update(1)

+ 78 - 0
mineru/model/table/cls/paddle_table_cls.py

@@ -1,10 +1,12 @@
 import os
 import os
+from pathlib import Path
 
 
 from PIL import Image
 from PIL import Image
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 import onnxruntime
 import onnxruntime
 from loguru import logger
 from loguru import logger
+from tqdm import tqdm
 
 
 from mineru.backend.pipeline.model_list import AtomicModel
 from mineru.backend.pipeline.model_list import AtomicModel
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.enum_class import ModelPath
@@ -75,3 +77,79 @@ class PaddleTableClsModel:
         if idx == 0 and conf < 0.8:
         if idx == 0 and conf < 0.8:
             idx = 1
             idx = 1
         return self.labels[idx], conf
         return self.labels[idx], conf
+
+    def list_2_batch(self, img_list, batch_size=16):
+        """
+        将任意长度的列表按照指定的batch size分成多个batch
+
+        Args:
+            img_list: 输入的列表
+            batch_size: 每个batch的大小,默认为16
+
+        Returns:
+            一个包含多个batch的列表,每个batch都是原列表的一个子列表
+        """
+        batches = []
+        for i in range(0, len(img_list), batch_size):
+            batch = img_list[i : min(i + batch_size, len(img_list))]
+            batches.append(batch)
+        return batches
+
+    def batch_preprocess(self, imgs):
+        res_imgs = []
+        for img in imgs:
+            img = np.asarray(img)
+            # 放大图片,使其最短边长为256
+            h, w = img.shape[:2]
+            scale = 256 / min(h, w)
+            h_resize = round(h * scale)
+            w_resize = round(w * scale)
+            img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+            # 调整为224*224的正方形
+            h, w = img.shape[:2]
+            cw, ch = 224, 224
+            x1 = max(0, (w - cw) // 2)
+            y1 = max(0, (h - ch) // 2)
+            x2 = min(w, x1 + cw)
+            y2 = min(h, y1 + ch)
+            if w < cw or h < ch:
+                raise ValueError(
+                    f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+                )
+            img = img[y1:y2, x1:x2, ...]
+            # 正则化
+            split_im = list(cv2.split(img))
+            std = [0.229, 0.224, 0.225]
+            scale = 0.00392156862745098
+            mean = [0.485, 0.456, 0.406]
+            alpha = [scale / std[i] for i in range(len(std))]
+            beta = [-mean[i] / std[i] for i in range(len(std))]
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= alpha[c]
+                split_im[c] += beta[c]
+            img = cv2.merge(split_im)
+            # 5. 转换为 CHW 格式
+            img = img.transpose((2, 0, 1))
+            res_imgs.append(img)
+        x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+    def batch_predict(self, img_info_list, batch_size=16):
+        imgs = [item["table_img"] for item in img_info_list]
+        imgs = self.list_2_batch(imgs, batch_size=batch_size)
+        label_res = []
+        with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict") as pbar:
+            for img_batch in imgs:
+                x = self.batch_preprocess(img_batch)
+                result = self.sess.run(None, {"x": x})
+                for img_res in result[0]:
+                    idx = np.argmax(img_res)
+                    conf = float(np.max(img_res))
+                    # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
+                    if idx == 0 and conf < 0.8:
+                        idx = 1
+                    label_res.append((self.labels[idx],conf))
+                pbar.update(len(img_batch))
+            for img_info, (label, conf) in zip(img_info_list, label_res):
+                img_info['table_res']["cls_label"] = label
+                img_info['table_res']["cls_score"] = round(conf, 3)

+ 0 - 46
mineru/model/table/rec/rapid_table.py

@@ -1,46 +0,0 @@
-import os
-import html
-import cv2
-import numpy as np
-from loguru import logger
-from rapid_table import RapidTable, RapidTableInput
-
-from mineru.utils.enum_class import ModelPath
-from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-
-
-def escape_html(input_string):
-    """Escape HTML Entities."""
-    return html.escape(input_string)
-
-
-class RapidTableModel(object):
-    def __init__(self, ocr_engine):
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.table_model = RapidTable(input_args)
-        self.ocr_engine = ocr_engine
-
-
-    def predict(self, image, table_cls_score):
-        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
-        # Continue with OCR on potentially rotated image
-        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-        if ocr_result:
-            ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
-                      len(item) == 2 and isinstance(item[1], tuple)]
-        else:
-            ocr_result = None
-
-        if ocr_result:
-            try:
-                table_results = self.table_model(np.asarray(image), ocr_result)
-                html_code = table_results.pred_html
-                table_cell_bboxes = table_results.cell_bboxes
-                logic_points = table_results.logic_points
-                elapse = table_results.elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
-            except Exception as e:
-                logger.exception(e)
-
-        return None, None, None, None

+ 0 - 0
mineru/model/table/rec/slanet_plus/__init__.py


+ 266 - 0
mineru/model/table/rec/slanet_plus/main.py

@@ -0,0 +1,266 @@
+import os
+import argparse
+import copy
+import importlib
+import time
+import html
+from dataclasses import asdict, dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from .matcher import TableMatch
+from .table_structure import TableStructurer
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+root_dir = Path(__file__).resolve().parent
+
+
+class ModelType(Enum):
+    PPSTRUCTURE_EN = "ppstructure_en"
+    PPSTRUCTURE_ZH = "ppstructure_zh"
+    SLANETPLUS = "slanet_plus"
+    UNITABLE = "unitable"
+
+
+ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
+KEY_TO_MODEL_URL = {
+    ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
+    ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
+    ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
+    ModelType.UNITABLE.value: {
+        "encoder": f"{ROOT_URL}/unitable/encoder.pth",
+        "decoder": f"{ROOT_URL}/unitable/decoder.pth",
+        "vocab": f"{ROOT_URL}/unitable/vocab.json",
+    },
+}
+
+
+@dataclass
+class RapidTableInput:
+    model_type: Optional[str] = ModelType.SLANETPLUS.value
+    model_path: Union[str, Path, None, Dict[str, str]] = None
+    use_cuda: bool = False
+    device: str = "cpu"
+
+
+@dataclass
+class RapidTableOutput:
+    pred_html: Optional[str] = None
+    cell_bboxes: Optional[np.ndarray] = None
+    logic_points: Optional[np.ndarray] = None
+    elapse: Optional[float] = None
+
+
+class RapidTable:
+    def __init__(self, config: RapidTableInput):
+        self.model_type = config.model_type
+        if self.model_type not in KEY_TO_MODEL_URL:
+            model_list = ",".join(KEY_TO_MODEL_URL)
+            raise ValueError(
+                f"{self.model_type} is not supported. The currently supported models are {model_list}."
+            )
+
+        config.model_path = config.model_path
+        if self.model_type == ModelType.SLANETPLUS.value:
+            self.table_structure = TableStructurer(asdict(config))
+        else:
+            raise ValueError(f"{self.model_type} is not supported.")
+        self.table_matcher = TableMatch()
+
+    def predict(
+        self,
+        img: np.ndarray,
+        ocr_result: List[Union[List[List[float]], str, str]] = None,
+    ) -> RapidTableOutput:
+        if ocr_result is None:
+            raise ValueError("OCR result is None")
+
+        s = time.perf_counter()
+        h, w = img.shape[:2]
+
+        dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
+
+        pred_structures, cell_bboxes, _ = self.table_structure.process(
+            copy.deepcopy(img)
+        )
+
+        # 适配slanet-plus模型输出的box缩放还原
+        cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+
+        pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
+
+        # 过滤掉占位的bbox
+        mask = ~np.all(cell_bboxes == 0, axis=1)
+        cell_bboxes = cell_bboxes[mask]
+
+        logic_points = self.table_matcher.decode_logic_points(pred_structures)
+        elapse = time.perf_counter() - s
+        return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
+
+    def batch_predict(
+        self,
+        images: List[np.ndarray],
+        ocr_results: List[List[Union[List[List[float]], str, str]]],
+        batch_size: int = 4,
+    ) -> List[RapidTableOutput]:
+        """批量处理图像"""
+        s = time.perf_counter()
+
+        batch_dt_boxes = []
+        batch_rec_res = []
+
+        for i, img in enumerate(images):
+            h, w = img.shape[:2]
+            dt_boxes, rec_res = self.get_boxes_recs(ocr_results[i], h, w)
+            batch_dt_boxes.append(dt_boxes)
+            batch_rec_res.append(rec_res)
+
+        # 批量表格结构识别
+        batch_results = self.table_structure.batch_process(images)
+
+        output_results = []
+        for i, (img, ocr_result, (pred_structures, cell_bboxes, _)) in enumerate(
+            zip(images, ocr_results, batch_results)
+        ):
+            # 适配slanet-plus模型输出的box缩放还原
+            cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+            pred_html = self.table_matcher(
+                pred_structures, cell_bboxes, batch_dt_boxes[i], batch_rec_res[i]
+            )
+            # 过滤掉占位的bbox
+            mask = ~np.all(cell_bboxes == 0, axis=1)
+            cell_bboxes = cell_bboxes[mask]
+
+            logic_points = self.table_matcher.decode_logic_points(pred_structures)
+            result = RapidTableOutput(pred_html, cell_bboxes, logic_points, 0)
+            output_results.append(result)
+
+        total_elapse = time.perf_counter() - s
+        for result in output_results:
+            result.elapse = total_elapse / len(output_results)
+
+        return output_results
+
+    def get_boxes_recs(
+        self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
+    ) -> Tuple[np.ndarray, Tuple[str, str]]:
+        dt_boxes, rec_res, scores = list(zip(*ocr_result))
+        rec_res = list(zip(rec_res, scores))
+
+        r_boxes = []
+        for box in dt_boxes:
+            box = np.array(box)
+            x_min = max(0, box[:, 0].min() - 1)
+            x_max = min(w, box[:, 0].max() + 1)
+            y_min = max(0, box[:, 1].min() - 1)
+            y_max = min(h, box[:, 1].max() + 1)
+            box = [x_min, y_min, x_max, y_max]
+            r_boxes.append(box)
+        dt_boxes = np.array(r_boxes)
+        return dt_boxes, rec_res
+
+    def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
+        h, w = img.shape[:2]
+        resized = 488
+        ratio = min(resized / h, resized / w)
+        w_ratio = resized / (w * ratio)
+        h_ratio = resized / (h * ratio)
+        cell_bboxes[:, 0::2] *= w_ratio
+        cell_bboxes[:, 1::2] *= h_ratio
+        return cell_bboxes
+
+
+def parse_args(arg_list: Optional[List[str]] = None):
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-v",
+        "--vis",
+        action="store_true",
+        default=False,
+        help="Wheter to visualize the layout results.",
+    )
+    parser.add_argument(
+        "-img", "--img_path", type=str, required=True, help="Path to image for layout."
+    )
+    parser.add_argument(
+        "-m",
+        "--model_type",
+        type=str,
+        default=ModelType.SLANETPLUS.value,
+        choices=list(KEY_TO_MODEL_URL),
+    )
+    args = parser.parse_args(arg_list)
+    return args
+
+
+def escape_html(input_string):
+    """Escape HTML Entities."""
+    return html.escape(input_string)
+
+
+class RapidTableModel(object):
+    def __init__(self, ocr_engine):
+        slanet_plus_model_path = os.path.join(
+            auto_download_and_get_model_root_path(ModelPath.slanet_plus),
+            ModelPath.slanet_plus,
+        )
+        input_args = RapidTableInput(
+            model_type="slanet_plus", model_path=slanet_plus_model_path
+        )
+        self.table_model = RapidTable(input_args)
+        self.ocr_engine = ocr_engine
+
+    def predict(self, image, table_cls_score):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        # Continue with OCR on potentially rotated image
+        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+        if ocr_result:
+            ocr_result = [
+                [item[0], escape_html(item[1][0]), item[1][1]]
+                for item in ocr_result
+                if len(item) == 2 and isinstance(item[1], tuple)
+            ]
+        else:
+            ocr_result = None
+
+        if ocr_result:
+            try:
+                table_results = self.table_model.predict(np.asarray(image), ocr_result)
+                html_code = table_results.pred_html
+                table_cell_bboxes = table_results.cell_bboxes
+                logic_points = table_results.logic_points
+                elapse = table_results.elapse
+                return html_code, table_cell_bboxes, logic_points, elapse
+            except Exception as e:
+                logger.exception(e)
+
+        return None, None, None, None
+
+    def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
+        """对传入的字典列表进行批量预测,无返回值"""
+        with tqdm(total=len(table_res_list), desc="Table-wireless Predict") as pbar:
+            for index in range(0, len(table_res_list), batch_size):
+                batch_imgs = [
+                    cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
+                    for i in range(index, min(index + batch_size, len(table_res_list)))
+                ]
+                batch_ocrs = [
+                    table_res_list[i]["ocr_result"]
+                    for i in range(index, min(index + batch_size, len(table_res_list)))
+                ]
+                results = self.table_model.batch_predict(
+                    batch_imgs, batch_ocrs, batch_size=batch_size
+                )
+                for i, result in enumerate(results):
+                    if result.pred_html:
+                        table_res_list[index + i]['table_res']['html'] = result.pred_html
+
+                # 更新进度条
+                pbar.update(len(results))

+ 198 - 0
mineru/model/table/rec/slanet_plus/matcher.py

@@ -0,0 +1,198 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+
+from .matcher_utils import compute_iou, distance
+
+
+class TableMatch:
+    def __init__(self, filter_ocr_result=True, use_master=False):
+        self.filter_ocr_result = filter_ocr_result
+        self.use_master = use_master
+
+    def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
+        if self.filter_ocr_result:
+            dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
+        matched_index = self.match_result(dt_boxes, cell_bboxes)
+        pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
+        return pred_html
+
+    def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
+        matched = {}
+        for i, gt_box in enumerate(dt_boxes):
+            distances = []
+            for j, pred_box in enumerate(cell_bboxes):
+                if len(pred_box) == 8:
+                    pred_box = [
+                        np.min(pred_box[0::2]),
+                        np.min(pred_box[1::2]),
+                        np.max(pred_box[0::2]),
+                        np.max(pred_box[1::2]),
+                    ]
+                distances.append(
+                    (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
+                )  # compute iou and l1 distance
+            sorted_distances = distances.copy()
+            # select det box by iou and l1 distance
+            sorted_distances = sorted(
+                sorted_distances, key=lambda item: (item[1], item[0])
+            )
+            # must > min_iou
+            if sorted_distances[0][1] >= 1 - min_iou:
+                continue
+
+            if distances.index(sorted_distances[0]) not in matched:
+                matched[distances.index(sorted_distances[0])] = [i]
+            else:
+                matched[distances.index(sorted_distances[0])].append(i)
+        return matched
+
+    def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+        end_html = []
+        td_index = 0
+        for tag in pred_structures:
+            if "</td>" not in tag:
+                end_html.append(tag)
+                continue
+
+            if "<td></td>" == tag:
+                end_html.extend("<td>")
+
+            if td_index in matched_index.keys():
+                b_with = False
+                if (
+                    "<b>" in ocr_contents[matched_index[td_index][0]]
+                    and len(matched_index[td_index]) > 1
+                ):
+                    b_with = True
+                    end_html.extend("<b>")
+
+                for i, td_index_index in enumerate(matched_index[td_index]):
+                    content = ocr_contents[td_index_index][0]
+                    if len(matched_index[td_index]) > 1:
+                        if len(content) == 0:
+                            continue
+
+                        if content[0] == " ":
+                            content = content[1:]
+
+                        if "<b>" in content:
+                            content = content[3:]
+
+                        if "</b>" in content:
+                            content = content[:-4]
+
+                        if len(content) == 0:
+                            continue
+
+                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+                            content += " "
+                    end_html.extend(content)
+
+                if b_with:
+                    end_html.extend("</b>")
+
+            if "<td></td>" == tag:
+                end_html.append("</td>")
+            else:
+                end_html.append(tag)
+
+            td_index += 1
+
+        # Filter <thead></thead><tbody></tbody> elements
+        filter_elements = ["<thead>", "</thead>", "<tbody>", "</tbody>"]
+        end_html = [v for v in end_html if v not in filter_elements]
+        return "".join(end_html), end_html
+
+    def decode_logic_points(self, pred_structures):
+        logic_points = []
+        current_row = 0
+        current_col = 0
+        max_rows = 0
+        max_cols = 0
+        occupied_cells = {}  # 用于记录已经被占用的单元格
+
+        def is_occupied(row, col):
+            return (row, col) in occupied_cells
+
+        def mark_occupied(row, col, rowspan, colspan):
+            for r in range(row, row + rowspan):
+                for c in range(col, col + colspan):
+                    occupied_cells[(r, c)] = True
+
+        i = 0
+        while i < len(pred_structures):
+            token = pred_structures[i]
+
+            if token == "<tr>":
+                current_col = 0  # 每次遇到 <tr> 时,重置当前列号
+            elif token == "</tr>":
+                current_row += 1  # 行结束,行号增加
+            elif token.startswith("<td"):
+                colspan = 1
+                rowspan = 1
+                j = i
+                if token != "<td></td>":
+                    j += 1
+                    # 提取 colspan 和 rowspan 属性
+                    while j < len(pred_structures) and not pred_structures[
+                        j
+                    ].startswith(">"):
+                        if "colspan=" in pred_structures[j]:
+                            colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+                        elif "rowspan=" in pred_structures[j]:
+                            rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+                        j += 1
+
+                # 跳过已经处理过的属性 token
+                i = j
+
+                # 找到下一个未被占用的列
+                while is_occupied(current_row, current_col):
+                    current_col += 1
+
+                # 计算逻辑坐标
+                r_start = current_row
+                r_end = current_row + rowspan - 1
+                col_start = current_col
+                col_end = current_col + colspan - 1
+
+                # 记录逻辑坐标
+                logic_points.append([r_start, r_end, col_start, col_end])
+
+                # 标记占用的单元格
+                mark_occupied(r_start, col_start, rowspan, colspan)
+
+                # 更新当前列号
+                current_col += colspan
+
+                # 更新最大行数和列数
+                max_rows = max(max_rows, r_end + 1)
+                max_cols = max(max_cols, col_end + 1)
+
+            i += 1
+
+        return logic_points
+
+    def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
+        y1 = cell_bboxes[:, 1::2].min()
+        new_dt_boxes = []
+        new_rec_res = []
+
+        for box, rec in zip(dt_boxes, rec_res):
+            if np.max(box[1::2]) < y1:
+                continue
+            new_dt_boxes.append(box)
+            new_rec_res.append(rec)
+        return new_dt_boxes, new_rec_res

+ 246 - 0
mineru/model/table/rec/slanet_plus/matcher_utils.py

@@ -0,0 +1,246 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import re
+
+
+def deal_isolate_span(thead_part):
+    """
+    Deal with isolate span cases in this function.
+    It causes by wrong prediction in structure recognition model.
+    eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out isolate span tokens.
+    isolate_pattern = (
+        r"<td></td> rowspan='(\d)+' colspan='(\d)+'></b></td>|"
+        r"<td></td> colspan='(\d)+' rowspan='(\d)+'></b></td>|"
+        r"<td></td> rowspan='(\d)+'></b></td>|"
+        r"<td></td> colspan='(\d)+'></b></td>"
+    )
+    isolate_iter = re.finditer(isolate_pattern, thead_part)
+    isolate_list = [i.group() for i in isolate_iter]
+
+    # 2. find out span number, by step 1 result.
+    span_pattern = (
+        r" rowspan='(\d)+' colspan='(\d)+'|"
+        r" colspan='(\d)+' rowspan='(\d)+'|"
+        r" rowspan='(\d)+'|"
+        r" colspan='(\d)+'"
+    )
+    corrected_list = []
+    for isolate_item in isolate_list:
+        span_part = re.search(span_pattern, isolate_item)
+        spanStr_in_isolateItem = span_part.group()
+        # 3. merge the span number into the span token format string.
+        if spanStr_in_isolateItem is not None:
+            corrected_item = f"<td{spanStr_in_isolateItem}></td>"
+            corrected_list.append(corrected_item)
+        else:
+            corrected_list.append(None)
+
+    # 4. replace original isolated token.
+    for corrected_item, isolate_item in zip(corrected_list, isolate_list):
+        if corrected_item is not None:
+            thead_part = thead_part.replace(isolate_item, corrected_item)
+        else:
+            pass
+    return thead_part
+
+
+def deal_duplicate_bb(thead_part):
+    """
+    Deal duplicate <b> or </b> after replace.
+    Keep one <b></b> in a <td></td> token.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out <td></td> in <thead></thead>.
+    td_pattern = (
+        r"<td rowspan='(\d)+' colspan='(\d)+'>(.+?)</td>|"
+        r"<td colspan='(\d)+' rowspan='(\d)+'>(.+?)</td>|"
+        r"<td rowspan='(\d)+'>(.+?)</td>|"
+        r"<td colspan='(\d)+'>(.+?)</td>|"
+        r"<td>(.*?)</td>"
+    )
+    td_iter = re.finditer(td_pattern, thead_part)
+    td_list = [t.group() for t in td_iter]
+
+    # 2. is multiply <b></b> in <td></td> or not?
+    new_td_list = []
+    for td_item in td_list:
+        if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
+            # multiply <b></b> in <td></td> case.
+            # 1. remove all <b></b>
+            td_item = td_item.replace("<b>", "").replace("</b>", "")
+            # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
+            td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
+            new_td_list.append(td_item)
+        else:
+            new_td_list.append(td_item)
+
+    # 3. replace original thead part.
+    for td_item, new_td_item in zip(td_list, new_td_list):
+        thead_part = thead_part.replace(td_item, new_td_item)
+    return thead_part
+
+
+def deal_bb(result_token):
+    """
+    In our opinion, <b></b> always occurs in <thead></thead> text's context.
+    This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
+    :param result_token:
+    :return:
+    """
+    # find out <thead></thead> parts.
+    thead_pattern = "<thead>(.*?)</thead>"
+    if re.search(thead_pattern, result_token) is None:
+        return result_token
+    thead_part = re.search(thead_pattern, result_token).group()
+    origin_thead_part = copy.deepcopy(thead_part)
+
+    # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
+    span_pattern = r"<td rowspan='(\d)+' colspan='(\d)+'>|<td colspan='(\d)+' rowspan='(\d)+'>|<td rowspan='(\d)+'>|<td colspan='(\d)+'>"
+    span_iter = re.finditer(span_pattern, thead_part)
+    span_list = [s.group() for s in span_iter]
+    has_span_in_head = True if len(span_list) > 0 else False
+
+    if not has_span_in_head:
+        # <thead></thead> not include "rowspan" or "colspan" branch 1.
+        # 1. replace <td> to <td><b>, and </td> to </b></td>
+        # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b></b> to </b>
+        thead_part = (
+            thead_part.replace("<td>", "<td><b>")
+            .replace("</td>", "</b></td>")
+            .replace("<b><b>", "<b>")
+            .replace("</b></b>", "</b>")
+        )
+    else:
+        # <thead></thead> include "rowspan" or "colspan" branch 2.
+        # Firstly, we deal rowspan or colspan cases.
+        # 1. replace > to ><b>
+        # 2. replace </td> to </b></td>
+        # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b><b> to </b>
+
+        # Secondly, deal ordinary cases like branch 1
+
+        # replace ">" to "<b>"
+        replaced_span_list = []
+        for sp in span_list:
+            replaced_span_list.append(sp.replace(">", "><b>"))
+        for sp, rsp in zip(span_list, replaced_span_list):
+            thead_part = thead_part.replace(sp, rsp)
+
+        # replace "</td>" to "</b></td>"
+        thead_part = thead_part.replace("</td>", "</b></td>")
+
+        # remove duplicated <b> by re.sub
+        mb_pattern = "(<b>)+"
+        single_b_string = "<b>"
+        thead_part = re.sub(mb_pattern, single_b_string, thead_part)
+
+        mgb_pattern = "(</b>)+"
+        single_gb_string = "</b>"
+        thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
+
+        # ordinary cases like branch 1
+        thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
+
+    # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
+    # but space cell(<tb> </tb>)  is suitable for <td><b> </b></td>
+    thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
+    # deal with duplicated <b></b>
+    thead_part = deal_duplicate_bb(thead_part)
+    # deal with isolate span tokens, which causes by wrong predict by structure prediction.
+    # eg.PMC5994107_011_00.png
+    thead_part = deal_isolate_span(thead_part)
+    # replace original result with new thead part.
+    result_token = result_token.replace(origin_thead_part, thead_part)
+    return result_token
+
+
+def deal_eb_token(master_token):
+    """
+    post process with <eb></eb>, <eb1></eb1>, ...
+    emptyBboxTokenDict = {
+        "[]": '<eb></eb>',
+        "[' ']": '<eb1></eb1>',
+        "['<b>', ' ', '</b>']": '<eb2></eb2>',
+        "['\\u2028', '\\u2028']": '<eb3></eb3>',
+        "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
+        "['<b>', '</b>']": '<eb5></eb5>',
+        "['<i>', ' ', '</i>']": '<eb6></eb6>',
+        "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
+        "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
+        "['<i>', '</i>']": '<eb9></eb9>',
+        "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
+    }
+    :param master_token:
+    :return:
+    """
+    master_token = master_token.replace("<eb></eb>", "<td></td>")
+    master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
+    master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
+    master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
+    master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
+    master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
+    master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
+    master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
+    master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
+    master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
+    master_token = master_token.replace(
+        "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
+    )
+    return master_token
+
+
+def distance(box_1, box_2):
+    x1, y1, x2, y2 = box_1
+    x3, y3, x4, y4 = box_2
+    dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+    dis_2 = abs(x3 - x1) + abs(y3 - y1)
+    dis_3 = abs(x4 - x2) + abs(y4 - y2)
+    return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1, rec2):
+    """
+    computing IoU
+    :param rec1: (y0, x0, y1, x1), which reflects
+            (top, left, bottom, right)
+    :param rec2: (y0, x0, y1, x1)
+    :return: scala value of IoU
+    """
+    # computing area of each rectangles
+    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+    # computing the sum_area
+    sum_area = S_rec1 + S_rec2
+
+    # find the each edge of intersect rectangle
+    left_line = max(rec1[1], rec2[1])
+    right_line = min(rec1[3], rec2[3])
+    top_line = max(rec1[0], rec2[0])
+    bottom_line = min(rec1[2], rec2[2])
+
+    # judge if there is an intersect
+    if left_line >= right_line or top_line >= bottom_line:
+        return 0.0
+
+    intersect = (right_line - left_line) * (bottom_line - top_line)
+    return (intersect / (sum_area - intersect)) * 1.0

+ 109 - 0
mineru/model/table/rec/slanet_plus/table_structure.py

@@ -0,0 +1,109 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+
+from .table_structure_utils import (
+    OrtInferSession,
+    TableLabelDecode,
+    TablePreprocess,
+    BatchTablePreprocess,
+)
+
+
+class TableStructurer:
+    def __init__(self, config: Dict[str, Any]):
+        self.preprocess_op = TablePreprocess()
+        self.batch_preprocess_op = BatchTablePreprocess()
+
+        self.session = OrtInferSession(config)
+
+        self.character = self.session.get_metadata()
+        self.postprocess_op = TableLabelDecode(self.character)
+
+    def process(self, img):
+        starttime = time.time()
+        data = {"image": img}
+        data = self.preprocess_op(data)
+        img = data[0]
+        if img is None:
+            return None, 0
+        img = np.expand_dims(img, axis=0)
+        img = img.copy()
+
+        outputs = self.session([img])
+
+        preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
+
+        shape_list = np.expand_dims(data[-1], axis=0)
+        post_result = self.postprocess_op(preds, [shape_list])
+
+        bbox_list = post_result["bbox_batch_list"][0]
+
+        structure_str_list = post_result["structure_batch_list"][0]
+        structure_str_list = structure_str_list[0]
+        structure_str_list = (
+            ["<html>", "<body>", "<table>"]
+            + structure_str_list
+            + ["</table>", "</body>", "</html>"]
+        )
+        elapse = time.time() - starttime
+        return structure_str_list, bbox_list, elapse
+
+    def batch_process(
+        self, img_list: List[np.ndarray]
+    ) -> List[Tuple[List[str], np.ndarray, float]]:
+        """批量处理图像列表
+        Args:
+            img_list: 图像列表
+        Returns:
+            结果列表,每个元素包含 (table_struct_str, cell_bboxes, elapse)
+        """
+        starttime = time.perf_counter()
+
+        batch_data = self.batch_preprocess_op(img_list)
+        preprocessed_images = batch_data[0]
+        shape_lists = batch_data[1]
+
+        preprocessed_images = np.array(preprocessed_images)
+        bbox_preds, struct_probs = self.session([preprocessed_images])
+
+        batch_size = preprocessed_images.shape[0]
+        results = []
+        for bbox_pred, struct_prob, shape_list in zip(
+            bbox_preds, struct_probs, shape_lists
+        ):
+            preds = {
+                "loc_preds": np.expand_dims(bbox_pred, axis=0),
+                "structure_probs": np.expand_dims(struct_prob, axis=0),
+            }
+            shape_list = np.expand_dims(shape_list, axis=0)
+            post_result = self.postprocess_op(preds, [shape_list])
+            bbox_list = post_result["bbox_batch_list"][0]
+            structure_str_list = post_result["structure_batch_list"][0]
+            structure_str_list = structure_str_list[0]
+            structure_str_list = (
+                ["<html>", "<body>", "<table>"]
+                + structure_str_list
+                + ["</table>", "</body>", "</html>"]
+            )
+            results.append((structure_str_list, bbox_list, 0))
+
+        total_elapse = time.perf_counter() - starttime
+        for i in range(len(results)):
+            results[i] = (results[i][0], results[i][1], total_elapse / batch_size)
+
+        return results

+ 570 - 0
mineru/model/table/rec/slanet_plus/table_structure_utils.py

@@ -0,0 +1,570 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import platform
+import traceback
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+from onnxruntime import (
+    GraphOptimizationLevel,
+    InferenceSession,
+    SessionOptions,
+    get_available_providers,
+    get_device,
+)
+
+from loguru import logger
+
+
+class EP(Enum):
+    CPU_EP = "CPUExecutionProvider"
+    CUDA_EP = "CUDAExecutionProvider"
+    DIRECTML_EP = "DmlExecutionProvider"
+
+
+class OrtInferSession:
+    def __init__(self, config: Dict[str, Any]):
+        self.logger = logger
+
+        model_path = config.get("model_path", None)
+        self._verify_model(model_path)
+
+        self.cfg_use_cuda = config.get("use_cuda", None)
+        self.cfg_use_dml = config.get("use_dml", None)
+
+        self.had_providers: List[str] = get_available_providers()
+        EP_list = self._get_ep_list()
+
+        sess_opt = self._init_sess_opts(config)
+        self.session = InferenceSession(
+            model_path,
+            sess_options=sess_opt,
+            providers=EP_list,
+        )
+        self._verify_providers()
+
+    @staticmethod
+    def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
+        sess_opt = SessionOptions()
+        sess_opt.log_severity_level = 4
+        sess_opt.enable_cpu_mem_arena = False
+        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+
+        cpu_nums = os.cpu_count()
+        intra_op_num_threads = config.get("intra_op_num_threads", -1)
+        if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
+            sess_opt.intra_op_num_threads = intra_op_num_threads
+
+        inter_op_num_threads = config.get("inter_op_num_threads", -1)
+        if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
+            sess_opt.inter_op_num_threads = inter_op_num_threads
+
+        return sess_opt
+
+    def get_metadata(self, key: str = "character") -> list:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        content_list = meta_dict[key].splitlines()
+        return content_list
+
+    def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
+        cpu_provider_opts = {
+            "arena_extend_strategy": "kSameAsRequested",
+        }
+        EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+
+        cuda_provider_opts = {
+            "device_id": 0,
+            "arena_extend_strategy": "kNextPowerOfTwo",
+            "cudnn_conv_algo_search": "EXHAUSTIVE",
+            "do_copy_in_default_stream": True,
+        }
+        self.use_cuda = self._check_cuda()
+        if self.use_cuda:
+            EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
+
+        self.use_directml = self._check_dml()
+        if self.use_directml:
+            self.logger.info(
+                "Windows 10 or above detected, try to use DirectML as primary provider"
+            )
+            directml_options = (
+                cuda_provider_opts if self.use_cuda else cpu_provider_opts
+            )
+            EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
+        return EP_list
+
+    def _check_cuda(self) -> bool:
+        if not self.cfg_use_cuda:
+            return False
+
+        cur_device = get_device()
+        if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
+            return True
+
+        self.logger.warning(
+            "%s is not in available providers (%s). Use %s inference by default.",
+            EP.CUDA_EP.value,
+            self.had_providers,
+            self.had_providers[0],
+        )
+        self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
+        self.logger.info(
+            "(For reference only) If you want to use GPU acceleration, you must do:"
+        )
+        self.logger.info(
+            "First, uninstall all onnxruntime pakcages in current environment."
+        )
+        self.logger.info(
+            "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
+        )
+        self.logger.info(
+            "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
+        )
+        self.logger.info(
+            "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
+        )
+        self.logger.info(
+            "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
+            EP.CUDA_EP.value,
+        )
+        return False
+
+    def _check_dml(self) -> bool:
+        if not self.cfg_use_dml:
+            return False
+
+        cur_os = platform.system()
+        if cur_os != "Windows":
+            self.logger.warning(
+                "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
+                cur_os,
+                self.had_providers[0],
+            )
+            return False
+
+        cur_window_version = int(platform.release().split(".")[0])
+        if cur_window_version < 10:
+            self.logger.warning(
+                "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
+                cur_window_version,
+                self.had_providers[0],
+            )
+            return False
+
+        if EP.DIRECTML_EP.value in self.had_providers:
+            return True
+
+        self.logger.warning(
+            "%s is not in available providers (%s). Use %s inference by default.",
+            EP.DIRECTML_EP.value,
+            self.had_providers,
+            self.had_providers[0],
+        )
+        self.logger.info("If you want to use DirectML acceleration, you must do:")
+        self.logger.info(
+            "First, uninstall all onnxruntime pakcages in current environment."
+        )
+        self.logger.info(
+            "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
+        )
+        self.logger.info(
+            "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
+            EP.DIRECTML_EP.value,
+        )
+        return False
+
+    def _verify_providers(self):
+        session_providers = self.session.get_providers()
+        first_provider = session_providers[0]
+
+        if self.use_cuda and first_provider != EP.CUDA_EP.value:
+            self.logger.warning(
+                "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
+                EP.CUDA_EP.value,
+                first_provider,
+            )
+
+        if self.use_directml and first_provider != EP.DIRECTML_EP.value:
+            self.logger.warning(
+                "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
+                EP.DIRECTML_EP.value,
+                first_provider,
+            )
+
+    def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
+        input_dict = dict(zip(self.get_input_names(), input_content))
+        try:
+            return self.session.run(None, input_dict)
+        except Exception as e:
+            error_info = traceback.format_exc()
+            raise ONNXRuntimeError(error_info) from e
+
+    def get_input_names(self) -> List[str]:
+        return [v.name for v in self.session.get_inputs()]
+
+    def get_output_names(self) -> List[str]:
+        return [v.name for v in self.session.get_outputs()]
+
+    def get_character_list(self, key: str = "character") -> List[str]:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        return meta_dict[key].splitlines()
+
+    def have_key(self, key: str = "character") -> bool:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        if key in meta_dict.keys():
+            return True
+        return False
+
+    @staticmethod
+    def _verify_model(model_path: Union[str, Path, None]):
+        if model_path is None:
+            raise ValueError("model_path is None!")
+
+        model_path = Path(model_path)
+        if not model_path.exists():
+            raise FileNotFoundError(f"{model_path} does not exists.")
+
+        if not model_path.is_file():
+            raise FileExistsError(f"{model_path} is not a file.")
+
+
+class ONNXRuntimeError(Exception):
+    pass
+
+
+class TableLabelDecode:
+    def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
+        if merge_no_span_structure:
+            if "<td></td>" not in dict_character:
+                dict_character.append("<td></td>")
+            if "<td>" in dict_character:
+                dict_character.remove("<td>")
+
+        dict_character = self.add_special_char(dict_character)
+        self.dict = {}
+        for i, char in enumerate(dict_character):
+            self.dict[char] = i
+        self.character = dict_character
+        self.td_token = ["<td>", "<td", "<td></td>"]
+
+    def __call__(self, preds, batch=None):
+        structure_probs = preds["structure_probs"]
+        bbox_preds = preds["loc_preds"]
+        shape_list = batch[-1]
+        result = self.decode(structure_probs, bbox_preds, shape_list)
+        if len(batch) == 1:  # only contains shape
+            return result
+
+        label_decode_result = self.decode_label(batch)
+        return result, label_decode_result
+
+    def decode(self, structure_probs, bbox_preds, shape_list):
+        """convert text-label into text-index."""
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_idx = structure_probs.argmax(axis=2)
+        structure_probs = structure_probs.max(axis=2)
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            score_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+
+                if char_idx in ignored_tokens:
+                    continue
+
+                text = self.character[char_idx]
+                if text in self.td_token:
+                    bbox = bbox_preds[batch_idx, idx]
+                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox_list.append(bbox)
+                structure_list.append(text)
+                score_list.append(structure_probs[batch_idx, idx])
+            structure_batch_list.append([structure_list, np.mean(score_list)])
+            bbox_batch_list.append(np.array(bbox_list))
+        result = {
+            "bbox_batch_list": bbox_batch_list,
+            "structure_batch_list": structure_batch_list,
+        }
+        return result
+
+    def decode_label(self, batch):
+        """convert text-label into text-index."""
+        structure_idx = batch[1]
+        gt_bbox_list = batch[2]
+        shape_list = batch[-1]
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+
+                if char_idx in ignored_tokens:
+                    continue
+
+                structure_list.append(self.character[char_idx])
+
+                bbox = gt_bbox_list[batch_idx][idx]
+                if bbox.sum() != 0:
+                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox_list.append(bbox)
+
+            structure_batch_list.append(structure_list)
+            bbox_batch_list.append(bbox_list)
+        result = {
+            "bbox_batch_list": bbox_batch_list,
+            "structure_batch_list": structure_batch_list,
+        }
+        return result
+
+    def _bbox_decode(self, bbox, shape):
+        h, w = shape[:2]
+        bbox[0::2] *= w
+        bbox[1::2] *= h
+        return bbox
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            return np.array(self.dict[self.beg_str])
+
+        if beg_or_end == "end":
+            return np.array(self.dict[self.end_str])
+
+        raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
+
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = [self.beg_str] + dict_character + [self.end_str]
+        return dict_character
+
+
+class TablePreprocess:
+    def __init__(self):
+        self.table_max_len = 488
+        self.build_pre_process_list()
+        self.ops = self.create_operators()
+
+    def __call__(self, data):
+        """transform"""
+        if self.ops is None:
+            self.ops = []
+
+        for op in self.ops:
+            data = op(data)
+            if data is None:
+                return None
+        return data
+
+    def create_operators(
+        self,
+    ):
+        """
+        create operators based on the config
+
+        Args:
+            params(list): a dict list, used to create some operators
+        """
+        assert isinstance(
+            self.pre_process_list, list
+        ), "operator config should be a list"
+        ops = []
+        for operator in self.pre_process_list:
+            assert (
+                isinstance(operator, dict) and len(operator) == 1
+            ), "yaml format error"
+            op_name = list(operator)[0]
+            param = {} if operator[op_name] is None else operator[op_name]
+            op = eval(op_name)(**param)
+            ops.append(op)
+        return ops
+
+    def build_pre_process_list(self):
+        resize_op = {
+            "ResizeTableImage": {
+                "max_len": self.table_max_len,
+            }
+        }
+        pad_op = {
+            "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
+        }
+        normalize_op = {
+            "NormalizeImage": {
+                "std": [0.229, 0.224, 0.225],
+                "mean": [0.485, 0.456, 0.406],
+                "scale": "1./255.",
+                "order": "hwc",
+            }
+        }
+        to_chw_op = {"ToCHWImage": None}
+        keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
+        self.pre_process_list = [
+            resize_op,
+            normalize_op,
+            pad_op,
+            to_chw_op,
+            keep_keys_op,
+        ]
+
+
+class BatchTablePreprocess:
+
+    def __init__(self):
+        self.preprocess = TablePreprocess()
+
+    def __call__(
+        self, img_list: List[np.ndarray]
+    ) -> Tuple[List[np.ndarray], List[List[float]]]:
+        """批量处理图像
+
+        Args:
+            img_list: 图像列表
+
+        Returns:
+            预处理后的图像列表和形状信息列表
+        """
+        processed_imgs = []
+        shape_lists = []
+
+        for img in img_list:
+            if img is None:
+                continue
+            data = {"image": img}
+            img_processed, shape_list = self.preprocess(data)
+            processed_imgs.append(img_processed)
+            shape_lists.append(shape_list)
+        return processed_imgs, shape_lists
+
+
+class ResizeTableImage:
+    def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
+        super(ResizeTableImage, self).__init__()
+        self.max_len = max_len
+        self.resize_bboxes = resize_bboxes
+        self.infer_mode = infer_mode
+
+    def __call__(self, data):
+        img = data["image"]
+        height, width = img.shape[0:2]
+        ratio = self.max_len / (max(height, width) * 1.0)
+        resize_h = int(height * ratio)
+        resize_w = int(width * ratio)
+        resize_img = cv2.resize(img, (resize_w, resize_h))
+        if self.resize_bboxes and not self.infer_mode:
+            data["bboxes"] = data["bboxes"] * ratio
+        data["image"] = resize_img
+        data["src_img"] = img
+        data["shape"] = np.array([height, width, ratio, ratio])
+        data["max_len"] = self.max_len
+        return data
+
+
+class PaddingTableImage:
+    def __init__(self, size, **kwargs):
+        super(PaddingTableImage, self).__init__()
+        self.size = size
+
+    def __call__(self, data):
+        img = data["image"]
+        pad_h, pad_w = self.size
+        padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
+        height, width = img.shape[0:2]
+        padding_img[0:height, 0:width, :] = img.copy()
+        data["image"] = padding_img
+        shape = data["shape"].tolist()
+        shape.extend([pad_h, pad_w])
+        data["shape"] = np.array(shape)
+        return data
+
+
+class NormalizeImage:
+    """normalize image such as substract mean, divide std"""
+
+    def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+        if isinstance(scale, str):
+            scale = eval(scale)
+        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        mean = mean if mean is not None else [0.485, 0.456, 0.406]
+        std = std if std is not None else [0.229, 0.224, 0.225]
+
+        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+        self.mean = np.array(mean).reshape(shape).astype("float32")
+        self.std = np.array(std).reshape(shape).astype("float32")
+
+    def __call__(self, data):
+        img = np.array(data["image"])
+        assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+        data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
+        return data
+
+
+class ToCHWImage:
+    """convert hwc image to chw image"""
+
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, data):
+        img = np.array(data["image"])
+        data["image"] = img.transpose((2, 0, 1))
+        return data
+
+
+class KeepKeys:
+    def __init__(self, keep_keys, **kwargs):
+        self.keep_keys = keep_keys
+
+    def __call__(self, data):
+        data_list = []
+        for key in self.keep_keys:
+            data_list.append(data[key])
+        return data_list
+
+
+def trans_char_ocr_res(ocr_res):
+    word_result = []
+    for res in ocr_res:
+        score = res[2]
+        for word_box, word in zip(res[3], res[4]):
+            word_res = []
+            word_res.append(word_box)
+            word_res.append(word)
+            word_res.append(score)
+            word_result.append(word_res)
+    return word_result

+ 42 - 65
mineru/model/table/rec/unet_table/main.py

@@ -10,7 +10,6 @@ import numpy as np
 import cv2
 import cv2
 from PIL import Image
 from PIL import Image
 from loguru import logger
 from loguru import logger
-from rapid_table import RapidTableInput, RapidTable
 
 
 from .table_structure_unet import TSRUnet
 from .table_structure_unet import TSRUnet
 
 
@@ -188,7 +187,7 @@ class WiredTableRecognition:
                 continue
                 continue
             # 判断长宽比
             # 判断长宽比
             if (x2 - x1) / (y2 - y1) > 20 or (y2 - y1) / (x2 - x1) > 20:
             if (x2 - x1) / (y2 - y1) > 20 or (y2 - y1) / (x2 - x1) > 20:
-                logger.warning(f"Box {i} has invalid aspect ratio: {x1, y1, x2, y2}")
+                # logger.warning(f"Box {i} has invalid aspect ratio: {x1, y1, x2, y2}")
                 continue
                 continue
             img_crop = bgr_img[int(y1):int(y2), int(x1):int(x2)]
             img_crop = bgr_img[int(y1):int(y2), int(x1):int(x2)]
             img_crop_list.append(img_crop)
             img_crop_list.append(img_crop)
@@ -243,12 +242,9 @@ class UnetTableModel:
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
         wired_input_args = WiredTableInput(model_path=model_path)
         wired_input_args = WiredTableInput(model_path=model_path)
         self.wired_table_model = WiredTableRecognition(wired_input_args, ocr_engine)
         self.wired_table_model = WiredTableRecognition(wired_input_args, ocr_engine)
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
         self.ocr_engine = ocr_engine
 
 
-    def predict(self, input_img, table_cls_score):
+    def predict(self, input_img, ocr_result, wireless_html_code):
         if isinstance(input_img, Image.Image):
         if isinstance(input_img, Image.Image):
             np_img = np.asarray(input_img)
             np_img = np.asarray(input_img)
         elif isinstance(input_img, np.ndarray):
         elif isinstance(input_img, np.ndarray):
@@ -256,67 +252,48 @@ class UnetTableModel:
         else:
         else:
             raise ValueError("Input must be a pillow object or a numpy array.")
             raise ValueError("Input must be a pillow object or a numpy array.")
         bgr_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
         bgr_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
-        ocr_result = self.ocr_engine.ocr(bgr_img)[0]
-        if ocr_result:
+
+        if ocr_result is None:
+            ocr_result = self.ocr_engine.ocr(bgr_img)[0]
             ocr_result = [
             ocr_result = [
                 [item[0], escape_html(item[1][0]), item[1][1]]
                 [item[0], escape_html(item[1][0]), item[1][1]]
                 for item in ocr_result
                 for item in ocr_result
                 if len(item) == 2 and isinstance(item[1], tuple)
                 if len(item) == 2 and isinstance(item[1], tuple)
             ]
             ]
-        else:
-            ocr_result = None
-        if ocr_result:
-            try:
-                wired_table_results = self.wired_table_model(np_img, ocr_result)
-
-                # viser = VisTable()
-                # save_html_path = f"outputs/output.html"
-                # save_drawed_path = f"outputs/output_table_vis.jpg"
-                # save_logic_path = (
-                #     f"outputs/output_table_vis_logic.jpg"
-                # )
-                # vis_imged = viser(
-                #     np_img, wired_table_results, save_html_path, save_drawed_path, save_logic_path
-                # )
-
-                wired_html_code = wired_table_results.pred_html
-                wired_table_cell_bboxes = wired_table_results.cell_bboxes
-                wired_logic_points = wired_table_results.logic_points
-                wired_elapse = wired_table_results.elapse
-
-                wireless_table_results = self.wireless_table_model(np_img, ocr_result)
-                wireless_html_code = wireless_table_results.pred_html
-                wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
-                wireless_logic_points = wireless_table_results.logic_points
-                wireless_elapse = wireless_table_results.elapse
-
-                # wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
-                # wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
-
-                wired_len = count_table_cells_physical(wired_html_code)
-                wireless_len = count_table_cells_physical(wireless_html_code)
-
-                # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
-                # 计算两种模型检测的单元格数量差异
-                gap_of_len = wireless_len - wired_len
-                # 判断是否使用无线表格模型的结果
-                if (
-                    wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
-                    # or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
-                    or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
-                    or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
-                ):
-                    # logger.debug("fall back to wireless table model")
-                    html_code = wireless_html_code
-                    table_cell_bboxes = wireless_table_cell_bboxes
-                    logic_points = wireless_logic_points
-                else:
-                    html_code = wired_html_code
-                    table_cell_bboxes = wired_table_cell_bboxes
-                    logic_points = wired_logic_points
-
-                elapse = wired_elapse + wireless_elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
-            except Exception as e:
-                logger.exception(e)
-        return None, None, None, None
+
+        try:
+            wired_table_results = self.wired_table_model(np_img, ocr_result)
+
+            # viser = VisTable()
+            # save_html_path = f"outputs/output.html"
+            # save_drawed_path = f"outputs/output_table_vis.jpg"
+            # save_logic_path = (
+            #     f"outputs/output_table_vis_logic.jpg"
+            # )
+            # vis_imged = viser(
+            #     np_img, wired_table_results, save_html_path, save_drawed_path, save_logic_path
+            # )
+
+            wired_html_code = wired_table_results.pred_html
+
+            wired_len = count_table_cells_physical(wired_html_code)
+            wireless_len = count_table_cells_physical(wireless_html_code)
+
+            # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
+            # 计算两种模型检测的单元格数量差异
+            gap_of_len = wireless_len - wired_len
+            # 判断是否使用无线表格模型的结果
+            if (
+                wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
+                or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
+            ):
+                # logger.debug("fall back to wireless table model")
+                html_code = wireless_html_code
+            else:
+                html_code = wired_html_code
+
+            return html_code
+        except Exception as e:
+            logger.exception(e)
+            return None

+ 1 - 1
mineru/model/table/rec/unet_table/utils.py

@@ -429,7 +429,7 @@ class VisTable:
         :return:
         :return:
         """
         """
         # 读取原图
         # 读取原图
-        img = cv2.imread(img_path)
+        img = img_path
         img = cv2.copyMakeBorder(
         img = cv2.copyMakeBorder(
             img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
             img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
         )
         )

+ 2 - 1
mineru/utils/ocr_utils.py

@@ -437,6 +437,7 @@ def get_rotate_crop_image(img, points):
         borderMode=cv2.BORDER_REPLICATE,
         borderMode=cv2.BORDER_REPLICATE,
         flags=cv2.INTER_CUBIC)
         flags=cv2.INTER_CUBIC)
     dst_img_height, dst_img_width = dst_img.shape[0:2]
     dst_img_height, dst_img_width = dst_img.shape[0:2]
-    if dst_img_height * 1.0 / dst_img_width >= 1.5:
+    rotate_radio = 2
+    if dst_img_height * 1.0 / dst_img_width >= rotate_radio:
         dst_img = np.rot90(dst_img)
         dst_img = np.rot90(dst_img)
     return dst_img
     return dst_img

+ 2 - 3
pyproject.toml

@@ -36,6 +36,7 @@ dependencies = [
     "opencv-python>=4.11.0.86",
     "opencv-python>=4.11.0.86",
     "fast-langdetect>=0.2.3,<0.3.0",
     "fast-langdetect>=0.2.3,<0.3.0",
     "scikit-image>=0.25.0,<1.0.0",
     "scikit-image>=0.25.0,<1.0.0",
+    "openai>=1.70.0,<2",
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]
@@ -49,7 +50,7 @@ test = [
 ]
 ]
 vlm = [
 vlm = [
     "transformers>=4.51.1",
     "transformers>=4.51.1",
-    "torch>=2.6.0,<2.8.0",
+    "torch>=2.6.0",
     "accelerate>=1.5.1",
     "accelerate>=1.5.1",
     "pydantic",
     "pydantic",
 ]
 ]
@@ -61,10 +62,8 @@ pipeline = [
     "ultralytics>=8.3.48,<9",
     "ultralytics>=8.3.48,<9",
     "doclayout_yolo==0.0.4",
     "doclayout_yolo==0.0.4",
     "dill>=0.3.8,<1",
     "dill>=0.3.8,<1",
-    "rapid_table>=1.0.5,<2.0.0",
     "PyYAML>=6.0.2,<7",
     "PyYAML>=6.0.2,<7",
     "ftfy>=6.3.1,<7",
     "ftfy>=6.3.1,<7",
-    "openai>=1.70.0,<2",
     "shapely>=2.0.7,<3",
     "shapely>=2.0.7,<3",
     "pyclipper>=1.3.0,<2",
     "pyclipper>=1.3.0,<2",
     "omegaconf>=2.3.0,<3",
     "omegaconf>=2.3.0,<3",