Selaa lähdekoodia

feat: replace rapid-table with local code

Sidney233 3 kuukautta sitten
vanhempi
commit
f0126cfc23

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
-from ...model.table.rec.rapid_table import RapidTableModel
+from ...model.table.rec.slanet_plus.rapid_table import RapidTableModel
 from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path

+ 0 - 0
mineru/model/table/rec/slanet_plus/__init__.py


+ 235 - 0
mineru/model/table/rec/slanet_plus/main.py

@@ -0,0 +1,235 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import argparse
+import copy
+import importlib
+import time
+from dataclasses import asdict, dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+
+from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
+
+from .matcher import TableMatch
+from .table_structure import TableStructurer
+from .table_structure_unitable import TableStructureUnitable
+
+logger = Logger(logger_name=__name__).get_log()
+root_dir = Path(__file__).resolve().parent
+
+
+class ModelType(Enum):
+    PPSTRUCTURE_EN = "ppstructure_en"
+    PPSTRUCTURE_ZH = "ppstructure_zh"
+    SLANETPLUS = "slanet_plus"
+    UNITABLE = "unitable"
+
+
+ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
+KEY_TO_MODEL_URL = {
+    ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
+    ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
+    ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
+    ModelType.UNITABLE.value: {
+        "encoder": f"{ROOT_URL}/unitable/encoder.pth",
+        "decoder": f"{ROOT_URL}/unitable/decoder.pth",
+        "vocab": f"{ROOT_URL}/unitable/vocab.json",
+    },
+}
+
+
+@dataclass
+class RapidTableInput:
+    model_type: Optional[str] = ModelType.SLANETPLUS.value
+    model_path: Union[str, Path, None, Dict[str, str]] = None
+    use_cuda: bool = False
+    device: str = "cpu"
+
+
+@dataclass
+class RapidTableOutput:
+    pred_html: Optional[str] = None
+    cell_bboxes: Optional[np.ndarray] = None
+    logic_points: Optional[np.ndarray] = None
+    elapse: Optional[float] = None
+
+
+class RapidTable:
+    def __init__(self, config: RapidTableInput):
+        self.model_type = config.model_type
+        if self.model_type not in KEY_TO_MODEL_URL:
+            model_list = ",".join(KEY_TO_MODEL_URL)
+            raise ValueError(
+                f"{self.model_type} is not supported. The currently supported models are {model_list}."
+            )
+
+        config.model_path = self.get_model_path(config.model_type, config.model_path)
+        if self.model_type == ModelType.UNITABLE.value:
+            self.table_structure = TableStructureUnitable(asdict(config))
+        else:
+            self.table_structure = TableStructurer(asdict(config))
+
+        self.table_matcher = TableMatch()
+
+        try:
+            self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
+        except ModuleNotFoundError:
+            self.ocr_engine = None
+
+        self.load_img = LoadImage()
+
+    def __call__(
+        self,
+        img_content: Union[str, np.ndarray, bytes, Path],
+        ocr_result: List[Union[List[List[float]], str, str]] = None,
+    ) -> RapidTableOutput:
+        if self.ocr_engine is None and ocr_result is None:
+            raise ValueError(
+                "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
+            )
+
+        img = self.load_img(img_content)
+
+        s = time.perf_counter()
+        h, w = img.shape[:2]
+
+        if ocr_result is None:
+            ocr_result = self.ocr_engine(img)
+            ocr_result = list(
+                zip(
+                    ocr_result.boxes,
+                    ocr_result.txts,
+                    ocr_result.scores,
+                )
+            )
+        dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
+
+        pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
+
+        # 适配slanet-plus模型输出的box缩放还原
+        if self.model_type == ModelType.SLANETPLUS.value:
+            cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+
+        pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
+
+        # 过滤掉占位的bbox
+        mask = ~np.all(cell_bboxes == 0, axis=1)
+        cell_bboxes = cell_bboxes[mask]
+
+        logic_points = self.table_matcher.decode_logic_points(pred_structures)
+        elapse = time.perf_counter() - s
+        return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
+
+    def get_boxes_recs(
+        self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
+    ) -> Tuple[np.ndarray, Tuple[str, str]]:
+        dt_boxes, rec_res, scores = list(zip(*ocr_result))
+        rec_res = list(zip(rec_res, scores))
+
+        r_boxes = []
+        for box in dt_boxes:
+            box = np.array(box)
+            x_min = max(0, box[:, 0].min() - 1)
+            x_max = min(w, box[:, 0].max() + 1)
+            y_min = max(0, box[:, 1].min() - 1)
+            y_max = min(h, box[:, 1].max() + 1)
+            box = [x_min, y_min, x_max, y_max]
+            r_boxes.append(box)
+        dt_boxes = np.array(r_boxes)
+        return dt_boxes, rec_res
+
+    def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
+        h, w = img.shape[:2]
+        resized = 488
+        ratio = min(resized / h, resized / w)
+        w_ratio = resized / (w * ratio)
+        h_ratio = resized / (h * ratio)
+        cell_bboxes[:, 0::2] *= w_ratio
+        cell_bboxes[:, 1::2] *= h_ratio
+        return cell_bboxes
+
+    @staticmethod
+    def get_model_path(
+        model_type: str, model_path: Union[str, Path, None]
+    ) -> Union[str, Dict[str, str]]:
+        if model_path is not None:
+            return model_path
+
+        model_url = KEY_TO_MODEL_URL.get(model_type, None)
+        if isinstance(model_url, str):
+            model_path = DownloadModel.download(model_url)
+            return model_path
+
+        if isinstance(model_url, dict):
+            model_paths = {}
+            for k, url in model_url.items():
+                model_paths[k] = DownloadModel.download(
+                    url, save_model_name=f"{model_type}_{Path(url).name}"
+                )
+            return model_paths
+
+        raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
+
+
+def parse_args(arg_list: Optional[List[str]] = None):
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-v",
+        "--vis",
+        action="store_true",
+        default=False,
+        help="Wheter to visualize the layout results.",
+    )
+    parser.add_argument(
+        "-img", "--img_path", type=str, required=True, help="Path to image for layout."
+    )
+    parser.add_argument(
+        "-m",
+        "--model_type",
+        type=str,
+        default=ModelType.SLANETPLUS.value,
+        choices=list(KEY_TO_MODEL_URL),
+    )
+    args = parser.parse_args(arg_list)
+    return args
+
+
+def main(arg_list: Optional[List[str]] = None):
+    args = parse_args(arg_list)
+
+    try:
+        ocr_engine = importlib.import_module("rapidocr").RapidOCR()
+    except ModuleNotFoundError as exc:
+        raise ModuleNotFoundError(
+            "Please install the rapidocr by pip install rapidocr"
+        ) from exc
+
+    input_args = RapidTableInput(model_type=args.model_type)
+    table_engine = RapidTable(input_args)
+
+    img = cv2.imread(args.img_path)
+
+    rapid_ocr_output = ocr_engine(img)
+    ocr_result = list(
+        zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
+    )
+    table_results = table_engine(img, ocr_result)
+    print(table_results.pred_html)
+
+    viser = VisTable()
+    if args.vis:
+        img_path = Path(args.img_path)
+
+        save_dir = img_path.resolve().parent
+        save_html_path = save_dir / f"{Path(img_path).stem}.html"
+        save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
+        viser(img_path, table_results, save_html_path, save_drawed_path)
+
+
+if __name__ == "__main__":
+    main()

+ 199 - 0
mineru/model/table/rec/slanet_plus/matcher.py

@@ -0,0 +1,199 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+import numpy as np
+
+from .matcher_utils import compute_iou, distance
+
+
+class TableMatch:
+    def __init__(self, filter_ocr_result=True, use_master=False):
+        self.filter_ocr_result = filter_ocr_result
+        self.use_master = use_master
+
+    def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
+        if self.filter_ocr_result:
+            dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
+        matched_index = self.match_result(dt_boxes, cell_bboxes)
+        pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
+        return pred_html
+
+    def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
+        matched = {}
+        for i, gt_box in enumerate(dt_boxes):
+            distances = []
+            for j, pred_box in enumerate(cell_bboxes):
+                if len(pred_box) == 8:
+                    pred_box = [
+                        np.min(pred_box[0::2]),
+                        np.min(pred_box[1::2]),
+                        np.max(pred_box[0::2]),
+                        np.max(pred_box[1::2]),
+                    ]
+                distances.append(
+                    (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
+                )  # compute iou and l1 distance
+            sorted_distances = distances.copy()
+            # select det box by iou and l1 distance
+            sorted_distances = sorted(
+                sorted_distances, key=lambda item: (item[1], item[0])
+            )
+            # must > min_iou
+            if sorted_distances[0][1] >= 1 - min_iou:
+                continue
+
+            if distances.index(sorted_distances[0]) not in matched:
+                matched[distances.index(sorted_distances[0])] = [i]
+            else:
+                matched[distances.index(sorted_distances[0])].append(i)
+        return matched
+
+    def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+        end_html = []
+        td_index = 0
+        for tag in pred_structures:
+            if "</td>" not in tag:
+                end_html.append(tag)
+                continue
+
+            if "<td></td>" == tag:
+                end_html.extend("<td>")
+
+            if td_index in matched_index.keys():
+                b_with = False
+                if (
+                    "<b>" in ocr_contents[matched_index[td_index][0]]
+                    and len(matched_index[td_index]) > 1
+                ):
+                    b_with = True
+                    end_html.extend("<b>")
+
+                for i, td_index_index in enumerate(matched_index[td_index]):
+                    content = ocr_contents[td_index_index][0]
+                    if len(matched_index[td_index]) > 1:
+                        if len(content) == 0:
+                            continue
+
+                        if content[0] == " ":
+                            content = content[1:]
+
+                        if "<b>" in content:
+                            content = content[3:]
+
+                        if "</b>" in content:
+                            content = content[:-4]
+
+                        if len(content) == 0:
+                            continue
+
+                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+                            content += " "
+                    end_html.extend(content)
+
+                if b_with:
+                    end_html.extend("</b>")
+
+            if "<td></td>" == tag:
+                end_html.append("</td>")
+            else:
+                end_html.append(tag)
+
+            td_index += 1
+
+        # Filter <thead></thead><tbody></tbody> elements
+        filter_elements = ["<thead>", "</thead>", "<tbody>", "</tbody>"]
+        end_html = [v for v in end_html if v not in filter_elements]
+        return "".join(end_html), end_html
+
+    def decode_logic_points(self, pred_structures):
+        logic_points = []
+        current_row = 0
+        current_col = 0
+        max_rows = 0
+        max_cols = 0
+        occupied_cells = {}  # 用于记录已经被占用的单元格
+
+        def is_occupied(row, col):
+            return (row, col) in occupied_cells
+
+        def mark_occupied(row, col, rowspan, colspan):
+            for r in range(row, row + rowspan):
+                for c in range(col, col + colspan):
+                    occupied_cells[(r, c)] = True
+
+        i = 0
+        while i < len(pred_structures):
+            token = pred_structures[i]
+
+            if token == "<tr>":
+                current_col = 0  # 每次遇到 <tr> 时,重置当前列号
+            elif token == "</tr>":
+                current_row += 1  # 行结束,行号增加
+            elif token.startswith("<td"):
+                colspan = 1
+                rowspan = 1
+                j = i
+                if token != "<td></td>":
+                    j += 1
+                    # 提取 colspan 和 rowspan 属性
+                    while j < len(pred_structures) and not pred_structures[
+                        j
+                    ].startswith(">"):
+                        if "colspan=" in pred_structures[j]:
+                            colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+                        elif "rowspan=" in pred_structures[j]:
+                            rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+                        j += 1
+
+                # 跳过已经处理过的属性 token
+                i = j
+
+                # 找到下一个未被占用的列
+                while is_occupied(current_row, current_col):
+                    current_col += 1
+
+                # 计算逻辑坐标
+                r_start = current_row
+                r_end = current_row + rowspan - 1
+                col_start = current_col
+                col_end = current_col + colspan - 1
+
+                # 记录逻辑坐标
+                logic_points.append([r_start, r_end, col_start, col_end])
+
+                # 标记占用的单元格
+                mark_occupied(r_start, col_start, rowspan, colspan)
+
+                # 更新当前列号
+                current_col += colspan
+
+                # 更新最大行数和列数
+                max_rows = max(max_rows, r_end + 1)
+                max_cols = max(max_cols, col_end + 1)
+
+            i += 1
+
+        return logic_points
+
+    def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
+        y1 = cell_bboxes[:, 1::2].min()
+        new_dt_boxes = []
+        new_rec_res = []
+
+        for box, rec in zip(dt_boxes, rec_res):
+            if np.max(box[1::2]) < y1:
+                continue
+            new_dt_boxes.append(box)
+            new_rec_res.append(rec)
+        return new_dt_boxes, new_rec_res

+ 249 - 0
mineru/model/table/rec/slanet_plus/matcher_utils.py

@@ -0,0 +1,249 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import copy
+import re
+
+
+def deal_isolate_span(thead_part):
+    """
+    Deal with isolate span cases in this function.
+    It causes by wrong prediction in structure recognition model.
+    eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out isolate span tokens.
+    isolate_pattern = (
+        '<td></td> rowspan="(\d)+" colspan="(\d)+"></b></td>|'
+        '<td></td> colspan="(\d)+" rowspan="(\d)+"></b></td>|'
+        '<td></td> rowspan="(\d)+"></b></td>|'
+        '<td></td> colspan="(\d)+"></b></td>'
+    )
+    isolate_iter = re.finditer(isolate_pattern, thead_part)
+    isolate_list = [i.group() for i in isolate_iter]
+
+    # 2. find out span number, by step 1 results.
+    span_pattern = (
+        ' rowspan="(\d)+" colspan="(\d)+"|'
+        ' colspan="(\d)+" rowspan="(\d)+"|'
+        ' rowspan="(\d)+"|'
+        ' colspan="(\d)+"'
+    )
+    corrected_list = []
+    for isolate_item in isolate_list:
+        span_part = re.search(span_pattern, isolate_item)
+        spanStr_in_isolateItem = span_part.group()
+        # 3. merge the span number into the span token format string.
+        if spanStr_in_isolateItem is not None:
+            corrected_item = f"<td{spanStr_in_isolateItem}></td>"
+            corrected_list.append(corrected_item)
+        else:
+            corrected_list.append(None)
+
+    # 4. replace original isolated token.
+    for corrected_item, isolate_item in zip(corrected_list, isolate_list):
+        if corrected_item is not None:
+            thead_part = thead_part.replace(isolate_item, corrected_item)
+        else:
+            pass
+    return thead_part
+
+
+def deal_duplicate_bb(thead_part):
+    """
+    Deal duplicate <b> or </b> after replace.
+    Keep one <b></b> in a <td></td> token.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out <td></td> in <thead></thead>.
+    td_pattern = (
+        '<td rowspan="(\d)+" colspan="(\d)+">(.+?)</td>|'
+        '<td colspan="(\d)+" rowspan="(\d)+">(.+?)</td>|'
+        '<td rowspan="(\d)+">(.+?)</td>|'
+        '<td colspan="(\d)+">(.+?)</td>|'
+        "<td>(.*?)</td>"
+    )
+    td_iter = re.finditer(td_pattern, thead_part)
+    td_list = [t.group() for t in td_iter]
+
+    # 2. is multiply <b></b> in <td></td> or not?
+    new_td_list = []
+    for td_item in td_list:
+        if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
+            # multiply <b></b> in <td></td> case.
+            # 1. remove all <b></b>
+            td_item = td_item.replace("<b>", "").replace("</b>", "")
+            # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
+            td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
+            new_td_list.append(td_item)
+        else:
+            new_td_list.append(td_item)
+
+    # 3. replace original thead part.
+    for td_item, new_td_item in zip(td_list, new_td_list):
+        thead_part = thead_part.replace(td_item, new_td_item)
+    return thead_part
+
+
+def deal_bb(result_token):
+    """
+    In our opinion, <b></b> always occurs in <thead></thead> text's context.
+    This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
+    :param result_token:
+    :return:
+    """
+    # find out <thead></thead> parts.
+    thead_pattern = "<thead>(.*?)</thead>"
+    if re.search(thead_pattern, result_token) is None:
+        return result_token
+    thead_part = re.search(thead_pattern, result_token).group()
+    origin_thead_part = copy.deepcopy(thead_part)
+
+    # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
+    span_pattern = '<td rowspan="(\d)+" colspan="(\d)+">|<td colspan="(\d)+" rowspan="(\d)+">|<td rowspan="(\d)+">|<td colspan="(\d)+">'
+    span_iter = re.finditer(span_pattern, thead_part)
+    span_list = [s.group() for s in span_iter]
+    has_span_in_head = True if len(span_list) > 0 else False
+
+    if not has_span_in_head:
+        # <thead></thead> not include "rowspan" or "colspan" branch 1.
+        # 1. replace <td> to <td><b>, and </td> to </b></td>
+        # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b></b> to </b>
+        thead_part = (
+            thead_part.replace("<td>", "<td><b>")
+            .replace("</td>", "</b></td>")
+            .replace("<b><b>", "<b>")
+            .replace("</b></b>", "</b>")
+        )
+    else:
+        # <thead></thead> include "rowspan" or "colspan" branch 2.
+        # Firstly, we deal rowspan or colspan cases.
+        # 1. replace > to ><b>
+        # 2. replace </td> to </b></td>
+        # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b><b> to </b>
+
+        # Secondly, deal ordinary cases like branch 1
+
+        # replace ">" to "<b>"
+        replaced_span_list = []
+        for sp in span_list:
+            replaced_span_list.append(sp.replace(">", "><b>"))
+        for sp, rsp in zip(span_list, replaced_span_list):
+            thead_part = thead_part.replace(sp, rsp)
+
+        # replace "</td>" to "</b></td>"
+        thead_part = thead_part.replace("</td>", "</b></td>")
+
+        # remove duplicated <b> by re.sub
+        mb_pattern = "(<b>)+"
+        single_b_string = "<b>"
+        thead_part = re.sub(mb_pattern, single_b_string, thead_part)
+
+        mgb_pattern = "(</b>)+"
+        single_gb_string = "</b>"
+        thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
+
+        # ordinary cases like branch 1
+        thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
+
+    # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
+    # but space cell(<tb> </tb>)  is suitable for <td><b> </b></td>
+    thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
+    # deal with duplicated <b></b>
+    thead_part = deal_duplicate_bb(thead_part)
+    # deal with isolate span tokens, which causes by wrong predict by structure prediction.
+    # eg.PMC5994107_011_00.png
+    thead_part = deal_isolate_span(thead_part)
+    # replace original result with new thead part.
+    result_token = result_token.replace(origin_thead_part, thead_part)
+    return result_token
+
+
+def deal_eb_token(master_token):
+    """
+    post process with <eb></eb>, <eb1></eb1>, ...
+    emptyBboxTokenDict = {
+        "[]": '<eb></eb>',
+        "[' ']": '<eb1></eb1>',
+        "['<b>', ' ', '</b>']": '<eb2></eb2>',
+        "['\\u2028', '\\u2028']": '<eb3></eb3>',
+        "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
+        "['<b>', '</b>']": '<eb5></eb5>',
+        "['<i>', ' ', '</i>']": '<eb6></eb6>',
+        "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
+        "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
+        "['<i>', '</i>']": '<eb9></eb9>',
+        "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
+    }
+    :param master_token:
+    :return:
+    """
+    master_token = master_token.replace("<eb></eb>", "<td></td>")
+    master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
+    master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
+    master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
+    master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
+    master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
+    master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
+    master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
+    master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
+    master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
+    master_token = master_token.replace(
+        "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
+    )
+    return master_token
+
+
+def distance(box_1, box_2):
+    x1, y1, x2, y2 = box_1
+    x3, y3, x4, y4 = box_2
+    dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+    dis_2 = abs(x3 - x1) + abs(y3 - y1)
+    dis_3 = abs(x4 - x2) + abs(y4 - y2)
+    return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1, rec2):
+    """
+    computing IoU
+    :param rec1: (y0, x0, y1, x1), which reflects
+            (top, left, bottom, right)
+    :param rec2: (y0, x0, y1, x1)
+    :return: scala value of IoU
+    """
+    # computing area of each rectangles
+    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+    # computing the sum_area
+    sum_area = S_rec1 + S_rec2
+
+    # find the each edge of intersect rectangle
+    left_line = max(rec1[1], rec2[1])
+    right_line = min(rec1[3], rec2[3])
+    top_line = max(rec1[0], rec2[0])
+    bottom_line = min(rec1[2], rec2[2])
+
+    # judge if there is an intersect
+    if left_line >= right_line or top_line >= bottom_line:
+        return 0.0
+
+    intersect = (right_line - left_line) * (bottom_line - top_line)
+    return (intersect / (sum_area - intersect)) * 1.0

+ 1 - 1
mineru/model/table/rec/rapid_table.py → mineru/model/table/rec/slanet_plus/rapid_table.py

@@ -3,7 +3,7 @@ import html
 import cv2
 import numpy as np
 from loguru import logger
-from rapid_table import RapidTable, RapidTableInput
+from .main import RapidTable, RapidTableInput
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path

+ 58 - 0
mineru/model/table/rec/slanet_plus/table_structure.py

@@ -0,0 +1,58 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+from typing import Any, Dict
+
+import numpy as np
+
+from .table_stucture_utils import OrtInferSession, TableLabelDecode, TablePreprocess
+
+
+class TableStructurer:
+    def __init__(self, config: Dict[str, Any]):
+        self.preprocess_op = TablePreprocess()
+
+        self.session = OrtInferSession(config)
+
+        self.character = self.session.get_metadata()
+        self.postprocess_op = TableLabelDecode(self.character)
+
+    def __call__(self, img):
+        starttime = time.time()
+        data = {"image": img}
+        data = self.preprocess_op(data)
+        img = data[0]
+        if img is None:
+            return None, 0
+        img = np.expand_dims(img, axis=0)
+        img = img.copy()
+
+        outputs = self.session([img])
+
+        preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
+
+        shape_list = np.expand_dims(data[-1], axis=0)
+        post_result = self.postprocess_op(preds, [shape_list])
+
+        bbox_list = post_result["bbox_batch_list"][0]
+
+        structure_str_list = post_result["structure_batch_list"][0]
+        structure_str_list = structure_str_list[0]
+        structure_str_list = (
+            ["<html>", "<body>", "<table>"]
+            + structure_str_list
+            + ["</table>", "</body>", "</html>"]
+        )
+        elapse = time.time() - starttime
+        return structure_str_list, bbox_list, elapse

+ 229 - 0
mineru/model/table/rec/slanet_plus/table_structure_unitable.py

@@ -0,0 +1,229 @@
+import re
+import time
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from tokenizers import Tokenizer
+from torchvision import transforms
+
+from .unitable_modules import Encoder, GPTFastDecoder
+
+IMG_SIZE = 448
+EOS_TOKEN = "<eos>"
+BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)]
+
+HTML_BBOX_HTML_TOKENS = [
+    "<td></td>",
+    "<td>[",
+    "]</td>",
+    "<td",
+    ">[",
+    "></td>",
+    "<tr>",
+    "</tr>",
+    "<tbody>",
+    "</tbody>",
+    "<thead>",
+    "</thead>",
+    ' rowspan="2"',
+    ' rowspan="3"',
+    ' rowspan="4"',
+    ' rowspan="5"',
+    ' rowspan="6"',
+    ' rowspan="7"',
+    ' rowspan="8"',
+    ' rowspan="9"',
+    ' rowspan="10"',
+    ' rowspan="11"',
+    ' rowspan="12"',
+    ' rowspan="13"',
+    ' rowspan="14"',
+    ' rowspan="15"',
+    ' rowspan="16"',
+    ' rowspan="17"',
+    ' rowspan="18"',
+    ' rowspan="19"',
+    ' colspan="2"',
+    ' colspan="3"',
+    ' colspan="4"',
+    ' colspan="5"',
+    ' colspan="6"',
+    ' colspan="7"',
+    ' colspan="8"',
+    ' colspan="9"',
+    ' colspan="10"',
+    ' colspan="11"',
+    ' colspan="12"',
+    ' colspan="13"',
+    ' colspan="14"',
+    ' colspan="15"',
+    ' colspan="16"',
+    ' colspan="17"',
+    ' colspan="18"',
+    ' colspan="19"',
+    ' colspan="25"',
+]
+
+VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS
+TASK_TOKENS = [
+    "[table]",
+    "[html]",
+    "[cell]",
+    "[bbox]",
+    "[cell+bbox]",
+    "[html+bbox]",
+]
+
+
+class TableStructureUnitable:
+    def __init__(self, config):
+        # encoder_path: str, decoder_path: str, vocab_path: str, device: str
+        vocab_path = config["model_path"]["vocab"]
+        encoder_path = config["model_path"]["encoder"]
+        decoder_path = config["model_path"]["decoder"]
+        device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu"
+
+        self.vocab = Tokenizer.from_file(vocab_path)
+        self.token_white_list = [
+            self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS
+        ]
+        self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS)
+        self.bbox_close_html_token = self.vocab.token_to_id("]</td>")
+        self.prefix_token_id = self.vocab.token_to_id("[html+bbox]")
+        self.eos_id = self.vocab.token_to_id(EOS_TOKEN)
+        self.max_seq_len = 1024
+        self.device = device
+        self.img_size = IMG_SIZE
+
+        # init encoder
+        encoder_state_dict = torch.load(encoder_path, map_location=device)
+        self.encoder = Encoder()
+        self.encoder.load_state_dict(encoder_state_dict)
+        self.encoder.eval().to(device)
+
+        # init decoder
+        decoder_state_dict = torch.load(decoder_path, map_location=device)
+        self.decoder = GPTFastDecoder()
+        self.decoder.load_state_dict(decoder_state_dict)
+        self.decoder.eval().to(device)
+
+        # define img transform
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize((448, 448)),
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.86597056, 0.88463002, 0.87491087],
+                    std=[0.20686628, 0.18201602, 0.18485524],
+                ),
+            ]
+        )
+
+    @torch.inference_mode()
+    def __call__(self, image: np.ndarray):
+        start_time = time.time()
+        ori_h, ori_w = image.shape[:2]
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+        image = Image.fromarray(image)
+        image = self.transform(image).unsqueeze(0).to(self.device)
+        self.decoder.setup_caches(
+            max_batch_size=1,
+            max_seq_length=self.max_seq_len,
+            dtype=image.dtype,
+            device=self.device,
+        )
+        context = (
+            torch.tensor([self.prefix_token_id], dtype=torch.int32)
+            .repeat(1, 1)
+            .to(self.device)
+        )
+        eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device)
+        memory = self.encoder(image)
+        context = self.loop_decode(context, eos_id_tensor, memory)
+        bboxes, html_tokens = self.decode_tokens(context)
+        bboxes = bboxes.astype(np.float32)
+
+        # rescale boxes
+        scale_h = ori_h / self.img_size
+        scale_w = ori_w / self.img_size
+        bboxes[:, 0::2] *= scale_w  # 缩放 x 坐标
+        bboxes[:, 1::2] *= scale_h  # 缩放 y 坐标
+        bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1)
+        bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1)
+        structure_str_list = (
+            ["<html>", "<body>", "<table>"]
+            + html_tokens
+            + ["</table>", "</body>", "</html>"]
+        )
+        return structure_str_list, bboxes, time.time() - start_time
+
+    def decode_tokens(self, context):
+        pred_html = context[0]
+        pred_html = pred_html.detach().cpu().numpy()
+        pred_html = self.vocab.decode(pred_html, skip_special_tokens=False)
+        seq = pred_html.split("<eos>")[0]
+        token_black_list = ["<eos>", "<pad>", *TASK_TOKENS]
+        for i in token_black_list:
+            seq = seq.replace(i, "")
+
+        tr_pattern = re.compile(r"<tr>(.*?)</tr>", re.DOTALL)
+        td_pattern = re.compile(r"<td(.*?)>(.*?)</td>", re.DOTALL)
+        bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]")
+
+        decoded_list = []
+        bbox_coords = []
+
+        # 查找所有的 <tr> 标签
+        for tr_match in tr_pattern.finditer(pred_html):
+            tr_content = tr_match.group(1)
+            decoded_list.append("<tr>")
+
+            # 查找所有的 <td> 标签
+            for td_match in td_pattern.finditer(tr_content):
+                td_attrs = td_match.group(1).strip()
+                td_content = td_match.group(2).strip()
+                if td_attrs:
+                    decoded_list.append("<td")
+                    # 可能同时存在行列合并,需要都添加
+                    attrs_list = td_attrs.split()
+                    for attr in attrs_list:
+                        decoded_list.append(" " + attr)
+                    decoded_list.append(">")
+                    decoded_list.append("</td>")
+                else:
+                    decoded_list.append("<td></td>")
+
+                # 查找 bbox 坐标
+                bbox_match = bbox_pattern.search(td_content)
+                if bbox_match:
+                    xmin, ymin, xmax, ymax = map(int, bbox_match.groups())
+                    # 将坐标转换为从左上角开始顺时针到左下角的点的坐标
+                    coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
+                    bbox_coords.append(coords)
+                else:
+                    # 填充占位的bbox,保证后续流程统一
+                    bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
+            decoded_list.append("</tr>")
+
+        bbox_coords_array = np.array(bbox_coords)
+        return bbox_coords_array, decoded_list
+
+    def loop_decode(self, context, eos_id_tensor, memory):
+        box_token_count = 0
+        for _ in range(self.max_seq_len):
+            eos_flag = (context == eos_id_tensor).any(dim=1)
+            if torch.all(eos_flag):
+                break
+
+            next_tokens = self.decoder(memory, context)
+            if next_tokens[0] in self.bbox_token_ids:
+                box_token_count += 1
+                if box_token_count > 4:
+                    next_tokens = torch.tensor(
+                        [self.bbox_close_html_token], dtype=torch.int32
+                    )
+                    box_token_count = 0
+            context = torch.cat([context, next_tokens], dim=1)
+        return context

+ 544 - 0
mineru/model/table/rec/slanet_plus/table_stucture_utils.py

@@ -0,0 +1,544 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import os
+import platform
+import traceback
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+from onnxruntime import (
+    GraphOptimizationLevel,
+    InferenceSession,
+    SessionOptions,
+    get_available_providers,
+    get_device,
+)
+
+from rapid_table.utils import Logger
+
+
+class EP(Enum):
+    CPU_EP = "CPUExecutionProvider"
+    CUDA_EP = "CUDAExecutionProvider"
+    DIRECTML_EP = "DmlExecutionProvider"
+
+
+class OrtInferSession:
+    def __init__(self, config: Dict[str, Any]):
+        self.logger = Logger(logger_name=__name__).get_log()
+
+        model_path = config.get("model_path", None)
+        self._verify_model(model_path)
+
+        self.cfg_use_cuda = config.get("use_cuda", None)
+        self.cfg_use_dml = config.get("use_dml", None)
+
+        self.had_providers: List[str] = get_available_providers()
+        EP_list = self._get_ep_list()
+
+        sess_opt = self._init_sess_opts(config)
+        self.session = InferenceSession(
+            model_path,
+            sess_options=sess_opt,
+            providers=EP_list,
+        )
+        self._verify_providers()
+
+    @staticmethod
+    def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
+        sess_opt = SessionOptions()
+        sess_opt.log_severity_level = 4
+        sess_opt.enable_cpu_mem_arena = False
+        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+
+        cpu_nums = os.cpu_count()
+        intra_op_num_threads = config.get("intra_op_num_threads", -1)
+        if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
+            sess_opt.intra_op_num_threads = intra_op_num_threads
+
+        inter_op_num_threads = config.get("inter_op_num_threads", -1)
+        if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
+            sess_opt.inter_op_num_threads = inter_op_num_threads
+
+        return sess_opt
+
+    def get_metadata(self, key: str = "character") -> list:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        content_list = meta_dict[key].splitlines()
+        return content_list
+
+    def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
+        cpu_provider_opts = {
+            "arena_extend_strategy": "kSameAsRequested",
+        }
+        EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+
+        cuda_provider_opts = {
+            "device_id": 0,
+            "arena_extend_strategy": "kNextPowerOfTwo",
+            "cudnn_conv_algo_search": "EXHAUSTIVE",
+            "do_copy_in_default_stream": True,
+        }
+        self.use_cuda = self._check_cuda()
+        if self.use_cuda:
+            EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
+
+        self.use_directml = self._check_dml()
+        if self.use_directml:
+            self.logger.info(
+                "Windows 10 or above detected, try to use DirectML as primary provider"
+            )
+            directml_options = (
+                cuda_provider_opts if self.use_cuda else cpu_provider_opts
+            )
+            EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
+        return EP_list
+
+    def _check_cuda(self) -> bool:
+        if not self.cfg_use_cuda:
+            return False
+
+        cur_device = get_device()
+        if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
+            return True
+
+        self.logger.warning(
+            "%s is not in available providers (%s). Use %s inference by default.",
+            EP.CUDA_EP.value,
+            self.had_providers,
+            self.had_providers[0],
+        )
+        self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
+        self.logger.info(
+            "(For reference only) If you want to use GPU acceleration, you must do:"
+        )
+        self.logger.info(
+            "First, uninstall all onnxruntime pakcages in current environment."
+        )
+        self.logger.info(
+            "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
+        )
+        self.logger.info(
+            "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
+        )
+        self.logger.info(
+            "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
+        )
+        self.logger.info(
+            "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
+            EP.CUDA_EP.value,
+        )
+        return False
+
+    def _check_dml(self) -> bool:
+        if not self.cfg_use_dml:
+            return False
+
+        cur_os = platform.system()
+        if cur_os != "Windows":
+            self.logger.warning(
+                "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
+                cur_os,
+                self.had_providers[0],
+            )
+            return False
+
+        cur_window_version = int(platform.release().split(".")[0])
+        if cur_window_version < 10:
+            self.logger.warning(
+                "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
+                cur_window_version,
+                self.had_providers[0],
+            )
+            return False
+
+        if EP.DIRECTML_EP.value in self.had_providers:
+            return True
+
+        self.logger.warning(
+            "%s is not in available providers (%s). Use %s inference by default.",
+            EP.DIRECTML_EP.value,
+            self.had_providers,
+            self.had_providers[0],
+        )
+        self.logger.info("If you want to use DirectML acceleration, you must do:")
+        self.logger.info(
+            "First, uninstall all onnxruntime pakcages in current environment."
+        )
+        self.logger.info(
+            "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
+        )
+        self.logger.info(
+            "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
+            EP.DIRECTML_EP.value,
+        )
+        return False
+
+    def _verify_providers(self):
+        session_providers = self.session.get_providers()
+        first_provider = session_providers[0]
+
+        if self.use_cuda and first_provider != EP.CUDA_EP.value:
+            self.logger.warning(
+                "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
+                EP.CUDA_EP.value,
+                first_provider,
+            )
+
+        if self.use_directml and first_provider != EP.DIRECTML_EP.value:
+            self.logger.warning(
+                "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
+                EP.DIRECTML_EP.value,
+                first_provider,
+            )
+
+    def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
+        input_dict = dict(zip(self.get_input_names(), input_content))
+        try:
+            return self.session.run(None, input_dict)
+        except Exception as e:
+            error_info = traceback.format_exc()
+            raise ONNXRuntimeError(error_info) from e
+
+    def get_input_names(self) -> List[str]:
+        return [v.name for v in self.session.get_inputs()]
+
+    def get_output_names(self) -> List[str]:
+        return [v.name for v in self.session.get_outputs()]
+
+    def get_character_list(self, key: str = "character") -> List[str]:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        return meta_dict[key].splitlines()
+
+    def have_key(self, key: str = "character") -> bool:
+        meta_dict = self.session.get_modelmeta().custom_metadata_map
+        if key in meta_dict.keys():
+            return True
+        return False
+
+    @staticmethod
+    def _verify_model(model_path: Union[str, Path, None]):
+        if model_path is None:
+            raise ValueError("model_path is None!")
+
+        model_path = Path(model_path)
+        if not model_path.exists():
+            raise FileNotFoundError(f"{model_path} does not exists.")
+
+        if not model_path.is_file():
+            raise FileExistsError(f"{model_path} is not a file.")
+
+
+class ONNXRuntimeError(Exception):
+    pass
+
+
+class TableLabelDecode:
+    def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
+        if merge_no_span_structure:
+            if "<td></td>" not in dict_character:
+                dict_character.append("<td></td>")
+            if "<td>" in dict_character:
+                dict_character.remove("<td>")
+
+        dict_character = self.add_special_char(dict_character)
+        self.dict = {}
+        for i, char in enumerate(dict_character):
+            self.dict[char] = i
+        self.character = dict_character
+        self.td_token = ["<td>", "<td", "<td></td>"]
+
+    def __call__(self, preds, batch=None):
+        structure_probs = preds["structure_probs"]
+        bbox_preds = preds["loc_preds"]
+        shape_list = batch[-1]
+        result = self.decode(structure_probs, bbox_preds, shape_list)
+        if len(batch) == 1:  # only contains shape
+            return result
+
+        label_decode_result = self.decode_label(batch)
+        return result, label_decode_result
+
+    def decode(self, structure_probs, bbox_preds, shape_list):
+        """convert text-label into text-index."""
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_idx = structure_probs.argmax(axis=2)
+        structure_probs = structure_probs.max(axis=2)
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            score_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+
+                if char_idx in ignored_tokens:
+                    continue
+
+                text = self.character[char_idx]
+                if text in self.td_token:
+                    bbox = bbox_preds[batch_idx, idx]
+                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox_list.append(bbox)
+                structure_list.append(text)
+                score_list.append(structure_probs[batch_idx, idx])
+            structure_batch_list.append([structure_list, np.mean(score_list)])
+            bbox_batch_list.append(np.array(bbox_list))
+        result = {
+            "bbox_batch_list": bbox_batch_list,
+            "structure_batch_list": structure_batch_list,
+        }
+        return result
+
+    def decode_label(self, batch):
+        """convert text-label into text-index."""
+        structure_idx = batch[1]
+        gt_bbox_list = batch[2]
+        shape_list = batch[-1]
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+
+                if char_idx in ignored_tokens:
+                    continue
+
+                structure_list.append(self.character[char_idx])
+
+                bbox = gt_bbox_list[batch_idx][idx]
+                if bbox.sum() != 0:
+                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox_list.append(bbox)
+
+            structure_batch_list.append(structure_list)
+            bbox_batch_list.append(bbox_list)
+        result = {
+            "bbox_batch_list": bbox_batch_list,
+            "structure_batch_list": structure_batch_list,
+        }
+        return result
+
+    def _bbox_decode(self, bbox, shape):
+        h, w = shape[:2]
+        bbox[0::2] *= w
+        bbox[1::2] *= h
+        return bbox
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            return np.array(self.dict[self.beg_str])
+
+        if beg_or_end == "end":
+            return np.array(self.dict[self.end_str])
+
+        raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
+
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = [self.beg_str] + dict_character + [self.end_str]
+        return dict_character
+
+
+class TablePreprocess:
+    def __init__(self):
+        self.table_max_len = 488
+        self.build_pre_process_list()
+        self.ops = self.create_operators()
+
+    def __call__(self, data):
+        """transform"""
+        if self.ops is None:
+            self.ops = []
+
+        for op in self.ops:
+            data = op(data)
+            if data is None:
+                return None
+        return data
+
+    def create_operators(
+        self,
+    ):
+        """
+        create operators based on the config
+
+        Args:
+            params(list): a dict list, used to create some operators
+        """
+        assert isinstance(
+            self.pre_process_list, list
+        ), "operator config should be a list"
+        ops = []
+        for operator in self.pre_process_list:
+            assert (
+                isinstance(operator, dict) and len(operator) == 1
+            ), "yaml format error"
+            op_name = list(operator)[0]
+            param = {} if operator[op_name] is None else operator[op_name]
+            op = eval(op_name)(**param)
+            ops.append(op)
+        return ops
+
+    def build_pre_process_list(self):
+        resize_op = {
+            "ResizeTableImage": {
+                "max_len": self.table_max_len,
+            }
+        }
+        pad_op = {
+            "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
+        }
+        normalize_op = {
+            "NormalizeImage": {
+                "std": [0.229, 0.224, 0.225],
+                "mean": [0.485, 0.456, 0.406],
+                "scale": "1./255.",
+                "order": "hwc",
+            }
+        }
+        to_chw_op = {"ToCHWImage": None}
+        keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
+        self.pre_process_list = [
+            resize_op,
+            normalize_op,
+            pad_op,
+            to_chw_op,
+            keep_keys_op,
+        ]
+
+
+class ResizeTableImage:
+    def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
+        super(ResizeTableImage, self).__init__()
+        self.max_len = max_len
+        self.resize_bboxes = resize_bboxes
+        self.infer_mode = infer_mode
+
+    def __call__(self, data):
+        img = data["image"]
+        height, width = img.shape[0:2]
+        ratio = self.max_len / (max(height, width) * 1.0)
+        resize_h = int(height * ratio)
+        resize_w = int(width * ratio)
+        resize_img = cv2.resize(img, (resize_w, resize_h))
+        if self.resize_bboxes and not self.infer_mode:
+            data["bboxes"] = data["bboxes"] * ratio
+        data["image"] = resize_img
+        data["src_img"] = img
+        data["shape"] = np.array([height, width, ratio, ratio])
+        data["max_len"] = self.max_len
+        return data
+
+
+class PaddingTableImage:
+    def __init__(self, size, **kwargs):
+        super(PaddingTableImage, self).__init__()
+        self.size = size
+
+    def __call__(self, data):
+        img = data["image"]
+        pad_h, pad_w = self.size
+        padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
+        height, width = img.shape[0:2]
+        padding_img[0:height, 0:width, :] = img.copy()
+        data["image"] = padding_img
+        shape = data["shape"].tolist()
+        shape.extend([pad_h, pad_w])
+        data["shape"] = np.array(shape)
+        return data
+
+
+class NormalizeImage:
+    """normalize image such as substract mean, divide std"""
+
+    def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+        if isinstance(scale, str):
+            scale = eval(scale)
+        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        mean = mean if mean is not None else [0.485, 0.456, 0.406]
+        std = std if std is not None else [0.229, 0.224, 0.225]
+
+        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+        self.mean = np.array(mean).reshape(shape).astype("float32")
+        self.std = np.array(std).reshape(shape).astype("float32")
+
+    def __call__(self, data):
+        img = np.array(data["image"])
+        assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+        data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
+        return data
+
+
+class ToCHWImage:
+    """convert hwc image to chw image"""
+
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, data):
+        img = np.array(data["image"])
+        data["image"] = img.transpose((2, 0, 1))
+        return data
+
+
+class KeepKeys:
+    def __init__(self, keep_keys, **kwargs):
+        self.keep_keys = keep_keys
+
+    def __call__(self, data):
+        data_list = []
+        for key in self.keep_keys:
+            data_list.append(data[key])
+        return data_list
+
+
+def trans_char_ocr_res(ocr_res):
+    word_result = []
+    for res in ocr_res:
+        score = res[2]
+        for word_box, word in zip(res[3], res[4]):
+            word_res = []
+            word_res.append(word_box)
+            word_res.append(word)
+            word_res.append(score)
+            word_result.append(word_res)
+    return word_result

+ 911 - 0
mineru/model/table/rec/slanet_plus/unitable_modules.py

@@ -0,0 +1,911 @@
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.modules.transformer import _get_activation_fn
+
+TOKEN_WHITE_LIST = [
+    1,
+    12,
+    13,
+    14,
+    15,
+    16,
+    17,
+    18,
+    19,
+    20,
+    21,
+    22,
+    23,
+    24,
+    25,
+    26,
+    27,
+    28,
+    29,
+    30,
+    31,
+    32,
+    33,
+    34,
+    35,
+    36,
+    37,
+    38,
+    39,
+    40,
+    41,
+    42,
+    43,
+    44,
+    45,
+    46,
+    47,
+    48,
+    49,
+    50,
+    51,
+    52,
+    53,
+    54,
+    55,
+    56,
+    57,
+    58,
+    59,
+    60,
+    61,
+    62,
+    63,
+    64,
+    65,
+    66,
+    67,
+    68,
+    69,
+    70,
+    71,
+    72,
+    73,
+    74,
+    75,
+    76,
+    77,
+    78,
+    79,
+    80,
+    81,
+    82,
+    83,
+    84,
+    85,
+    86,
+    87,
+    88,
+    89,
+    90,
+    91,
+    92,
+    93,
+    94,
+    95,
+    96,
+    97,
+    98,
+    99,
+    100,
+    101,
+    102,
+    103,
+    104,
+    105,
+    106,
+    107,
+    108,
+    109,
+    110,
+    111,
+    112,
+    113,
+    114,
+    115,
+    116,
+    117,
+    118,
+    119,
+    120,
+    121,
+    122,
+    123,
+    124,
+    125,
+    126,
+    127,
+    128,
+    129,
+    130,
+    131,
+    132,
+    133,
+    134,
+    135,
+    136,
+    137,
+    138,
+    139,
+    140,
+    141,
+    142,
+    143,
+    144,
+    145,
+    146,
+    147,
+    148,
+    149,
+    150,
+    151,
+    152,
+    153,
+    154,
+    155,
+    156,
+    157,
+    158,
+    159,
+    160,
+    161,
+    162,
+    163,
+    164,
+    165,
+    166,
+    167,
+    168,
+    169,
+    170,
+    171,
+    172,
+    173,
+    174,
+    175,
+    176,
+    177,
+    178,
+    179,
+    180,
+    181,
+    182,
+    183,
+    184,
+    185,
+    186,
+    187,
+    188,
+    189,
+    190,
+    191,
+    192,
+    193,
+    194,
+    195,
+    196,
+    197,
+    198,
+    199,
+    200,
+    201,
+    202,
+    203,
+    204,
+    205,
+    206,
+    207,
+    208,
+    209,
+    210,
+    211,
+    212,
+    213,
+    214,
+    215,
+    216,
+    217,
+    218,
+    219,
+    220,
+    221,
+    222,
+    223,
+    224,
+    225,
+    226,
+    227,
+    228,
+    229,
+    230,
+    231,
+    232,
+    233,
+    234,
+    235,
+    236,
+    237,
+    238,
+    239,
+    240,
+    241,
+    242,
+    243,
+    244,
+    245,
+    246,
+    247,
+    248,
+    249,
+    250,
+    251,
+    252,
+    253,
+    254,
+    255,
+    256,
+    257,
+    258,
+    259,
+    260,
+    261,
+    262,
+    263,
+    264,
+    265,
+    266,
+    267,
+    268,
+    269,
+    270,
+    271,
+    272,
+    273,
+    274,
+    275,
+    276,
+    277,
+    278,
+    279,
+    280,
+    281,
+    282,
+    283,
+    284,
+    285,
+    286,
+    287,
+    288,
+    289,
+    290,
+    291,
+    292,
+    293,
+    294,
+    295,
+    296,
+    297,
+    298,
+    299,
+    300,
+    301,
+    302,
+    303,
+    304,
+    305,
+    306,
+    307,
+    308,
+    309,
+    310,
+    311,
+    312,
+    313,
+    314,
+    315,
+    316,
+    317,
+    318,
+    319,
+    320,
+    321,
+    322,
+    323,
+    324,
+    325,
+    326,
+    327,
+    328,
+    329,
+    330,
+    331,
+    332,
+    333,
+    334,
+    335,
+    336,
+    337,
+    338,
+    339,
+    340,
+    341,
+    342,
+    343,
+    344,
+    345,
+    346,
+    347,
+    348,
+    349,
+    350,
+    351,
+    352,
+    353,
+    354,
+    355,
+    356,
+    357,
+    358,
+    359,
+    360,
+    361,
+    362,
+    363,
+    364,
+    365,
+    366,
+    367,
+    368,
+    369,
+    370,
+    371,
+    372,
+    373,
+    374,
+    375,
+    376,
+    377,
+    378,
+    379,
+    380,
+    381,
+    382,
+    383,
+    384,
+    385,
+    386,
+    387,
+    388,
+    389,
+    390,
+    391,
+    392,
+    393,
+    394,
+    395,
+    396,
+    397,
+    398,
+    399,
+    400,
+    401,
+    402,
+    403,
+    404,
+    405,
+    406,
+    407,
+    408,
+    409,
+    410,
+    411,
+    412,
+    413,
+    414,
+    415,
+    416,
+    417,
+    418,
+    419,
+    420,
+    421,
+    422,
+    423,
+    424,
+    425,
+    426,
+    427,
+    428,
+    429,
+    430,
+    431,
+    432,
+    433,
+    434,
+    435,
+    436,
+    437,
+    438,
+    439,
+    440,
+    441,
+    442,
+    443,
+    444,
+    445,
+    446,
+    447,
+    448,
+    449,
+    450,
+    451,
+    452,
+    453,
+    454,
+    455,
+    456,
+    457,
+    458,
+    459,
+    460,
+    461,
+    462,
+    463,
+    464,
+    465,
+    466,
+    467,
+    468,
+    469,
+    470,
+    471,
+    472,
+    473,
+    474,
+    475,
+    476,
+    477,
+    478,
+    479,
+    480,
+    481,
+    482,
+    483,
+    484,
+    485,
+    486,
+    487,
+    488,
+    489,
+    490,
+    491,
+    492,
+    493,
+    494,
+    495,
+    496,
+    497,
+    498,
+    499,
+    500,
+    501,
+    502,
+    503,
+    504,
+    505,
+    506,
+    507,
+    508,
+    509,
+]
+
+
+class ImgLinearBackbone(nn.Module):
+    def __init__(
+        self,
+        d_model: int,
+        patch_size: int,
+        in_chan: int = 3,
+    ) -> None:
+        super().__init__()
+
+        self.conv_proj = nn.Conv2d(
+            in_chan,
+            out_channels=d_model,
+            kernel_size=patch_size,
+            stride=patch_size,
+        )
+        self.d_model = d_model
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.conv_proj(x)
+        x = x.flatten(start_dim=-2).transpose(1, 2)
+        return x
+
+
+class Encoder(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+        self.patch_size = 16
+        self.d_model = 768
+        self.dropout = 0
+        self.activation = "gelu"
+        self.norm_first = True
+        self.ff_ratio = 4
+        self.nhead = 12
+        self.max_seq_len = 1024
+        self.n_encoder_layer = 12
+        encoder_layer = nn.TransformerEncoderLayer(
+            self.d_model,
+            nhead=self.nhead,
+            dim_feedforward=self.ff_ratio * self.d_model,
+            dropout=self.dropout,
+            activation=self.activation,
+            batch_first=True,
+            norm_first=self.norm_first,
+        )
+        norm_layer = partial(nn.LayerNorm, eps=1e-6)
+        self.norm = norm_layer(self.d_model)
+        self.backbone = ImgLinearBackbone(
+            d_model=self.d_model, patch_size=self.patch_size
+        )
+        self.pos_embed = PositionEmbedding(
+            max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+        )
+        self.encoder = nn.TransformerEncoder(
+            encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
+        )
+
+    def forward(self, x: Tensor) -> Tensor:
+        src_feature = self.backbone(x)
+        src_feature = self.pos_embed(src_feature)
+        memory = self.encoder(src_feature)
+        memory = self.norm(memory)
+        return memory
+
+
+class PositionEmbedding(nn.Module):
+    def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
+        super().__init__()
+        self.embedding = nn.Embedding(max_seq_len, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
+        # assume x is batch first
+        if input_pos is None:
+            _pos = torch.arange(x.shape[1], device=x.device)
+        else:
+            _pos = input_pos
+        out = self.embedding(_pos)
+        return self.dropout(out + x)
+
+
+class TokenEmbedding(nn.Module):
+    def __init__(
+        self,
+        vocab_size: int,
+        d_model: int,
+        padding_idx: int,
+    ) -> None:
+        super().__init__()
+        assert vocab_size > 0
+        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self.embedding(x)
+
+
+def find_multiple(n: int, k: int) -> int:
+    if n % k == 0:
+        return n
+    return n + k - (n % k)
+
+
+@dataclass
+class ModelArgs:
+    n_layer: int = 4
+    n_head: int = 12
+    dim: int = 768
+    intermediate_size: int = None
+    head_dim: int = 64
+    activation: str = "gelu"
+    norm_first: bool = True
+
+    def __post_init__(self):
+        if self.intermediate_size is None:
+            hidden_dim = 4 * self.dim
+            n_hidden = int(2 * hidden_dim / 3)
+            self.intermediate_size = find_multiple(n_hidden, 256)
+        self.head_dim = self.dim // self.n_head
+
+
+class KVCache(nn.Module):
+    def __init__(
+        self,
+        max_batch_size,
+        max_seq_length,
+        n_heads,
+        head_dim,
+        dtype=torch.bfloat16,
+        device="cpu",
+    ):
+        super().__init__()
+        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
+        self.register_buffer(
+            "k_cache",
+            torch.zeros(cache_shape, dtype=dtype, device=device),
+            persistent=False,
+        )
+        self.register_buffer(
+            "v_cache",
+            torch.zeros(cache_shape, dtype=dtype, device=device),
+            persistent=False,
+        )
+
+    def update(self, input_pos, k_val, v_val):
+        # input_pos: [S], k_val: [B, H, S, D]
+        # assert input_pos.shape[0] == k_val.shape[2]
+
+        bs = k_val.shape[0]
+        k_out = self.k_cache
+        v_out = self.v_cache
+        k_out[:bs, :, input_pos] = k_val
+        v_out[:bs, :, input_pos] = v_val
+
+        return k_out[:bs], v_out[:bs]
+
+
+class GPTFastDecoder(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+        self.vocab_size = 960
+        self.padding_idx = 2
+        self.prefix_token_id = 11
+        self.eos_id = 1
+        self.max_seq_len = 1024
+        self.dropout = 0
+        self.d_model = 768
+        self.nhead = 12
+        self.activation = "gelu"
+        self.norm_first = True
+        self.n_decoder_layer = 4
+        config = ModelArgs(
+            n_layer=self.n_decoder_layer,
+            n_head=self.nhead,
+            dim=self.d_model,
+            intermediate_size=self.d_model * 4,
+            activation=self.activation,
+            norm_first=self.norm_first,
+        )
+        self.config = config
+        self.layers = nn.ModuleList(
+            TransformerBlock(config) for _ in range(config.n_layer)
+        )
+        self.token_embed = TokenEmbedding(
+            vocab_size=self.vocab_size,
+            d_model=self.d_model,
+            padding_idx=self.padding_idx,
+        )
+        self.pos_embed = PositionEmbedding(
+            max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+        )
+        self.generator = nn.Linear(self.d_model, self.vocab_size)
+        self.token_white_list = TOKEN_WHITE_LIST
+        self.mask_cache: Optional[Tensor] = None
+        self.max_batch_size = -1
+        self.max_seq_length = -1
+
+    def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
+        for b in self.layers:
+            b.multihead_attn.k_cache = None
+            b.multihead_attn.v_cache = None
+
+        if (
+            self.max_seq_length >= max_seq_length
+            and self.max_batch_size >= max_batch_size
+        ):
+            return
+        head_dim = self.config.dim // self.config.n_head
+        max_seq_length = find_multiple(max_seq_length, 8)
+        self.max_seq_length = max_seq_length
+        self.max_batch_size = max_batch_size
+
+        for b in self.layers:
+            b.self_attn.kv_cache = KVCache(
+                max_batch_size,
+                max_seq_length,
+                self.config.n_head,
+                head_dim,
+                dtype,
+                device,
+            )
+            b.multihead_attn.k_cache = None
+            b.multihead_attn.v_cache = None
+
+        self.causal_mask = torch.tril(
+            torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+        ).to(device)
+
+    def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
+        input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
+        tgt = tgt[:, -1:]
+        tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
+        # tgt = self.decoder(tgt_feature, memory, input_pos)
+        with torch.backends.cuda.sdp_kernel(
+            enable_flash=False, enable_mem_efficient=False, enable_math=True
+        ):
+            logits = tgt_feature
+            tgt_mask = self.causal_mask[None, None, input_pos]
+            for i, layer in enumerate(self.layers):
+                logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
+        # return output
+        logits = self.generator(logits)[:, -1, :]
+        total = set([i for i in range(logits.shape[-1])])
+        black_list = list(total.difference(set(self.token_white_list)))
+        logits[..., black_list] = -1e9
+        probs = F.softmax(logits, dim=-1)
+        _, next_tokens = probs.topk(1)
+        return next_tokens
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, config: ModelArgs) -> None:
+        super().__init__()
+        self.self_attn = Attention(config)
+        self.multihead_attn = CrossAttention(config)
+
+        layer_norm_eps = 1e-5
+
+        d_model = config.dim
+        dim_feedforward = config.intermediate_size
+
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm_first = config.norm_first
+        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+
+        self.activation = _get_activation_fn(config.activation)
+
+    def forward(
+        self,
+        tgt: Tensor,
+        memory: Tensor,
+        tgt_mask: Tensor,
+        input_pos: Tensor,
+    ) -> Tensor:
+        if self.norm_first:
+            x = tgt
+            x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
+            x = x + self.multihead_attn(self.norm2(x), memory)
+            x = x + self._ff_block(self.norm3(x))
+        else:
+            x = tgt
+            x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
+            x = self.norm2(x + self.multihead_attn(x, memory))
+            x = self.norm3(x + self._ff_block(x))
+        return x
+
+    def _ff_block(self, x: Tensor) -> Tensor:
+        x = self.linear2(self.activation(self.linear1(x)))
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, config: ModelArgs):
+        super().__init__()
+        assert config.dim % config.n_head == 0
+
+        # key, query, value projections for all heads, but in a batch
+        self.wqkv = nn.Linear(config.dim, 3 * config.dim)
+        self.wo = nn.Linear(config.dim, config.dim)
+
+        self.kv_cache: Optional[KVCache] = None
+
+        self.n_head = config.n_head
+        self.head_dim = config.head_dim
+        self.dim = config.dim
+
+    def forward(
+        self,
+        x: Tensor,
+        mask: Tensor,
+        input_pos: Optional[Tensor] = None,
+    ) -> Tensor:
+        bsz, seqlen, _ = x.shape
+
+        kv_size = self.n_head * self.head_dim
+        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+        k = k.view(bsz, seqlen, self.n_head, self.head_dim)
+        v = v.view(bsz, seqlen, self.n_head, self.head_dim)
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+        if self.kv_cache is not None:
+            k, v = self.kv_cache.update(input_pos, k, v)
+
+        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+        y = self.wo(y)
+        return y
+
+
+class CrossAttention(nn.Module):
+    def __init__(self, config: ModelArgs):
+        super().__init__()
+        assert config.dim % config.n_head == 0
+
+        self.query = nn.Linear(config.dim, config.dim)
+        self.key = nn.Linear(config.dim, config.dim)
+        self.value = nn.Linear(config.dim, config.dim)
+        self.out = nn.Linear(config.dim, config.dim)
+
+        self.k_cache = None
+        self.v_cache = None
+
+        self.n_head = config.n_head
+        self.head_dim = config.head_dim
+
+    def get_kv(self, xa: torch.Tensor):
+        if self.k_cache is not None and self.v_cache is not None:
+            return self.k_cache, self.v_cache
+
+        k = self.key(xa)
+        v = self.value(xa)
+
+        # Reshape for correct format
+        batch_size, source_seq_len, _ = k.shape
+        k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
+        v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
+
+        if self.k_cache is None:
+            self.k_cache = k
+        if self.v_cache is None:
+            self.v_cache = v
+
+        return k, v
+
+    def forward(
+        self,
+        x: Tensor,
+        xa: Tensor,
+    ):
+        q = self.query(x)
+        batch_size, target_seq_len, _ = q.shape
+        q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
+        k, v = self.get_kv(xa)
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+        wv = F.scaled_dot_product_attention(
+            query=q,
+            key=k,
+            value=v,
+            is_causal=False,
+        )
+        wv = wv.transpose(1, 2).reshape(
+            batch_size,
+            target_seq_len,
+            self.n_head * self.head_dim,
+        )
+
+        return self.out(wv)