Parcourir la source

fix: remove unitable

Sidney233 il y a 3 mois
Parent
commit
58cccf0825

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
-from ...model.table.rec.slanet_plus.rapid_table import RapidTableModel
+from ...model.table.rec.slanet_plus.main import RapidTableModel
 from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path

+ 50 - 37
mineru/model/table/rec/slanet_plus/main.py

@@ -1,10 +1,9 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
+import os
 import argparse
 import copy
 import importlib
 import time
+import html
 from dataclasses import asdict, dataclass
 from enum import Enum
 from pathlib import Path
@@ -12,10 +11,11 @@ from typing import Dict, List, Optional, Tuple, Union
 
 import cv2
 import numpy as np
-
+from loguru import logger
 from .matcher import TableMatch
 from .table_structure import TableStructurer
-from .table_structure_unitable import TableStructureUnitable
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 root_dir = Path(__file__).resolve().parent
 
@@ -66,11 +66,10 @@ class RapidTable:
             )
 
         config.model_path = config.model_path
-        if self.model_type == ModelType.UNITABLE.value:
-            self.table_structure = TableStructureUnitable(asdict(config))
-        else:
+        if self.model_type == ModelType.SLANETPLUS.value:
             self.table_structure = TableStructurer(asdict(config))
-
+        else:
+            raise ValueError(f"{self.model_type} is not supported.")
         self.table_matcher = TableMatch()
 
         try:
@@ -78,7 +77,6 @@ class RapidTable:
         except ModuleNotFoundError:
             self.ocr_engine = None
 
-
     def __call__(
         self,
         img: np.ndarray,
@@ -148,7 +146,6 @@ class RapidTable:
         return cell_bboxes
 
 
-
 def parse_args(arg_list: Optional[List[str]] = None):
     parser = argparse.ArgumentParser()
     parser.add_argument(
@@ -172,29 +169,45 @@ def parse_args(arg_list: Optional[List[str]] = None):
     return args
 
 
-def main(arg_list: Optional[List[str]] = None):
-    args = parse_args(arg_list)
-
-    try:
-        ocr_engine = importlib.import_module("rapidocr").RapidOCR()
-    except ModuleNotFoundError as exc:
-        raise ModuleNotFoundError(
-            "Please install the rapidocr by pip install rapidocr"
-        ) from exc
-
-    input_args = RapidTableInput(model_type=args.model_type)
-    table_engine = RapidTable(input_args)
-
-    img = cv2.imread(args.img_path)
-
-    rapid_ocr_output = ocr_engine(img)
-    ocr_result = list(
-        zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
-    )
-    table_results = table_engine(img, ocr_result)
-    print(table_results.pred_html)
-
-
-
-if __name__ == "__main__":
-    main()
+def escape_html(input_string):
+    """Escape HTML Entities."""
+    return html.escape(input_string)
+
+
+class RapidTableModel(object):
+    def __init__(self, ocr_engine):
+        slanet_plus_model_path = os.path.join(
+            auto_download_and_get_model_root_path(ModelPath.slanet_plus),
+            ModelPath.slanet_plus,
+        )
+        input_args = RapidTableInput(
+            model_type="slanet_plus", model_path=slanet_plus_model_path
+        )
+        self.table_model = RapidTable(input_args)
+        self.ocr_engine = ocr_engine
+
+    def predict(self, image, table_cls_score):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        # Continue with OCR on potentially rotated image
+        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+        if ocr_result:
+            ocr_result = [
+                [item[0], escape_html(item[1][0]), item[1][1]]
+                for item in ocr_result
+                if len(item) == 2 and isinstance(item[1], tuple)
+            ]
+        else:
+            ocr_result = None
+
+        if ocr_result:
+            try:
+                table_results = self.table_model(np.asarray(image), ocr_result)
+                html_code = table_results.pred_html
+                table_cell_bboxes = table_results.cell_bboxes
+                logic_points = table_results.logic_points
+                elapse = table_results.elapse
+                return html_code, table_cell_bboxes, logic_points, elapse
+            except Exception as e:
+                logger.exception(e)
+
+        return None, None, None, None

+ 0 - 51
mineru/model/table/rec/slanet_plus/rapid_table.py

@@ -1,51 +0,0 @@
-import os
-import html
-from typing import List
-
-import cv2
-import numpy as np
-from loguru import logger
-from .main import RapidTable, RapidTableInput
-
-from mineru.utils.enum_class import ModelPath
-from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
-
-
-def escape_html(input_string):
-    """Escape HTML Entities."""
-    return html.escape(input_string)
-
-
-class RapidTableModel(object):
-    def __init__(self, ocr_engine):
-        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
-        input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
-        self.table_model = RapidTable(input_args)
-        self.ocr_engine = ocr_engine
-
-    def predict(self, image, table_cls_score):
-        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
-        # Continue with OCR on potentially rotated image
-        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-        if ocr_result:
-            ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
-                      len(item) == 2 and isinstance(item[1], tuple)]
-        else:
-            ocr_result = None
-
-        if ocr_result:
-            try:
-                table_results = self.table_model(np.asarray(image), ocr_result)
-                html_code = table_results.pred_html
-                table_cell_bboxes = table_results.cell_bboxes
-                logic_points = table_results.logic_points
-                elapse = table_results.elapse
-                return html_code, table_cell_bboxes, logic_points, elapse
-            except Exception as e:
-                logger.exception(e)
-
-        return None, None, None, None
-
-    def batch_predict(self, images: List[np.ndarray], batch_size: int = 1):
-        # TODO: ocr也需要batch
-        pass

+ 0 - 229
mineru/model/table/rec/slanet_plus/table_structure_unitable.py

@@ -1,229 +0,0 @@
-import re
-import time
-
-import cv2
-import numpy as np
-import torch
-from PIL import Image
-from tokenizers import Tokenizer
-from torchvision import transforms
-
-from .unitable_modules import Encoder, GPTFastDecoder
-
-IMG_SIZE = 448
-EOS_TOKEN = "<eos>"
-BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)]
-
-HTML_BBOX_HTML_TOKENS = [
-    "<td></td>",
-    "<td>[",
-    "]</td>",
-    "<td",
-    ">[",
-    "></td>",
-    "<tr>",
-    "</tr>",
-    "<tbody>",
-    "</tbody>",
-    "<thead>",
-    "</thead>",
-    ' rowspan="2"',
-    ' rowspan="3"',
-    ' rowspan="4"',
-    ' rowspan="5"',
-    ' rowspan="6"',
-    ' rowspan="7"',
-    ' rowspan="8"',
-    ' rowspan="9"',
-    ' rowspan="10"',
-    ' rowspan="11"',
-    ' rowspan="12"',
-    ' rowspan="13"',
-    ' rowspan="14"',
-    ' rowspan="15"',
-    ' rowspan="16"',
-    ' rowspan="17"',
-    ' rowspan="18"',
-    ' rowspan="19"',
-    ' colspan="2"',
-    ' colspan="3"',
-    ' colspan="4"',
-    ' colspan="5"',
-    ' colspan="6"',
-    ' colspan="7"',
-    ' colspan="8"',
-    ' colspan="9"',
-    ' colspan="10"',
-    ' colspan="11"',
-    ' colspan="12"',
-    ' colspan="13"',
-    ' colspan="14"',
-    ' colspan="15"',
-    ' colspan="16"',
-    ' colspan="17"',
-    ' colspan="18"',
-    ' colspan="19"',
-    ' colspan="25"',
-]
-
-VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS
-TASK_TOKENS = [
-    "[table]",
-    "[html]",
-    "[cell]",
-    "[bbox]",
-    "[cell+bbox]",
-    "[html+bbox]",
-]
-
-
-class TableStructureUnitable:
-    def __init__(self, config):
-        # encoder_path: str, decoder_path: str, vocab_path: str, device: str
-        vocab_path = config["model_path"]["vocab"]
-        encoder_path = config["model_path"]["encoder"]
-        decoder_path = config["model_path"]["decoder"]
-        device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu"
-
-        self.vocab = Tokenizer.from_file(vocab_path)
-        self.token_white_list = [
-            self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS
-        ]
-        self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS)
-        self.bbox_close_html_token = self.vocab.token_to_id("]</td>")
-        self.prefix_token_id = self.vocab.token_to_id("[html+bbox]")
-        self.eos_id = self.vocab.token_to_id(EOS_TOKEN)
-        self.max_seq_len = 1024
-        self.device = device
-        self.img_size = IMG_SIZE
-
-        # init encoder
-        encoder_state_dict = torch.load(encoder_path, map_location=device)
-        self.encoder = Encoder()
-        self.encoder.load_state_dict(encoder_state_dict)
-        self.encoder.eval().to(device)
-
-        # init decoder
-        decoder_state_dict = torch.load(decoder_path, map_location=device)
-        self.decoder = GPTFastDecoder()
-        self.decoder.load_state_dict(decoder_state_dict)
-        self.decoder.eval().to(device)
-
-        # define img transform
-        self.transform = transforms.Compose(
-            [
-                transforms.Resize((448, 448)),
-                transforms.ToTensor(),
-                transforms.Normalize(
-                    mean=[0.86597056, 0.88463002, 0.87491087],
-                    std=[0.20686628, 0.18201602, 0.18485524],
-                ),
-            ]
-        )
-
-    @torch.inference_mode()
-    def __call__(self, image: np.ndarray):
-        start_time = time.time()
-        ori_h, ori_w = image.shape[:2]
-        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
-        image = Image.fromarray(image)
-        image = self.transform(image).unsqueeze(0).to(self.device)
-        self.decoder.setup_caches(
-            max_batch_size=1,
-            max_seq_length=self.max_seq_len,
-            dtype=image.dtype,
-            device=self.device,
-        )
-        context = (
-            torch.tensor([self.prefix_token_id], dtype=torch.int32)
-            .repeat(1, 1)
-            .to(self.device)
-        )
-        eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device)
-        memory = self.encoder(image)
-        context = self.loop_decode(context, eos_id_tensor, memory)
-        bboxes, html_tokens = self.decode_tokens(context)
-        bboxes = bboxes.astype(np.float32)
-
-        # rescale boxes
-        scale_h = ori_h / self.img_size
-        scale_w = ori_w / self.img_size
-        bboxes[:, 0::2] *= scale_w  # 缩放 x 坐标
-        bboxes[:, 1::2] *= scale_h  # 缩放 y 坐标
-        bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1)
-        bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1)
-        structure_str_list = (
-            ["<html>", "<body>", "<table>"]
-            + html_tokens
-            + ["</table>", "</body>", "</html>"]
-        )
-        return structure_str_list, bboxes, time.time() - start_time
-
-    def decode_tokens(self, context):
-        pred_html = context[0]
-        pred_html = pred_html.detach().cpu().numpy()
-        pred_html = self.vocab.decode(pred_html, skip_special_tokens=False)
-        seq = pred_html.split("<eos>")[0]
-        token_black_list = ["<eos>", "<pad>", *TASK_TOKENS]
-        for i in token_black_list:
-            seq = seq.replace(i, "")
-
-        tr_pattern = re.compile(r"<tr>(.*?)</tr>", re.DOTALL)
-        td_pattern = re.compile(r"<td(.*?)>(.*?)</td>", re.DOTALL)
-        bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]")
-
-        decoded_list = []
-        bbox_coords = []
-
-        # 查找所有的 <tr> 标签
-        for tr_match in tr_pattern.finditer(pred_html):
-            tr_content = tr_match.group(1)
-            decoded_list.append("<tr>")
-
-            # 查找所有的 <td> 标签
-            for td_match in td_pattern.finditer(tr_content):
-                td_attrs = td_match.group(1).strip()
-                td_content = td_match.group(2).strip()
-                if td_attrs:
-                    decoded_list.append("<td")
-                    # 可能同时存在行列合并,需要都添加
-                    attrs_list = td_attrs.split()
-                    for attr in attrs_list:
-                        decoded_list.append(" " + attr)
-                    decoded_list.append(">")
-                    decoded_list.append("</td>")
-                else:
-                    decoded_list.append("<td></td>")
-
-                # 查找 bbox 坐标
-                bbox_match = bbox_pattern.search(td_content)
-                if bbox_match:
-                    xmin, ymin, xmax, ymax = map(int, bbox_match.groups())
-                    # 将坐标转换为从左上角开始顺时针到左下角的点的坐标
-                    coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
-                    bbox_coords.append(coords)
-                else:
-                    # 填充占位的bbox,保证后续流程统一
-                    bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
-            decoded_list.append("</tr>")
-
-        bbox_coords_array = np.array(bbox_coords)
-        return bbox_coords_array, decoded_list
-
-    def loop_decode(self, context, eos_id_tensor, memory):
-        box_token_count = 0
-        for _ in range(self.max_seq_len):
-            eos_flag = (context == eos_id_tensor).any(dim=1)
-            if torch.all(eos_flag):
-                break
-
-            next_tokens = self.decoder(memory, context)
-            if next_tokens[0] in self.bbox_token_ids:
-                box_token_count += 1
-                if box_token_count > 4:
-                    next_tokens = torch.tensor(
-                        [self.bbox_close_html_token], dtype=torch.int32
-                    )
-                    box_token_count = 0
-            context = torch.cat([context, next_tokens], dim=1)
-        return context

+ 0 - 911
mineru/model/table/rec/slanet_plus/unitable_modules.py

@@ -1,911 +0,0 @@
-from dataclasses import dataclass
-from functools import partial
-from typing import Optional
-
-import torch
-import torch.nn as nn
-from torch import Tensor
-from torch.nn import functional as F
-from torch.nn.modules.transformer import _get_activation_fn
-
-TOKEN_WHITE_LIST = [
-    1,
-    12,
-    13,
-    14,
-    15,
-    16,
-    17,
-    18,
-    19,
-    20,
-    21,
-    22,
-    23,
-    24,
-    25,
-    26,
-    27,
-    28,
-    29,
-    30,
-    31,
-    32,
-    33,
-    34,
-    35,
-    36,
-    37,
-    38,
-    39,
-    40,
-    41,
-    42,
-    43,
-    44,
-    45,
-    46,
-    47,
-    48,
-    49,
-    50,
-    51,
-    52,
-    53,
-    54,
-    55,
-    56,
-    57,
-    58,
-    59,
-    60,
-    61,
-    62,
-    63,
-    64,
-    65,
-    66,
-    67,
-    68,
-    69,
-    70,
-    71,
-    72,
-    73,
-    74,
-    75,
-    76,
-    77,
-    78,
-    79,
-    80,
-    81,
-    82,
-    83,
-    84,
-    85,
-    86,
-    87,
-    88,
-    89,
-    90,
-    91,
-    92,
-    93,
-    94,
-    95,
-    96,
-    97,
-    98,
-    99,
-    100,
-    101,
-    102,
-    103,
-    104,
-    105,
-    106,
-    107,
-    108,
-    109,
-    110,
-    111,
-    112,
-    113,
-    114,
-    115,
-    116,
-    117,
-    118,
-    119,
-    120,
-    121,
-    122,
-    123,
-    124,
-    125,
-    126,
-    127,
-    128,
-    129,
-    130,
-    131,
-    132,
-    133,
-    134,
-    135,
-    136,
-    137,
-    138,
-    139,
-    140,
-    141,
-    142,
-    143,
-    144,
-    145,
-    146,
-    147,
-    148,
-    149,
-    150,
-    151,
-    152,
-    153,
-    154,
-    155,
-    156,
-    157,
-    158,
-    159,
-    160,
-    161,
-    162,
-    163,
-    164,
-    165,
-    166,
-    167,
-    168,
-    169,
-    170,
-    171,
-    172,
-    173,
-    174,
-    175,
-    176,
-    177,
-    178,
-    179,
-    180,
-    181,
-    182,
-    183,
-    184,
-    185,
-    186,
-    187,
-    188,
-    189,
-    190,
-    191,
-    192,
-    193,
-    194,
-    195,
-    196,
-    197,
-    198,
-    199,
-    200,
-    201,
-    202,
-    203,
-    204,
-    205,
-    206,
-    207,
-    208,
-    209,
-    210,
-    211,
-    212,
-    213,
-    214,
-    215,
-    216,
-    217,
-    218,
-    219,
-    220,
-    221,
-    222,
-    223,
-    224,
-    225,
-    226,
-    227,
-    228,
-    229,
-    230,
-    231,
-    232,
-    233,
-    234,
-    235,
-    236,
-    237,
-    238,
-    239,
-    240,
-    241,
-    242,
-    243,
-    244,
-    245,
-    246,
-    247,
-    248,
-    249,
-    250,
-    251,
-    252,
-    253,
-    254,
-    255,
-    256,
-    257,
-    258,
-    259,
-    260,
-    261,
-    262,
-    263,
-    264,
-    265,
-    266,
-    267,
-    268,
-    269,
-    270,
-    271,
-    272,
-    273,
-    274,
-    275,
-    276,
-    277,
-    278,
-    279,
-    280,
-    281,
-    282,
-    283,
-    284,
-    285,
-    286,
-    287,
-    288,
-    289,
-    290,
-    291,
-    292,
-    293,
-    294,
-    295,
-    296,
-    297,
-    298,
-    299,
-    300,
-    301,
-    302,
-    303,
-    304,
-    305,
-    306,
-    307,
-    308,
-    309,
-    310,
-    311,
-    312,
-    313,
-    314,
-    315,
-    316,
-    317,
-    318,
-    319,
-    320,
-    321,
-    322,
-    323,
-    324,
-    325,
-    326,
-    327,
-    328,
-    329,
-    330,
-    331,
-    332,
-    333,
-    334,
-    335,
-    336,
-    337,
-    338,
-    339,
-    340,
-    341,
-    342,
-    343,
-    344,
-    345,
-    346,
-    347,
-    348,
-    349,
-    350,
-    351,
-    352,
-    353,
-    354,
-    355,
-    356,
-    357,
-    358,
-    359,
-    360,
-    361,
-    362,
-    363,
-    364,
-    365,
-    366,
-    367,
-    368,
-    369,
-    370,
-    371,
-    372,
-    373,
-    374,
-    375,
-    376,
-    377,
-    378,
-    379,
-    380,
-    381,
-    382,
-    383,
-    384,
-    385,
-    386,
-    387,
-    388,
-    389,
-    390,
-    391,
-    392,
-    393,
-    394,
-    395,
-    396,
-    397,
-    398,
-    399,
-    400,
-    401,
-    402,
-    403,
-    404,
-    405,
-    406,
-    407,
-    408,
-    409,
-    410,
-    411,
-    412,
-    413,
-    414,
-    415,
-    416,
-    417,
-    418,
-    419,
-    420,
-    421,
-    422,
-    423,
-    424,
-    425,
-    426,
-    427,
-    428,
-    429,
-    430,
-    431,
-    432,
-    433,
-    434,
-    435,
-    436,
-    437,
-    438,
-    439,
-    440,
-    441,
-    442,
-    443,
-    444,
-    445,
-    446,
-    447,
-    448,
-    449,
-    450,
-    451,
-    452,
-    453,
-    454,
-    455,
-    456,
-    457,
-    458,
-    459,
-    460,
-    461,
-    462,
-    463,
-    464,
-    465,
-    466,
-    467,
-    468,
-    469,
-    470,
-    471,
-    472,
-    473,
-    474,
-    475,
-    476,
-    477,
-    478,
-    479,
-    480,
-    481,
-    482,
-    483,
-    484,
-    485,
-    486,
-    487,
-    488,
-    489,
-    490,
-    491,
-    492,
-    493,
-    494,
-    495,
-    496,
-    497,
-    498,
-    499,
-    500,
-    501,
-    502,
-    503,
-    504,
-    505,
-    506,
-    507,
-    508,
-    509,
-]
-
-
-class ImgLinearBackbone(nn.Module):
-    def __init__(
-        self,
-        d_model: int,
-        patch_size: int,
-        in_chan: int = 3,
-    ) -> None:
-        super().__init__()
-
-        self.conv_proj = nn.Conv2d(
-            in_chan,
-            out_channels=d_model,
-            kernel_size=patch_size,
-            stride=patch_size,
-        )
-        self.d_model = d_model
-
-    def forward(self, x: Tensor) -> Tensor:
-        x = self.conv_proj(x)
-        x = x.flatten(start_dim=-2).transpose(1, 2)
-        return x
-
-
-class Encoder(nn.Module):
-    def __init__(self) -> None:
-        super().__init__()
-
-        self.patch_size = 16
-        self.d_model = 768
-        self.dropout = 0
-        self.activation = "gelu"
-        self.norm_first = True
-        self.ff_ratio = 4
-        self.nhead = 12
-        self.max_seq_len = 1024
-        self.n_encoder_layer = 12
-        encoder_layer = nn.TransformerEncoderLayer(
-            self.d_model,
-            nhead=self.nhead,
-            dim_feedforward=self.ff_ratio * self.d_model,
-            dropout=self.dropout,
-            activation=self.activation,
-            batch_first=True,
-            norm_first=self.norm_first,
-        )
-        norm_layer = partial(nn.LayerNorm, eps=1e-6)
-        self.norm = norm_layer(self.d_model)
-        self.backbone = ImgLinearBackbone(
-            d_model=self.d_model, patch_size=self.patch_size
-        )
-        self.pos_embed = PositionEmbedding(
-            max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
-        )
-        self.encoder = nn.TransformerEncoder(
-            encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
-        )
-
-    def forward(self, x: Tensor) -> Tensor:
-        src_feature = self.backbone(x)
-        src_feature = self.pos_embed(src_feature)
-        memory = self.encoder(src_feature)
-        memory = self.norm(memory)
-        return memory
-
-
-class PositionEmbedding(nn.Module):
-    def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
-        super().__init__()
-        self.embedding = nn.Embedding(max_seq_len, d_model)
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
-        # assume x is batch first
-        if input_pos is None:
-            _pos = torch.arange(x.shape[1], device=x.device)
-        else:
-            _pos = input_pos
-        out = self.embedding(_pos)
-        return self.dropout(out + x)
-
-
-class TokenEmbedding(nn.Module):
-    def __init__(
-        self,
-        vocab_size: int,
-        d_model: int,
-        padding_idx: int,
-    ) -> None:
-        super().__init__()
-        assert vocab_size > 0
-        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
-
-    def forward(self, x: Tensor) -> Tensor:
-        return self.embedding(x)
-
-
-def find_multiple(n: int, k: int) -> int:
-    if n % k == 0:
-        return n
-    return n + k - (n % k)
-
-
-@dataclass
-class ModelArgs:
-    n_layer: int = 4
-    n_head: int = 12
-    dim: int = 768
-    intermediate_size: int = None
-    head_dim: int = 64
-    activation: str = "gelu"
-    norm_first: bool = True
-
-    def __post_init__(self):
-        if self.intermediate_size is None:
-            hidden_dim = 4 * self.dim
-            n_hidden = int(2 * hidden_dim / 3)
-            self.intermediate_size = find_multiple(n_hidden, 256)
-        self.head_dim = self.dim // self.n_head
-
-
-class KVCache(nn.Module):
-    def __init__(
-        self,
-        max_batch_size,
-        max_seq_length,
-        n_heads,
-        head_dim,
-        dtype=torch.bfloat16,
-        device="cpu",
-    ):
-        super().__init__()
-        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
-        self.register_buffer(
-            "k_cache",
-            torch.zeros(cache_shape, dtype=dtype, device=device),
-            persistent=False,
-        )
-        self.register_buffer(
-            "v_cache",
-            torch.zeros(cache_shape, dtype=dtype, device=device),
-            persistent=False,
-        )
-
-    def update(self, input_pos, k_val, v_val):
-        # input_pos: [S], k_val: [B, H, S, D]
-        # assert input_pos.shape[0] == k_val.shape[2]
-
-        bs = k_val.shape[0]
-        k_out = self.k_cache
-        v_out = self.v_cache
-        k_out[:bs, :, input_pos] = k_val
-        v_out[:bs, :, input_pos] = v_val
-
-        return k_out[:bs], v_out[:bs]
-
-
-class GPTFastDecoder(nn.Module):
-    def __init__(self) -> None:
-        super().__init__()
-
-        self.vocab_size = 960
-        self.padding_idx = 2
-        self.prefix_token_id = 11
-        self.eos_id = 1
-        self.max_seq_len = 1024
-        self.dropout = 0
-        self.d_model = 768
-        self.nhead = 12
-        self.activation = "gelu"
-        self.norm_first = True
-        self.n_decoder_layer = 4
-        config = ModelArgs(
-            n_layer=self.n_decoder_layer,
-            n_head=self.nhead,
-            dim=self.d_model,
-            intermediate_size=self.d_model * 4,
-            activation=self.activation,
-            norm_first=self.norm_first,
-        )
-        self.config = config
-        self.layers = nn.ModuleList(
-            TransformerBlock(config) for _ in range(config.n_layer)
-        )
-        self.token_embed = TokenEmbedding(
-            vocab_size=self.vocab_size,
-            d_model=self.d_model,
-            padding_idx=self.padding_idx,
-        )
-        self.pos_embed = PositionEmbedding(
-            max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
-        )
-        self.generator = nn.Linear(self.d_model, self.vocab_size)
-        self.token_white_list = TOKEN_WHITE_LIST
-        self.mask_cache: Optional[Tensor] = None
-        self.max_batch_size = -1
-        self.max_seq_length = -1
-
-    def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
-        for b in self.layers:
-            b.multihead_attn.k_cache = None
-            b.multihead_attn.v_cache = None
-
-        if (
-            self.max_seq_length >= max_seq_length
-            and self.max_batch_size >= max_batch_size
-        ):
-            return
-        head_dim = self.config.dim // self.config.n_head
-        max_seq_length = find_multiple(max_seq_length, 8)
-        self.max_seq_length = max_seq_length
-        self.max_batch_size = max_batch_size
-
-        for b in self.layers:
-            b.self_attn.kv_cache = KVCache(
-                max_batch_size,
-                max_seq_length,
-                self.config.n_head,
-                head_dim,
-                dtype,
-                device,
-            )
-            b.multihead_attn.k_cache = None
-            b.multihead_attn.v_cache = None
-
-        self.causal_mask = torch.tril(
-            torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
-        ).to(device)
-
-    def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
-        input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
-        tgt = tgt[:, -1:]
-        tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
-        # tgt = self.decoder(tgt_feature, memory, input_pos)
-        with torch.backends.cuda.sdp_kernel(
-            enable_flash=False, enable_mem_efficient=False, enable_math=True
-        ):
-            logits = tgt_feature
-            tgt_mask = self.causal_mask[None, None, input_pos]
-            for i, layer in enumerate(self.layers):
-                logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
-        # return output
-        logits = self.generator(logits)[:, -1, :]
-        total = set([i for i in range(logits.shape[-1])])
-        black_list = list(total.difference(set(self.token_white_list)))
-        logits[..., black_list] = -1e9
-        probs = F.softmax(logits, dim=-1)
-        _, next_tokens = probs.topk(1)
-        return next_tokens
-
-
-class TransformerBlock(nn.Module):
-    def __init__(self, config: ModelArgs) -> None:
-        super().__init__()
-        self.self_attn = Attention(config)
-        self.multihead_attn = CrossAttention(config)
-
-        layer_norm_eps = 1e-5
-
-        d_model = config.dim
-        dim_feedforward = config.intermediate_size
-
-        self.linear1 = nn.Linear(d_model, dim_feedforward)
-        self.linear2 = nn.Linear(dim_feedforward, d_model)
-
-        self.norm_first = config.norm_first
-        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
-        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
-        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
-
-        self.activation = _get_activation_fn(config.activation)
-
-    def forward(
-        self,
-        tgt: Tensor,
-        memory: Tensor,
-        tgt_mask: Tensor,
-        input_pos: Tensor,
-    ) -> Tensor:
-        if self.norm_first:
-            x = tgt
-            x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
-            x = x + self.multihead_attn(self.norm2(x), memory)
-            x = x + self._ff_block(self.norm3(x))
-        else:
-            x = tgt
-            x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
-            x = self.norm2(x + self.multihead_attn(x, memory))
-            x = self.norm3(x + self._ff_block(x))
-        return x
-
-    def _ff_block(self, x: Tensor) -> Tensor:
-        x = self.linear2(self.activation(self.linear1(x)))
-        return x
-
-
-class Attention(nn.Module):
-    def __init__(self, config: ModelArgs):
-        super().__init__()
-        assert config.dim % config.n_head == 0
-
-        # key, query, value projections for all heads, but in a batch
-        self.wqkv = nn.Linear(config.dim, 3 * config.dim)
-        self.wo = nn.Linear(config.dim, config.dim)
-
-        self.kv_cache: Optional[KVCache] = None
-
-        self.n_head = config.n_head
-        self.head_dim = config.head_dim
-        self.dim = config.dim
-
-    def forward(
-        self,
-        x: Tensor,
-        mask: Tensor,
-        input_pos: Optional[Tensor] = None,
-    ) -> Tensor:
-        bsz, seqlen, _ = x.shape
-
-        kv_size = self.n_head * self.head_dim
-        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
-
-        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
-        k = k.view(bsz, seqlen, self.n_head, self.head_dim)
-        v = v.view(bsz, seqlen, self.n_head, self.head_dim)
-
-        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
-
-        if self.kv_cache is not None:
-            k, v = self.kv_cache.update(input_pos, k, v)
-
-        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
-        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
-
-        y = self.wo(y)
-        return y
-
-
-class CrossAttention(nn.Module):
-    def __init__(self, config: ModelArgs):
-        super().__init__()
-        assert config.dim % config.n_head == 0
-
-        self.query = nn.Linear(config.dim, config.dim)
-        self.key = nn.Linear(config.dim, config.dim)
-        self.value = nn.Linear(config.dim, config.dim)
-        self.out = nn.Linear(config.dim, config.dim)
-
-        self.k_cache = None
-        self.v_cache = None
-
-        self.n_head = config.n_head
-        self.head_dim = config.head_dim
-
-    def get_kv(self, xa: torch.Tensor):
-        if self.k_cache is not None and self.v_cache is not None:
-            return self.k_cache, self.v_cache
-
-        k = self.key(xa)
-        v = self.value(xa)
-
-        # Reshape for correct format
-        batch_size, source_seq_len, _ = k.shape
-        k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
-        v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
-
-        if self.k_cache is None:
-            self.k_cache = k
-        if self.v_cache is None:
-            self.v_cache = v
-
-        return k, v
-
-    def forward(
-        self,
-        x: Tensor,
-        xa: Tensor,
-    ):
-        q = self.query(x)
-        batch_size, target_seq_len, _ = q.shape
-        q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
-        k, v = self.get_kv(xa)
-
-        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
-
-        wv = F.scaled_dot_product_attention(
-            query=q,
-            key=k,
-            value=v,
-            is_causal=False,
-        )
-        wv = wv.transpose(1, 2).reshape(
-            batch_size,
-            target_seq_len,
-            self.n_head * self.head_dim,
-        )
-
-        return self.out(wv)

+ 1 - 1
mineru/model/table/rec/unet_table/main.py

@@ -10,7 +10,7 @@ import numpy as np
 import cv2
 from PIL import Image
 from loguru import logger
-from  ..slanet_plus.rapid_table import RapidTableInput, RapidTable
+from  ..slanet_plus.main import RapidTableInput, RapidTable
 
 from .table_structure_unet import TSRUnet