Преглед на файлове

Merge pull request #3371 from Sidney233/dev

Feat: add batch predict for table rec
Xiaomeng Zhao преди 2 месеца
родител
ревизия
98c8761361

+ 103 - 59
mineru/backend/pipeline/batch_analyze.py

@@ -1,3 +1,5 @@
+import html
+
 import cv2
 from loguru import logger
 from tqdm import tqdm
@@ -8,13 +10,15 @@ from .model_init import AtomModelSingleton
 from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
-from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
+from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
+from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
 from ...utils.pdf_image_tools import get_crop_np_img
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFR_BASE_BATCH_SIZE = 16
 OCR_DET_BASE_BATCH_SIZE = 16
+ORI_TAB_CLS_BATCH_SIZE = 16
 
 
 class BatchAnalyze:
@@ -189,9 +193,6 @@ class BatchAnalyze:
 
                         if dt_boxes is not None and len(dt_boxes) > 0:
                             # 直接应用原始OCR流程中的关键处理步骤
-                            from mineru.utils.ocr_utils import (
-                                merge_det_boxes, update_det_boxes, sorted_boxes
-                            )
 
                             # 1. 排序检测框
                             if len(dt_boxes) > 0:
@@ -255,72 +256,115 @@ class BatchAnalyze:
 
         # 表格识别 table recognition
         if self.table_enable:
-            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
-                _lang = table_res_dict['lang']
+            # 图片旋转批量处理
 
-                # 调整图片方向
-                img_orientation_cls_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.ImgOrientationCls,
+            img_orientation_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.ImgOrientationCls,
+            )
+            try:
+                img_orientation_cls_model.batch_predict(table_res_list_all_page, batch_size=self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
+            except Exception as e:
+                logger.warning(
+                    f"Image orientation classification failed: {e}, using original image"
                 )
-                try:
-                    rotate_label = img_orientation_cls_model.predict(
-                        table_res_dict["table_img"]
-                    )
-                except Exception as e:
-                    logger.warning(
-                        f"Image orientation classification failed: {e}, using original image"
-                    )
-                    rotate_label = "0"
-
-                np_table_img = table_res_dict["table_img"]
-                if rotate_label == "270":
-                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_CLOCKWISE)
-                elif rotate_label == "90":
-                    np_table_img = cv2.rotate(np_table_img, cv2.ROTATE_90_COUNTERCLOCKWISE)
-                else:
-                    pass
-
-                # 有线表/无线表分类
-                table_cls_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.TableCls,
+            # 表格分类
+            table_cls_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.TableCls,
+            )
+            try:
+                table_cls_model.batch_predict(table_res_list_all_page)
+            except Exception as e:
+                logger.warning(
+                    f"Table classification failed: {e}, using default model"
                 )
-                table_cls_score = 0.5
-                try:
-                    table_label, table_cls_score = table_cls_model.predict(np_table_img)
-                except Exception as e:
-                    table_label = AtomicModel.WirelessTable
-                    logger.warning(
-                        f"Table classification failed: {e}, using default model {table_label}"
-                    )
-                # table_label = AtomicModel.WirelessTable
-                # logger.debug(f"Table classification result: {table_label}")
-                if table_label not in [AtomicModel.WirelessTable, AtomicModel.WiredTable]:
-                    raise ValueError(
-                        "Table classification failed, please check the model"
+            rec_img_lang_group = defaultdict(list)
+            # OCR det 过程,顺序执行
+            for index, table_res_dict in enumerate(
+                tqdm(table_res_list_all_page, desc="Table OCR det")
+            ):
+                _lang = table_res_dict["lang"]
+                ocr_engine = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.OCR,
+                    det_db_box_thresh=0.5,
+                    det_db_unclip_ratio=1.6,
+                    lang=_lang,
+                    enable_merge_det_boxes=False,
+                )
+                bgr_image = cv2.cvtColor(
+                    np.asarray(table_res_dict["table_img"]), cv2.COLOR_RGB2BGR
+                )
+                ocr_result = ocr_engine.ocr(bgr_image, det=True, rec=False)[0]
+                # 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
+                for dt_box in ocr_result:
+                    rec_img_lang_group[_lang].append(
+                        {
+                            "cropped_img": get_rotate_crop_image(
+                                bgr_image, np.asarray(dt_box, dtype=np.float32)
+                            ),
+                            "dt_box": np.asarray(dt_box, dtype=np.float32),
+                            "table_id": index,
+                        }
                     )
-
-                # 根据表格分类结果选择有线表格识别模型和无线表格识别模型
-                table_model = atom_model_manager.get_atom_model(
-                    atom_model_name=table_label,
+            # OCR rec,按照语言分批处理
+            for _lang, rec_img_list in rec_img_lang_group.items():
+                ocr_engine = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.OCR,
+                    det_db_box_thresh=0.5,
+                    det_db_unclip_ratio=1.6,
                     lang=_lang,
+                    enable_merge_det_boxes=False,
                 )
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(np_table_img, table_cls_score)
-                # 判断是否返回正常
-                if html_code:
+                cropped_img_list = [item["cropped_img"] for item in rec_img_list]
+                ocr_res_list = ocr_engine.ocr(
+                    cropped_img_list, det=False, rec=True, tqdm_enable=True
+                )[0]
+                # 按照 table_id 将识别结果进行回填
+                for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
+                    if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
+                        table_res_list_all_page[img_dict["table_id"]]["ocr_result"].append(
+                            [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
+                        )
+                    else:
+                        table_res_list_all_page[img_dict["table_id"]]["ocr_result"] = [
+                            [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
+                        ]
+
+            # 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
+            wireless_table_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.WirelessTable,
+            )
+            wireless_table_model.batch_predict(table_res_list_all_page)
+
+            # 单独拿出有线表格进行预测
+            wired_table_res_list = []
+            for table_res_dict in table_res_list_all_page:
+                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                    wired_table_res_list.append(table_res_dict)
+            for table_res_dict in tqdm(
+                wired_table_res_list, desc="Wired Table Predict"
+            ):
+                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                    wired_table_model = atom_model_manager.get_atom_model(
+                        atom_model_name=AtomicModel.WiredTable,
+                        lang=table_res_dict["lang"],
+                    )
+                    html_code = wired_table_model.predict(
+                        table_res_dict["table_img"],
+                        table_res_dict["ocr_result"],
+                        table_res_dict["table_res"].get("html", None)
+                    )
                     # 检查html_code是否包含'<table>'和'</table>'
-                    if '<table>' in html_code and '</table>' in html_code:
+                    if "<table>" in html_code and "</table>" in html_code:
                         # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
-                        start_index = html_code.find('<table>')
-                        end_index = html_code.rfind('</table>') + len('</table>')
-                        table_res_dict['table_res']['html'] = html_code[start_index:end_index]
+                        start_index = html_code.find("<table>")
+                        end_index = html_code.rfind("</table>") + len("</table>")
+                        table_res_dict["table_res"]["html"] = html_code[
+                            start_index:end_index
+                        ]
                     else:
                         logger.warning(
-                            'table recognition processing fails, not found expected HTML table end'
+                            "wired table recognition processing fails, not found expected HTML table end"
                         )
-                else:
-                    logger.warning(
-                        'table recognition processing fails, not get html return'
-                    )
 
         # Create dictionaries to store items by language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
-from ...model.table.rec.rapid_table import RapidTableModel
+from ...model.table.rec.slanet_plus.main import RapidTableModel
 from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path

+ 150 - 1
mineru/model/ori_cls/paddle_ori_cls.py

@@ -2,10 +2,12 @@
 import os
 
 from PIL import Image
+from collections import defaultdict
+from typing import List, Dict
+from tqdm import tqdm
 import cv2
 import numpy as np
 import onnxruntime
-from loguru import logger
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
@@ -111,3 +113,150 @@ class PaddleOrientationClsModel:
                     # logger.debug(f"Orientation classification result: {label}")
 
         return rotate_label
+
+    def list_2_batch(self, img_list, batch_size=16):
+        """
+        将任意长度的列表按照指定的batch size分成多个batch
+
+        Args:
+            img_list: 输入的列表
+            batch_size: 每个batch的大小,默认为16
+
+        Returns:
+            一个包含多个batch的列表,每个batch都是原列表的一个子列表
+        """
+        batches = []
+        for i in range(0, len(img_list), batch_size):
+            batch = img_list[i : min(i + batch_size, len(img_list))]
+            batches.append(batch)
+        return batches
+
+    def batch_preprocess(self, imgs):
+        res_imgs = []
+        for img_info in imgs:
+            # PIL图像转cv2
+            img = cv2.cvtColor(np.asarray(img_info["table_img"]), cv2.COLOR_RGB2BGR)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            # 放大图片,使其最短边长为256
+            h, w = img.shape[:2]
+            scale = 256 / min(h, w)
+            h_resize = round(h * scale)
+            w_resize = round(w * scale)
+            img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+            # 调整为224*224的正方形
+            h, w = img.shape[:2]
+            cw, ch = 224, 224
+            x1 = max(0, (w - cw) // 2)
+            y1 = max(0, (h - ch) // 2)
+            x2 = min(w, x1 + cw)
+            y2 = min(h, y1 + ch)
+            if w < cw or h < ch:
+                raise ValueError(
+                    f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+                )
+            img = img[y1:y2, x1:x2, ...]
+            # 正则化
+            split_im = list(cv2.split(img))
+            std = [0.229, 0.224, 0.225]
+            scale = 0.00392156862745098
+            mean = [0.485, 0.456, 0.406]
+            alpha = [scale / std[i] for i in range(len(std))]
+            beta = [-mean[i] / std[i] for i in range(len(std))]
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= alpha[c]
+                split_im[c] += beta[c]
+            img = cv2.merge(split_im)
+            # 5. 转换为 CHW 格式
+            img = img.transpose((2, 0, 1))
+            res_imgs.append(img)
+        x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+
+    def batch_predict(
+        self, imgs: List[Dict], batch_size: int
+    ) -> None:
+        """
+        批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
+        """
+        RESOLUTION_GROUP_STRIDE = 64
+        # 跳过长宽比小于1.2的图片
+        resolution_groups = defaultdict(list)
+        for img in imgs:
+            # RGB图像转换BGR
+            table_img: np.ndarray = cv2.cvtColor(img["table_img"], cv2.COLOR_RGB2BGR)
+            img["table_img_bgr"] = table_img
+            img_height, img_width = table_img.shape[:2]
+            img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
+            img_is_portrait = img_aspect_ratio > 1.2
+            if img_is_portrait:
+                h, w = img["table_img_bgr"].shape[:2]
+                normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE  # 向上取整到RESOLUTION_GROUP_STRIDE的倍数
+                normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+                group_key = (normalized_h, normalized_w)
+                resolution_groups[group_key].append(img)
+
+            # 对每个分辨率组进行批处理
+        for group_key, group_imgs in tqdm(
+            resolution_groups.items(), desc=f"ORI CLS OCR-det"
+        ):
+
+            # 计算目标尺寸(组内最大尺寸,向上取整到RESOLUTION_GROUP_STRIDE的倍数)
+            max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
+            max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
+            target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+            target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
+
+            # 对所有图像进行padding到统一尺寸
+            batch_images = []
+            for img in group_imgs:
+                table_img_ndarray = img["table_img_bgr"]
+                h, w = table_img_ndarray.shape[:2]
+                # 创建目标尺寸的白色背景
+                padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                # 将原图像粘贴到左上角
+                padded_img[:h, :w] = table_img_ndarray
+                batch_images.append(padded_img)
+
+            # 批处理检测
+            det_batch_size = min(len(batch_images), batch_size)  # 增加批处理大小
+            batch_results = self.ocr_engine.text_detector.batch_predict(
+                batch_images, det_batch_size
+            )
+
+            rotated_imgs = []
+            # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
+            for index, (img_info, (dt_boxes, elapse)) in enumerate(
+                zip(group_imgs, batch_results)
+            ):
+                vertical_count = 0
+                for box_ocr_res in dt_boxes:
+                    p1, p2, p3, p4 = box_ocr_res
+
+                    # Calculate width and height
+                    width = p3[0] - p1[0]
+                    height = p3[1] - p1[1]
+
+                    aspect_ratio = width / height if height > 0 else 1.0
+
+                    # Count vertical text boxes
+                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                        vertical_count += 1
+
+                if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
+                    rotated_imgs.append(img_info)
+            if len(rotated_imgs) > 0:
+                x = self.batch_preprocess(rotated_imgs)
+                results = self.sess.run(None, {"x": x})
+                for img_info, res in zip(rotated_imgs, results[0]):
+                    label = self.labels[np.argmax(res)]
+                    if label == "270":
+                        img_info["table_img"] = cv2.rotate(
+                            np.asarray(img_info["table_img"]),
+                            cv2.ROTATE_90_CLOCKWISE,
+                        )
+                    elif label == "90":
+                        img_info["table_img"] = cv2.rotate(
+                            np.asarray(img_info["table_img"]),
+                            cv2.ROTATE_90_COUNTERCLOCKWISE,
+                        )

+ 77 - 0
mineru/model/table/cls/paddle_table_cls.py

@@ -1,4 +1,5 @@
 import os
+from pathlib import Path
 
 from PIL import Image
 import cv2
@@ -75,3 +76,79 @@ class PaddleTableClsModel:
         if idx == 0 and conf < 0.8:
             idx = 1
         return self.labels[idx], conf
+
+    def list_2_batch(self, img_list, batch_size=16):
+        """
+        将任意长度的列表按照指定的batch size分成多个batch
+
+        Args:
+            img_list: 输入的列表
+            batch_size: 每个batch的大小,默认为16
+
+        Returns:
+            一个包含多个batch的列表,每个batch都是原列表的一个子列表
+        """
+        batches = []
+        for i in range(0, len(img_list), batch_size):
+            batch = img_list[i : min(i + batch_size, len(img_list))]
+            batches.append(batch)
+        return batches
+
+    def batch_preprocess(self, imgs):
+        res_imgs = []
+        for img in imgs:
+            # PIL图像转cv2
+            img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            # 放大图片,使其最短边长为256
+            h, w = img.shape[:2]
+            scale = 256 / min(h, w)
+            h_resize = round(h * scale)
+            w_resize = round(w * scale)
+            img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
+            # 调整为224*224的正方形
+            h, w = img.shape[:2]
+            cw, ch = 224, 224
+            x1 = max(0, (w - cw) // 2)
+            y1 = max(0, (h - ch) // 2)
+            x2 = min(w, x1 + cw)
+            y2 = min(h, y1 + ch)
+            if w < cw or h < ch:
+                raise ValueError(
+                    f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
+                )
+            img = img[y1:y2, x1:x2, ...]
+            # 正则化
+            split_im = list(cv2.split(img))
+            std = [0.229, 0.224, 0.225]
+            scale = 0.00392156862745098
+            mean = [0.485, 0.456, 0.406]
+            alpha = [scale / std[i] for i in range(len(std))]
+            beta = [-mean[i] / std[i] for i in range(len(std))]
+            for c in range(img.shape[2]):
+                split_im[c] = split_im[c].astype(np.float32)
+                split_im[c] *= alpha[c]
+                split_im[c] += beta[c]
+            img = cv2.merge(split_im)
+            # 5. 转换为 CHW 格式
+            img = img.transpose((2, 0, 1))
+            res_imgs.append(img)
+        x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
+        return x
+    def batch_predict(self, img_info_list, batch_size=16):
+        imgs = [item["table_img"] for item in img_info_list]
+        imgs = self.list_2_batch(imgs, batch_size=batch_size)
+        label_res = []
+        for img_batch in imgs:
+            x = self.batch_preprocess(img_batch)
+            result = self.sess.run(None, {"x": x})
+            for img_res in result[0]:
+                idx = np.argmax(img_res)
+                conf = float(np.max(img_res))
+                # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
+                if idx == 0 and conf < 0.8:
+                    idx = 1
+                label_res.append((self.labels[idx],conf))
+        for img_info, (label, conf) in zip(img_info_list, label_res):
+            img_info['table_res']["cls_label"] = label
+            img_info['table_res']["cls_score"] = conf

+ 0 - 46
mineru/model/table/rec/rapid_table.py

@@ -1,46 +0,0 @@
-import os
-import html
-import cv2
-import numpy as np
-from loguru import logger
-from rapid_table import RapidTable, RapidTableInput
-
-from mineru.utils.enum_class import ModelPath
-from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-
-
-def escape_html(input_string):
-    """Escape HTML Entities."""
-    return html.escape(input_string)
-
-
-class RapidTableModel(object):
-    def __init__(self, ocr_engine):
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.table_model = RapidTable(input_args)
-        self.ocr_engine = ocr_engine
-
-
-    def predict(self, image, table_cls_score):
-        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
-        # Continue with OCR on potentially rotated image
-        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-        if ocr_result:
-            ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
-                      len(item) == 2 and isinstance(item[1], tuple)]
-        else:
-            ocr_result = None
-
-        if ocr_result:
-            try:
-                table_results = self.table_model(np.asarray(image), ocr_result)
-                html_code = table_results.pred_html
-                table_cell_bboxes = table_results.cell_bboxes
-                logic_points = table_results.logic_points
-                elapse = table_results.elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
-            except Exception as e:
-                logger.exception(e)
-
-        return None, None, None, None

+ 0 - 0
mineru/model/table/rec/slanet_plus/__init__.py


+ 278 - 0
mineru/model/table/rec/slanet_plus/main.py

@@ -0,0 +1,278 @@
+import os
+import argparse
+import copy
+import importlib
+import time
+import html
+from dataclasses import asdict, dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from .matcher import TableMatch
+from .table_structure import TableStructurer
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+root_dir = Path(__file__).resolve().parent
+
+
+class ModelType(Enum):
+    PPSTRUCTURE_EN = "ppstructure_en"
+    PPSTRUCTURE_ZH = "ppstructure_zh"
+    SLANETPLUS = "slanet_plus"
+    UNITABLE = "unitable"
+
+
+ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
+KEY_TO_MODEL_URL = {
+    ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
+    ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
+    ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
+    ModelType.UNITABLE.value: {
+        "encoder": f"{ROOT_URL}/unitable/encoder.pth",
+        "decoder": f"{ROOT_URL}/unitable/decoder.pth",
+        "vocab": f"{ROOT_URL}/unitable/vocab.json",
+    },
+}
+
+
+@dataclass
+class RapidTableInput:
+    model_type: Optional[str] = ModelType.SLANETPLUS.value
+    model_path: Union[str, Path, None, Dict[str, str]] = None
+    use_cuda: bool = False
+    device: str = "cpu"
+
+
+@dataclass
+class RapidTableOutput:
+    pred_html: Optional[str] = None
+    cell_bboxes: Optional[np.ndarray] = None
+    logic_points: Optional[np.ndarray] = None
+    elapse: Optional[float] = None
+
+
+class RapidTable:
+    def __init__(self, config: RapidTableInput):
+        self.model_type = config.model_type
+        if self.model_type not in KEY_TO_MODEL_URL:
+            model_list = ",".join(KEY_TO_MODEL_URL)
+            raise ValueError(
+                f"{self.model_type} is not supported. The currently supported models are {model_list}."
+            )
+
+        config.model_path = config.model_path
+        if self.model_type == ModelType.SLANETPLUS.value:
+            self.table_structure = TableStructurer(asdict(config))
+        else:
+            raise ValueError(f"{self.model_type} is not supported.")
+        self.table_matcher = TableMatch()
+
+    def predict(
+        self,
+        img: np.ndarray,
+        ocr_result: List[Union[List[List[float]], str, str]] = None,
+    ) -> RapidTableOutput:
+        if ocr_result is None:
+            raise ValueError("OCR result is None")
+
+        s = time.perf_counter()
+        h, w = img.shape[:2]
+
+        dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
+
+        pred_structures, cell_bboxes, _ = self.table_structure.process(
+            copy.deepcopy(img)
+        )
+
+        # 适配slanet-plus模型输出的box缩放还原
+        cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+
+        pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
+
+        # 过滤掉占位的bbox
+        mask = ~np.all(cell_bboxes == 0, axis=1)
+        cell_bboxes = cell_bboxes[mask]
+
+        logic_points = self.table_matcher.decode_logic_points(pred_structures)
+        elapse = time.perf_counter() - s
+        return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
+
+    def batch_predict(
+        self,
+        images: List[np.ndarray],
+        ocr_results: List[List[Union[List[List[float]], str, str]]],
+        batch_size: int = 4,
+    ) -> List[RapidTableOutput]:
+        """批量处理图像"""
+        s = time.perf_counter()
+
+        batch_dt_boxes = []
+        batch_rec_res = []
+
+        for i, img in enumerate(images):
+            h, w = img.shape[:2]
+            dt_boxes, rec_res = self.get_boxes_recs(ocr_results[i], h, w)
+            batch_dt_boxes.append(dt_boxes)
+            batch_rec_res.append(rec_res)
+
+        # 批量表格结构识别
+        batch_results = self.table_structure.batch_process(images)
+
+        output_results = []
+        for i, (img, ocr_result, (pred_structures, cell_bboxes, _)) in enumerate(
+            zip(images, ocr_results, batch_results)
+        ):
+            # 适配slanet-plus模型输出的box缩放还原
+            cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+            pred_html = self.table_matcher(
+                pred_structures, cell_bboxes, batch_dt_boxes[i], batch_rec_res[i]
+            )
+            # 过滤掉占位的bbox
+            mask = ~np.all(cell_bboxes == 0, axis=1)
+            cell_bboxes = cell_bboxes[mask]
+
+            logic_points = self.table_matcher.decode_logic_points(pred_structures)
+            result = RapidTableOutput(pred_html, cell_bboxes, logic_points, 0)
+            output_results.append(result)
+
+        total_elapse = time.perf_counter() - s
+        for result in output_results:
+            result.elapse = total_elapse / len(output_results)
+
+        return output_results
+
+    def get_boxes_recs(
+        self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
+    ) -> Tuple[np.ndarray, Tuple[str, str]]:
+        dt_boxes, rec_res, scores = list(zip(*ocr_result))
+        rec_res = list(zip(rec_res, scores))
+
+        r_boxes = []
+        for box in dt_boxes:
+            box = np.array(box)
+            x_min = max(0, box[:, 0].min() - 1)
+            x_max = min(w, box[:, 0].max() + 1)
+            y_min = max(0, box[:, 1].min() - 1)
+            y_max = min(h, box[:, 1].max() + 1)
+            box = [x_min, y_min, x_max, y_max]
+            r_boxes.append(box)
+        dt_boxes = np.array(r_boxes)
+        return dt_boxes, rec_res
+
+    def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
+        h, w = img.shape[:2]
+        resized = 488
+        ratio = min(resized / h, resized / w)
+        w_ratio = resized / (w * ratio)
+        h_ratio = resized / (h * ratio)
+        cell_bboxes[:, 0::2] *= w_ratio
+        cell_bboxes[:, 1::2] *= h_ratio
+        return cell_bboxes
+
+
+def parse_args(arg_list: Optional[List[str]] = None):
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-v",
+        "--vis",
+        action="store_true",
+        default=False,
+        help="Wheter to visualize the layout results.",
+    )
+    parser.add_argument(
+        "-img", "--img_path", type=str, required=True, help="Path to image for layout."
+    )
+    parser.add_argument(
+        "-m",
+        "--model_type",
+        type=str,
+        default=ModelType.SLANETPLUS.value,
+        choices=list(KEY_TO_MODEL_URL),
+    )
+    args = parser.parse_args(arg_list)
+    return args
+
+
+def escape_html(input_string):
+    """Escape HTML Entities."""
+    return html.escape(input_string)
+
+
+class RapidTableModel(object):
+    def __init__(self, ocr_engine):
+        slanet_plus_model_path = os.path.join(
+            auto_download_and_get_model_root_path(ModelPath.slanet_plus),
+            ModelPath.slanet_plus,
+        )
+        input_args = RapidTableInput(
+            model_type="slanet_plus", model_path=slanet_plus_model_path
+        )
+        self.table_model = RapidTable(input_args)
+        self.ocr_engine = ocr_engine
+
+    def predict(self, image, table_cls_score):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        # Continue with OCR on potentially rotated image
+        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+        if ocr_result:
+            ocr_result = [
+                [item[0], escape_html(item[1][0]), item[1][1]]
+                for item in ocr_result
+                if len(item) == 2 and isinstance(item[1], tuple)
+            ]
+        else:
+            ocr_result = None
+
+        if ocr_result:
+            try:
+                table_results = self.table_model.predict(np.asarray(image), ocr_result)
+                html_code = table_results.pred_html
+                table_cell_bboxes = table_results.cell_bboxes
+                logic_points = table_results.logic_points
+                elapse = table_results.elapse
+                return html_code, table_cell_bboxes, logic_points, elapse
+            except Exception as e:
+                logger.exception(e)
+
+        return None, None, None, None
+
+    def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
+        """对传入的字典列表进行批量预测,无返回值"""
+        for index in tqdm(
+            range(0, len(table_res_list), batch_size),
+            desc=f"Wireless Table Batch Predict, total={len(table_res_list)}, batch_size={batch_size}",
+        ):
+            batch_imgs = [
+                cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
+                for i in range(index, min(index + batch_size, len(table_res_list)))
+            ]
+            batch_ocrs = [
+                table_res_list[i]["ocr_result"]
+                for i in range(index, min(index + batch_size, len(table_res_list)))
+            ]
+            results = self.table_model.batch_predict(
+                batch_imgs, batch_ocrs, batch_size=batch_size
+            )
+            for i, result in enumerate(results):
+                if result.pred_html:
+                    # 检查html_code是否包含'<table>'和'</table>'
+                    if '<table>' in result.pred_html and '</table>' in result.pred_html:
+                        # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
+                        start_index = result.pred_html.find('<table>')
+                        end_index = result.pred_html.rfind('</table>') + len('</table>')
+                        table_res_list[index + i]['table_res']['html'] = result.pred_html[start_index:end_index]
+                    else:
+                        logger.warning(
+                            'wireless table recognition processing fails, not found expected HTML table end'
+                        )
+                else:
+                    logger.warning(
+                        "wireless table recognition processing fails, not get html return"
+                    )

+ 198 - 0
mineru/model/table/rec/slanet_plus/matcher.py

@@ -0,0 +1,198 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+
+from .matcher_utils import compute_iou, distance
+
+
+class TableMatch:
+    def __init__(self, filter_ocr_result=True, use_master=False):
+        self.filter_ocr_result = filter_ocr_result
+        self.use_master = use_master
+
+    def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
+        if self.filter_ocr_result:
+            dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
+        matched_index = self.match_result(dt_boxes, cell_bboxes)
+        pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
+        return pred_html
+
+    def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
+        matched = {}
+        for i, gt_box in enumerate(dt_boxes):
+            distances = []
+            for j, pred_box in enumerate(cell_bboxes):
+                if len(pred_box) == 8:
+                    pred_box = [
+                        np.min(pred_box[0::2]),
+                        np.min(pred_box[1::2]),
+                        np.max(pred_box[0::2]),
+                        np.max(pred_box[1::2]),
+                    ]
+                distances.append(
+                    (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
+                )  # compute iou and l1 distance
+            sorted_distances = distances.copy()
+            # select det box by iou and l1 distance
+            sorted_distances = sorted(
+                sorted_distances, key=lambda item: (item[1], item[0])
+            )
+            # must > min_iou
+            if sorted_distances[0][1] >= 1 - min_iou:
+                continue
+
+            if distances.index(sorted_distances[0]) not in matched:
+                matched[distances.index(sorted_distances[0])] = [i]
+            else:
+                matched[distances.index(sorted_distances[0])].append(i)
+        return matched
+
+    def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+        end_html = []
+        td_index = 0
+        for tag in pred_structures:
+            if "</td>" not in tag:
+                end_html.append(tag)
+                continue
+
+            if "<td></td>" == tag:
+                end_html.extend("<td>")
+
+            if td_index in matched_index.keys():
+                b_with = False
+                if (
+                    "<b>" in ocr_contents[matched_index[td_index][0]]
+                    and len(matched_index[td_index]) > 1
+                ):
+                    b_with = True
+                    end_html.extend("<b>")
+
+                for i, td_index_index in enumerate(matched_index[td_index]):
+                    content = ocr_contents[td_index_index][0]
+                    if len(matched_index[td_index]) > 1:
+                        if len(content) == 0:
+                            continue
+
+                        if content[0] == " ":
+                            content = content[1:]
+
+                        if "<b>" in content:
+                            content = content[3:]
+
+                        if "</b>" in content:
+                            content = content[:-4]
+
+                        if len(content) == 0:
+                            continue
+
+                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+                            content += " "
+                    end_html.extend(content)
+
+                if b_with:
+                    end_html.extend("</b>")
+
+            if "<td></td>" == tag:
+                end_html.append("</td>")
+            else:
+                end_html.append(tag)
+
+            td_index += 1
+
+        # Filter <thead></thead><tbody></tbody> elements
+        filter_elements = ["<thead>", "</thead>", "<tbody>", "</tbody>"]
+        end_html = [v for v in end_html if v not in filter_elements]
+        return "".join(end_html), end_html
+
+    def decode_logic_points(self, pred_structures):
+        logic_points = []
+        current_row = 0
+        current_col = 0
+        max_rows = 0
+        max_cols = 0
+        occupied_cells = {}  # 用于记录已经被占用的单元格
+
+        def is_occupied(row, col):
+            return (row, col) in occupied_cells
+
+        def mark_occupied(row, col, rowspan, colspan):
+            for r in range(row, row + rowspan):
+                for c in range(col, col + colspan):
+                    occupied_cells[(r, c)] = True
+
+        i = 0
+        while i < len(pred_structures):
+            token = pred_structures[i]
+
+            if token == "<tr>":
+                current_col = 0  # 每次遇到 <tr> 时,重置当前列号
+            elif token == "</tr>":
+                current_row += 1  # 行结束,行号增加
+            elif token.startswith("<td"):
+                colspan = 1
+                rowspan = 1
+                j = i
+                if token != "<td></td>":
+                    j += 1
+                    # 提取 colspan 和 rowspan 属性
+                    while j < len(pred_structures) and not pred_structures[
+                        j
+                    ].startswith(">"):
+                        if "colspan=" in pred_structures[j]:
+                            colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+                        elif "rowspan=" in pred_structures[j]:
+                            rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+                        j += 1
+
+                # 跳过已经处理过的属性 token
+                i = j
+
+                # 找到下一个未被占用的列
+                while is_occupied(current_row, current_col):
+                    current_col += 1
+
+                # 计算逻辑坐标
+                r_start = current_row
+                r_end = current_row + rowspan - 1
+                col_start = current_col
+                col_end = current_col + colspan - 1
+
+                # 记录逻辑坐标
+                logic_points.append([r_start, r_end, col_start, col_end])
+
+                # 标记占用的单元格
+                mark_occupied(r_start, col_start, rowspan, colspan)
+
+                # 更新当前列号
+                current_col += colspan
+
+                # 更新最大行数和列数
+                max_rows = max(max_rows, r_end + 1)
+                max_cols = max(max_cols, col_end + 1)
+
+            i += 1
+
+        return logic_points
+
+    def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
+        y1 = cell_bboxes[:, 1::2].min()
+        new_dt_boxes = []
+        new_rec_res = []
+
+        for box, rec in zip(dt_boxes, rec_res):
+            if np.max(box[1::2]) < y1:
+                continue
+            new_dt_boxes.append(box)
+            new_rec_res.append(rec)
+        return new_dt_boxes, new_rec_res

+ 246 - 0
mineru/model/table/rec/slanet_plus/matcher_utils.py

@@ -0,0 +1,246 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import re
+
+
+def deal_isolate_span(thead_part):
+    """
+    Deal with isolate span cases in this function.
+    It causes by wrong prediction in structure recognition model.
+    eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out isolate span tokens.
+    isolate_pattern = (
+        r"<td></td> rowspan='(\d)+' colspan='(\d)+'></b></td>|"
+        r"<td></td> colspan='(\d)+' rowspan='(\d)+'></b></td>|"
+        r"<td></td> rowspan='(\d)+'></b></td>|"
+        r"<td></td> colspan='(\d)+'></b></td>"
+    )
+    isolate_iter = re.finditer(isolate_pattern, thead_part)
+    isolate_list = [i.group() for i in isolate_iter]
+
+    # 2. find out span number, by step 1 result.
+    span_pattern = (
+        r" rowspan='(\d)+' colspan='(\d)+'|"
+        r" colspan='(\d)+' rowspan='(\d)+'|"
+        r" rowspan='(\d)+'|"
+        r" colspan='(\d)+'"
+    )
+    corrected_list = []
+    for isolate_item in isolate_list:
+        span_part = re.search(span_pattern, isolate_item)
+        spanStr_in_isolateItem = span_part.group()
+        # 3. merge the span number into the span token format string.
+        if spanStr_in_isolateItem is not None:
+            corrected_item = f"<td{spanStr_in_isolateItem}></td>"
+            corrected_list.append(corrected_item)
+        else:
+            corrected_list.append(None)
+
+    # 4. replace original isolated token.
+    for corrected_item, isolate_item in zip(corrected_list, isolate_list):
+        if corrected_item is not None:
+            thead_part = thead_part.replace(isolate_item, corrected_item)
+        else:
+            pass
+    return thead_part
+
+
+def deal_duplicate_bb(thead_part):
+    """
+    Deal duplicate <b> or </b> after replace.
+    Keep one <b></b> in a <td></td> token.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out <td></td> in <thead></thead>.
+    td_pattern = (
+        r"<td rowspan='(\d)+' colspan='(\d)+'>(.+?)</td>|"
+        r"<td colspan='(\d)+' rowspan='(\d)+'>(.+?)</td>|"
+        r"<td rowspan='(\d)+'>(.+?)</td>|"
+        r"<td colspan='(\d)+'>(.+?)</td>|"
+        r"<td>(.*?)</td>"
+    )
+    td_iter = re.finditer(td_pattern, thead_part)
+    td_list = [t.group() for t in td_iter]
+
+    # 2. is multiply <b></b> in <td></td> or not?
+    new_td_list = []
+    for td_item in td_list:
+        if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
+            # multiply <b></b> in <td></td> case.
+            # 1. remove all <b></b>
+            td_item = td_item.replace("<b>", "").replace("</b>", "")
+            # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
+            td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
+            new_td_list.append(td_item)
+        else:
+            new_td_list.append(td_item)
+
+    # 3. replace original thead part.
+    for td_item, new_td_item in zip(td_list, new_td_list):
+        thead_part = thead_part.replace(td_item, new_td_item)
+    return thead_part
+
+
+def deal_bb(result_token):
+    """
+    In our opinion, <b></b> always occurs in <thead></thead> text's context.
+    This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
+    :param result_token:
+    :return:
+    """
+    # find out <thead></thead> parts.
+    thead_pattern = "<thead>(.*?)</thead>"
+    if re.search(thead_pattern, result_token) is None:
+        return result_token
+    thead_part = re.search(thead_pattern, result_token).group()
+    origin_thead_part = copy.deepcopy(thead_part)
+
+    # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
+    span_pattern = r"<td rowspan='(\d)+' colspan='(\d)+'>|<td colspan='(\d)+' rowspan='(\d)+'>|<td rowspan='(\d)+'>|<td colspan='(\d)+'>"
+    span_iter = re.finditer(span_pattern, thead_part)
+    span_list = [s.group() for s in span_iter]
+    has_span_in_head = True if len(span_list) > 0 else False
+
+    if not has_span_in_head:
+        # <thead></thead> not include "rowspan" or "colspan" branch 1.
+        # 1. replace <td> to <td><b>, and </td> to </b></td>
+        # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b></b> to </b>
+        thead_part = (
+            thead_part.replace("<td>", "<td><b>")
+            .replace("</td>", "</b></td>")
+            .replace("<b><b>", "<b>")
+            .replace("</b></b>", "</b>")
+        )
+    else:
+        # <thead></thead> include "rowspan" or "colspan" branch 2.
+        # Firstly, we deal rowspan or colspan cases.
+        # 1. replace > to ><b>
+        # 2. replace </td> to </b></td>
+        # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b><b> to </b>
+
+        # Secondly, deal ordinary cases like branch 1
+
+        # replace ">" to "<b>"
+        replaced_span_list = []
+        for sp in span_list:
+            replaced_span_list.append(sp.replace(">", "><b>"))
+        for sp, rsp in zip(span_list, replaced_span_list):
+            thead_part = thead_part.replace(sp, rsp)
+
+        # replace "</td>" to "</b></td>"
+        thead_part = thead_part.replace("</td>", "</b></td>")
+
+        # remove duplicated <b> by re.sub
+        mb_pattern = "(<b>)+"
+        single_b_string = "<b>"
+        thead_part = re.sub(mb_pattern, single_b_string, thead_part)
+
+        mgb_pattern = "(</b>)+"
+        single_gb_string = "</b>"
+        thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
+
+        # ordinary cases like branch 1
+        thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
+
+    # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
+    # but space cell(<tb> </tb>)  is suitable for <td><b> </b></td>
+    thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
+    # deal with duplicated <b></b>
+    thead_part = deal_duplicate_bb(thead_part)
+    # deal with isolate span tokens, which causes by wrong predict by structure prediction.
+    # eg.PMC5994107_011_00.png
+    thead_part = deal_isolate_span(thead_part)
+    # replace original result with new thead part.
+    result_token = result_token.replace(origin_thead_part, thead_part)
+    return result_token
+
+
+def deal_eb_token(master_token):
+    """
+    post process with <eb></eb>, <eb1></eb1>, ...
+    emptyBboxTokenDict = {
+        "[]": '<eb></eb>',
+        "[' ']": '<eb1></eb1>',
+        "['<b>', ' ', '</b>']": '<eb2></eb2>',
+        "['\\u2028', '\\u2028']": '<eb3></eb3>',
+        "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
+        "['<b>', '</b>']": '<eb5></eb5>',
+        "['<i>', ' ', '</i>']": '<eb6></eb6>',
+        "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
+        "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
+        "['<i>', '</i>']": '<eb9></eb9>',
+        "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
+    }
+    :param master_token:
+    :return:
+    """
+    master_token = master_token.replace("<eb></eb>", "<td></td>")
+    master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
+    master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
+    master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
+    master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
+    master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
+    master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
+    master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
+    master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
+    master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
+    master_token = master_token.replace(
+        "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
+    )
+    return master_token
+
+
+def distance(box_1, box_2):
+    x1, y1, x2, y2 = box_1
+    x3, y3, x4, y4 = box_2
+    dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+    dis_2 = abs(x3 - x1) + abs(y3 - y1)
+    dis_3 = abs(x4 - x2) + abs(y4 - y2)
+    return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1, rec2):
+    """
+    computing IoU
+    :param rec1: (y0, x0, y1, x1), which reflects
+            (top, left, bottom, right)
+    :param rec2: (y0, x0, y1, x1)
+    :return: scala value of IoU
+    """
+    # computing area of each rectangles
+    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+    # computing the sum_area
+    sum_area = S_rec1 + S_rec2
+
+    # find the each edge of intersect rectangle
+    left_line = max(rec1[1], rec2[1])
+    right_line = min(rec1[3], rec2[3])
+    top_line = max(rec1[0], rec2[0])
+    bottom_line = min(rec1[2], rec2[2])
+
+    # judge if there is an intersect
+    if left_line >= right_line or top_line >= bottom_line:
+        return 0.0
+
+    intersect = (right_line - left_line) * (bottom_line - top_line)
+    return (intersect / (sum_area - intersect)) * 1.0

+ 109 - 0
mineru/model/table/rec/slanet_plus/table_structure.py

@@ -0,0 +1,109 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+
+from .table_structure_utils import (
+    OrtInferSession,
+    TableLabelDecode,
+    TablePreprocess,
+    BatchTablePreprocess,
+)
+
+
+class TableStructurer:
+    def __init__(self, config: Dict[str, Any]):
+        self.preprocess_op = TablePreprocess()
+        self.batch_preprocess_op = BatchTablePreprocess()
+
+        self.session = OrtInferSession(config)
+
+        self.character = self.session.get_metadata()
+        self.postprocess_op = TableLabelDecode(self.character)
+
+    def process(self, img):
+        starttime = time.time()
+        data = {"image": img}
+        data = self.preprocess_op(data)
+        img = data[0]
+        if img is None:
+            return None, 0
+        img = np.expand_dims(img, axis=0)
+        img = img.copy()
+
+        outputs = self.session([img])
+
+        preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
+
+        shape_list = np.expand_dims(data[-1], axis=0)
+        post_result = self.postprocess_op(preds, [shape_list])
+
+        bbox_list = post_result["bbox_batch_list"][0]
+
+        structure_str_list = post_result["structure_batch_list"][0]
+        structure_str_list = structure_str_list[0]
+        structure_str_list = (
+            ["<html>", "<body>", "<table>"]
+            + structure_str_list
+            + ["</table>", "</body>", "</html>"]
+        )
+        elapse = time.time() - starttime
+        return structure_str_list, bbox_list, elapse
+
+    def batch_process(
+        self, img_list: List[np.ndarray]
+    ) -> List[Tuple[List[str], np.ndarray, float]]:
+        """批量处理图像列表
+        Args:
+            img_list: 图像列表
+        Returns:
+            结果列表,每个元素包含 (table_struct_str, cell_bboxes, elapse)
+        """
+        starttime = time.perf_counter()
+
+        batch_data = self.batch_preprocess_op(img_list)
+        preprocessed_images = batch_data[0]
+        shape_lists = batch_data[1]
+
+        preprocessed_images = np.array(preprocessed_images)
+        bbox_preds, struct_probs = self.session([preprocessed_images])
+
+        batch_size = preprocessed_images.shape[0]
+        results = []
+        for bbox_pred, struct_prob, shape_list in zip(
+            bbox_preds, struct_probs, shape_lists
+        ):
+            preds = {
+                "loc_preds": np.expand_dims(bbox_pred, axis=0),
+                "structure_probs": np.expand_dims(struct_prob, axis=0),
+            }
+            shape_list = np.expand_dims(shape_list, axis=0)
+            post_result = self.postprocess_op(preds, [shape_list])
+            bbox_list = post_result["bbox_batch_list"][0]
+            structure_str_list = post_result["structure_batch_list"][0]
+            structure_str_list = structure_str_list[0]
+            structure_str_list = (
+                ["<html>", "<body>", "<table>"]
+                + structure_str_list
+                + ["</table>", "</body>", "</html>"]
+            )
+            results.append((structure_str_list, bbox_list, 0))
+
+        total_elapse = time.perf_counter() - starttime
+        for i in range(len(results)):
+            results[i] = (results[i][0], results[i][1], total_elapse / batch_size)
+
+        return results

+ 570 - 0
mineru/model/table/rec/slanet_plus/table_structure_utils.py

@@ -0,0 +1,570 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import platform
+import traceback
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+from onnxruntime import (
+    GraphOptimizationLevel,
+    InferenceSession,
+    SessionOptions,
+    get_available_providers,
+    get_device,
+)
+
+from loguru import logger
+
+
+class EP(Enum):
+    CPU_EP = "CPUExecutionProvider"
+    CUDA_EP = "CUDAExecutionProvider"
+    DIRECTML_EP = "DmlExecutionProvider"
+
+
+class OrtInferSession:
+    def __init__(self, config: Dict[str, Any]):
+        self.logger = logger
+
+        model_path = config.get("model_path", None)
+        self._verify_model(model_path)
+
+        self.cfg_use_cuda = config.get("use_cuda", None)
+        self.cfg_use_dml = config.get("use_dml", None)
+
+        self.had_providers: List[str] = get_available_providers()
+        EP_list = self._get_ep_list()
+
+        sess_opt = self._init_sess_opts(config)
+        self.session = InferenceSession(
+            model_path,
+            sess_options=sess_opt,
+            providers=EP_list,
+        )
+        self._verify_providers()
+
+    @staticmethod
+    def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
+        sess_opt = SessionOptions()
+        sess_opt.log_severity_level = 4
+        sess_opt.enable_cpu_mem_arena = False
+        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+
+        cpu_nums = os.cpu_count()
+        intra_op_num_threads = config.get("intra_op_num_threads", -1)
+        if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
+            sess_opt.intra_op_num_threads = intra_op_num_threads
+
+        inter_op_num_threads = config.get("inter_op_num_threads", -1)
+        if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
+            sess_opt.inter_op_num_threads = inter_op_num_threads
+
+        return sess_opt
+
+    def get_metadata(self, key: str = "character") -> list:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        content_list = meta_dict[key].splitlines()
+        return content_list
+
+    def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
+        cpu_provider_opts = {
+            "arena_extend_strategy": "kSameAsRequested",
+        }
+        EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+
+        cuda_provider_opts = {
+            "device_id": 0,
+            "arena_extend_strategy": "kNextPowerOfTwo",
+            "cudnn_conv_algo_search": "EXHAUSTIVE",
+            "do_copy_in_default_stream": True,
+        }
+        self.use_cuda = self._check_cuda()
+        if self.use_cuda:
+            EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
+
+        self.use_directml = self._check_dml()
+        if self.use_directml:
+            self.logger.info(
+                "Windows 10 or above detected, try to use DirectML as primary provider"
+            )
+            directml_options = (
+                cuda_provider_opts if self.use_cuda else cpu_provider_opts
+            )
+            EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
+        return EP_list
+
+    def _check_cuda(self) -> bool:
+        if not self.cfg_use_cuda:
+            return False
+
+        cur_device = get_device()
+        if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
+            return True
+
+        self.logger.warning(
+            "%s is not in available providers (%s). Use %s inference by default.",
+            EP.CUDA_EP.value,
+            self.had_providers,
+            self.had_providers[0],
+        )
+        self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
+        self.logger.info(
+            "(For reference only) If you want to use GPU acceleration, you must do:"
+        )
+        self.logger.info(
+            "First, uninstall all onnxruntime pakcages in current environment."
+        )
+        self.logger.info(
+            "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
+        )
+        self.logger.info(
+            "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
+        )
+        self.logger.info(
+            "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
+        )
+        self.logger.info(
+            "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
+            EP.CUDA_EP.value,
+        )
+        return False
+
+    def _check_dml(self) -> bool:
+        if not self.cfg_use_dml:
+            return False
+
+        cur_os = platform.system()
+        if cur_os != "Windows":
+            self.logger.warning(
+                "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
+                cur_os,
+                self.had_providers[0],
+            )
+            return False
+
+        cur_window_version = int(platform.release().split(".")[0])
+        if cur_window_version < 10:
+            self.logger.warning(
+                "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
+                cur_window_version,
+                self.had_providers[0],
+            )
+            return False
+
+        if EP.DIRECTML_EP.value in self.had_providers:
+            return True
+
+        self.logger.warning(
+            "%s is not in available providers (%s). Use %s inference by default.",
+            EP.DIRECTML_EP.value,
+            self.had_providers,
+            self.had_providers[0],
+        )
+        self.logger.info("If you want to use DirectML acceleration, you must do:")
+        self.logger.info(
+            "First, uninstall all onnxruntime pakcages in current environment."
+        )
+        self.logger.info(
+            "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
+        )
+        self.logger.info(
+            "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
+            EP.DIRECTML_EP.value,
+        )
+        return False
+
+    def _verify_providers(self):
+        session_providers = self.session.get_providers()
+        first_provider = session_providers[0]
+
+        if self.use_cuda and first_provider != EP.CUDA_EP.value:
+            self.logger.warning(
+                "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
+                EP.CUDA_EP.value,
+                first_provider,
+            )
+
+        if self.use_directml and first_provider != EP.DIRECTML_EP.value:
+            self.logger.warning(
+                "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
+                EP.DIRECTML_EP.value,
+                first_provider,
+            )
+
+    def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
+        input_dict = dict(zip(self.get_input_names(), input_content))
+        try:
+            return self.session.run(None, input_dict)
+        except Exception as e:
+            error_info = traceback.format_exc()
+            raise ONNXRuntimeError(error_info) from e
+
+    def get_input_names(self) -> List[str]:
+        return [v.name for v in self.session.get_inputs()]
+
+    def get_output_names(self) -> List[str]:
+        return [v.name for v in self.session.get_outputs()]
+
+    def get_character_list(self, key: str = "character") -> List[str]:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        return meta_dict[key].splitlines()
+
+    def have_key(self, key: str = "character") -> bool:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        if key in meta_dict.keys():
+            return True
+        return False
+
+    @staticmethod
+    def _verify_model(model_path: Union[str, Path, None]):
+        if model_path is None:
+            raise ValueError("model_path is None!")
+
+        model_path = Path(model_path)
+        if not model_path.exists():
+            raise FileNotFoundError(f"{model_path} does not exists.")
+
+        if not model_path.is_file():
+            raise FileExistsError(f"{model_path} is not a file.")
+
+
+class ONNXRuntimeError(Exception):
+    pass
+
+
+class TableLabelDecode:
+    def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
+        if merge_no_span_structure:
+            if "<td></td>" not in dict_character:
+                dict_character.append("<td></td>")
+            if "<td>" in dict_character:
+                dict_character.remove("<td>")
+
+        dict_character = self.add_special_char(dict_character)
+        self.dict = {}
+        for i, char in enumerate(dict_character):
+            self.dict[char] = i
+        self.character = dict_character
+        self.td_token = ["<td>", "<td", "<td></td>"]
+
+    def __call__(self, preds, batch=None):
+        structure_probs = preds["structure_probs"]
+        bbox_preds = preds["loc_preds"]
+        shape_list = batch[-1]
+        result = self.decode(structure_probs, bbox_preds, shape_list)
+        if len(batch) == 1:  # only contains shape
+            return result
+
+        label_decode_result = self.decode_label(batch)
+        return result, label_decode_result
+
+    def decode(self, structure_probs, bbox_preds, shape_list):
+        """convert text-label into text-index."""
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_idx = structure_probs.argmax(axis=2)
+        structure_probs = structure_probs.max(axis=2)
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            score_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+
+                if char_idx in ignored_tokens:
+                    continue
+
+                text = self.character[char_idx]
+                if text in self.td_token:
+                    bbox = bbox_preds[batch_idx, idx]
+                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox_list.append(bbox)
+                structure_list.append(text)
+                score_list.append(structure_probs[batch_idx, idx])
+            structure_batch_list.append([structure_list, np.mean(score_list)])
+            bbox_batch_list.append(np.array(bbox_list))
+        result = {
+            "bbox_batch_list": bbox_batch_list,
+            "structure_batch_list": structure_batch_list,
+        }
+        return result
+
+    def decode_label(self, batch):
+        """convert text-label into text-index."""
+        structure_idx = batch[1]
+        gt_bbox_list = batch[2]
+        shape_list = batch[-1]
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+
+                if char_idx in ignored_tokens:
+                    continue
+
+                structure_list.append(self.character[char_idx])
+
+                bbox = gt_bbox_list[batch_idx][idx]
+                if bbox.sum() != 0:
+                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox_list.append(bbox)
+
+            structure_batch_list.append(structure_list)
+            bbox_batch_list.append(bbox_list)
+        result = {
+            "bbox_batch_list": bbox_batch_list,
+            "structure_batch_list": structure_batch_list,
+        }
+        return result
+
+    def _bbox_decode(self, bbox, shape):
+        h, w = shape[:2]
+        bbox[0::2] *= w
+        bbox[1::2] *= h
+        return bbox
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            return np.array(self.dict[self.beg_str])
+
+        if beg_or_end == "end":
+            return np.array(self.dict[self.end_str])
+
+        raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
+
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = [self.beg_str] + dict_character + [self.end_str]
+        return dict_character
+
+
+class TablePreprocess:
+    def __init__(self):
+        self.table_max_len = 488
+        self.build_pre_process_list()
+        self.ops = self.create_operators()
+
+    def __call__(self, data):
+        """transform"""
+        if self.ops is None:
+            self.ops = []
+
+        for op in self.ops:
+            data = op(data)
+            if data is None:
+                return None
+        return data
+
+    def create_operators(
+        self,
+    ):
+        """
+        create operators based on the config
+
+        Args:
+            params(list): a dict list, used to create some operators
+        """
+        assert isinstance(
+            self.pre_process_list, list
+        ), "operator config should be a list"
+        ops = []
+        for operator in self.pre_process_list:
+            assert (
+                isinstance(operator, dict) and len(operator) == 1
+            ), "yaml format error"
+            op_name = list(operator)[0]
+            param = {} if operator[op_name] is None else operator[op_name]
+            op = eval(op_name)(**param)
+            ops.append(op)
+        return ops
+
+    def build_pre_process_list(self):
+        resize_op = {
+            "ResizeTableImage": {
+                "max_len": self.table_max_len,
+            }
+        }
+        pad_op = {
+            "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
+        }
+        normalize_op = {
+            "NormalizeImage": {
+                "std": [0.229, 0.224, 0.225],
+                "mean": [0.485, 0.456, 0.406],
+                "scale": "1./255.",
+                "order": "hwc",
+            }
+        }
+        to_chw_op = {"ToCHWImage": None}
+        keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
+        self.pre_process_list = [
+            resize_op,
+            normalize_op,
+            pad_op,
+            to_chw_op,
+            keep_keys_op,
+        ]
+
+
+class BatchTablePreprocess:
+
+    def __init__(self):
+        self.preprocess = TablePreprocess()
+
+    def __call__(
+        self, img_list: List[np.ndarray]
+    ) -> Tuple[List[np.ndarray], List[List[float]]]:
+        """批量处理图像
+
+        Args:
+            img_list: 图像列表
+
+        Returns:
+            预处理后的图像列表和形状信息列表
+        """
+        processed_imgs = []
+        shape_lists = []
+
+        for img in img_list:
+            if img is None:
+                continue
+            data = {"image": img}
+            img_processed, shape_list = self.preprocess(data)
+            processed_imgs.append(img_processed)
+            shape_lists.append(shape_list)
+        return processed_imgs, shape_lists
+
+
+class ResizeTableImage:
+    def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
+        super(ResizeTableImage, self).__init__()
+        self.max_len = max_len
+        self.resize_bboxes = resize_bboxes
+        self.infer_mode = infer_mode
+
+    def __call__(self, data):
+        img = data["image"]
+        height, width = img.shape[0:2]
+        ratio = self.max_len / (max(height, width) * 1.0)
+        resize_h = int(height * ratio)
+        resize_w = int(width * ratio)
+        resize_img = cv2.resize(img, (resize_w, resize_h))
+        if self.resize_bboxes and not self.infer_mode:
+            data["bboxes"] = data["bboxes"] * ratio
+        data["image"] = resize_img
+        data["src_img"] = img
+        data["shape"] = np.array([height, width, ratio, ratio])
+        data["max_len"] = self.max_len
+        return data
+
+
+class PaddingTableImage:
+    def __init__(self, size, **kwargs):
+        super(PaddingTableImage, self).__init__()
+        self.size = size
+
+    def __call__(self, data):
+        img = data["image"]
+        pad_h, pad_w = self.size
+        padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
+        height, width = img.shape[0:2]
+        padding_img[0:height, 0:width, :] = img.copy()
+        data["image"] = padding_img
+        shape = data["shape"].tolist()
+        shape.extend([pad_h, pad_w])
+        data["shape"] = np.array(shape)
+        return data
+
+
+class NormalizeImage:
+    """normalize image such as substract mean, divide std"""
+
+    def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+        if isinstance(scale, str):
+            scale = eval(scale)
+        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        mean = mean if mean is not None else [0.485, 0.456, 0.406]
+        std = std if std is not None else [0.229, 0.224, 0.225]
+
+        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+        self.mean = np.array(mean).reshape(shape).astype("float32")
+        self.std = np.array(std).reshape(shape).astype("float32")
+
+    def __call__(self, data):
+        img = np.array(data["image"])
+        assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+        data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
+        return data
+
+
+class ToCHWImage:
+    """convert hwc image to chw image"""
+
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, data):
+        img = np.array(data["image"])
+        data["image"] = img.transpose((2, 0, 1))
+        return data
+
+
+class KeepKeys:
+    def __init__(self, keep_keys, **kwargs):
+        self.keep_keys = keep_keys
+
+    def __call__(self, data):
+        data_list = []
+        for key in self.keep_keys:
+            data_list.append(data[key])
+        return data_list
+
+
+def trans_char_ocr_res(ocr_res):
+    word_result = []
+    for res in ocr_res:
+        score = res[2]
+        for word_box, word in zip(res[3], res[4]):
+            word_res = []
+            word_res.append(word_box)
+            word_res.append(word)
+            word_res.append(score)
+            word_result.append(word_res)
+    return word_result

+ 33 - 65
mineru/model/table/rec/unet_table/main.py

@@ -10,14 +10,13 @@ import numpy as np
 import cv2
 from PIL import Image
 from loguru import logger
-from rapid_table import RapidTableInput, RapidTable
 
 from .table_structure_unet import TSRUnet
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from .table_recover import TableRecover
-from .utils import InputType, LoadImage, VisTable
+from .utils import InputType, LoadImage
 from .utils_table_recover import (
     match_ocr_cell,
     plot_html_table,
@@ -243,12 +242,9 @@ class UnetTableModel:
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
         wired_input_args = WiredTableInput(model_path=model_path)
         self.wired_table_model = WiredTableRecognition(wired_input_args, ocr_engine)
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
 
-    def predict(self, input_img, table_cls_score):
+    def predict(self, input_img, ocr_result, wireless_html_code):
         if isinstance(input_img, Image.Image):
             np_img = np.asarray(input_img)
         elif isinstance(input_img, np.ndarray):
@@ -256,67 +252,39 @@ class UnetTableModel:
         else:
             raise ValueError("Input must be a pillow object or a numpy array.")
         bgr_img = cv2.cvtColor(np_img, cv2.COLOR_RGB2BGR)
-        ocr_result = self.ocr_engine.ocr(bgr_img)[0]
-        if ocr_result:
+
+        if ocr_result is None:
+            ocr_result = self.ocr_engine.ocr(bgr_img)[0]
             ocr_result = [
                 [item[0], escape_html(item[1][0]), item[1][1]]
                 for item in ocr_result
                 if len(item) == 2 and isinstance(item[1], tuple)
             ]
-        else:
-            ocr_result = None
-        if ocr_result:
-            try:
-                wired_table_results = self.wired_table_model(np_img, ocr_result)
-
-                # viser = VisTable()
-                # save_html_path = f"outputs/output.html"
-                # save_drawed_path = f"outputs/output_table_vis.jpg"
-                # save_logic_path = (
-                #     f"outputs/output_table_vis_logic.jpg"
-                # )
-                # vis_imged = viser(
-                #     np_img, wired_table_results, save_html_path, save_drawed_path, save_logic_path
-                # )
-
-                wired_html_code = wired_table_results.pred_html
-                wired_table_cell_bboxes = wired_table_results.cell_bboxes
-                wired_logic_points = wired_table_results.logic_points
-                wired_elapse = wired_table_results.elapse
-
-                wireless_table_results = self.wireless_table_model(np_img, ocr_result)
-                wireless_html_code = wireless_table_results.pred_html
-                wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
-                wireless_logic_points = wireless_table_results.logic_points
-                wireless_elapse = wireless_table_results.elapse
-
-                # wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
-                # wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
-
-                wired_len = count_table_cells_physical(wired_html_code)
-                wireless_len = count_table_cells_physical(wireless_html_code)
-
-                # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
-                # 计算两种模型检测的单元格数量差异
-                gap_of_len = wireless_len - wired_len
-                # 判断是否使用无线表格模型的结果
-                if (
-                    wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
-                    # or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
-                    or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
-                    or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
-                ):
-                    # logger.debug("fall back to wireless table model")
-                    html_code = wireless_html_code
-                    table_cell_bboxes = wireless_table_cell_bboxes
-                    logic_points = wireless_logic_points
-                else:
-                    html_code = wired_html_code
-                    table_cell_bboxes = wired_table_cell_bboxes
-                    logic_points = wired_logic_points
-
-                elapse = wired_elapse + wireless_elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
-            except Exception as e:
-                logger.exception(e)
-        return None, None, None, None
+
+        try:
+            wired_table_results = self.wired_table_model(np_img, ocr_result)
+
+            wired_html_code = wired_table_results.pred_html
+
+            wired_len = count_table_cells_physical(wired_html_code)
+            wireless_len = count_table_cells_physical(wireless_html_code)
+
+            # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
+            # 计算两种模型检测的单元格数量差异
+            gap_of_len = wireless_len - wired_len
+            # 判断是否使用无线表格模型的结果
+            if (
+                wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                # or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
+                or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
+                or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
+            ):
+                # logger.debug("fall back to wireless table model")
+                html_code = wireless_html_code
+            else:
+                html_code = wired_html_code
+
+            return html_code
+        except Exception as e:
+            logger.exception(e)
+            return None

+ 0 - 1
pyproject.toml

@@ -61,7 +61,6 @@ pipeline = [
     "ultralytics>=8.3.48,<9",
     "doclayout_yolo==0.0.4",
     "dill>=0.3.8,<1",
-    "rapid_table>=1.0.5,<2.0.0",
     "PyYAML>=6.0.2,<7",
     "ftfy>=6.3.1,<7",
     "openai>=1.70.0,<2",