Forráskód Böngészése

Merge pull request #3779 from myhloli/dev

mfr add paddle
Xiaomeng Zhao 4 hete
szülő
commit
a513357607
75 módosított fájl, 4837 hozzáadás és 15 törlés
  1. 19 3
      mineru/backend/pipeline/model_init.py
  2. 0 1
      mineru/backend/pipeline/model_list.py
  3. 2 2
      mineru/backend/vlm/vlm_analyze.py
  4. 0 0
      mineru/model/mfr/pp_formulanet_plus_m/__init__.py
  5. 141 0
      mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py
  6. 634 0
      mineru/model/mfr/pp_formulanet_plus_m/processors.py
  7. 3 3
      mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py
  8. 0 0
      mineru/model/utils/__init__.py
  9. 0 0
      mineru/model/utils/pytorchocr/__init__.py
  10. 0 0
      mineru/model/utils/pytorchocr/base_ocr_v20.py
  11. 0 0
      mineru/model/utils/pytorchocr/data/__init__.py
  12. 0 0
      mineru/model/utils/pytorchocr/data/imaug/__init__.py
  13. 0 0
      mineru/model/utils/pytorchocr/data/imaug/operators.py
  14. 0 0
      mineru/model/utils/pytorchocr/modeling/__init__.py
  15. 0 0
      mineru/model/utils/pytorchocr/modeling/architectures/__init__.py
  16. 0 0
      mineru/model/utils/pytorchocr/modeling/architectures/base_model.py
  17. 2 1
      mineru/model/utils/pytorchocr/modeling/backbones/__init__.py
  18. 0 0
      mineru/model/utils/pytorchocr/modeling/backbones/det_mobilenet_v3.py
  19. 0 0
      mineru/model/utils/pytorchocr/modeling/backbones/rec_donut_swin.py
  20. 0 0
      mineru/model/utils/pytorchocr/modeling/backbones/rec_hgnet.py
  21. 0 0
      mineru/model/utils/pytorchocr/modeling/backbones/rec_lcnetv3.py
  22. 0 0
      mineru/model/utils/pytorchocr/modeling/backbones/rec_mobilenet_v3.py
  23. 0 0
      mineru/model/utils/pytorchocr/modeling/backbones/rec_mv1_enhance.py
  24. 2 2
      mineru/model/utils/pytorchocr/modeling/backbones/rec_pphgnetv2.py
  25. 0 0
      mineru/model/utils/pytorchocr/modeling/backbones/rec_svtrnet.py
  26. 0 0
      mineru/model/utils/pytorchocr/modeling/common.py
  27. 2 0
      mineru/model/utils/pytorchocr/modeling/heads/__init__.py
  28. 0 0
      mineru/model/utils/pytorchocr/modeling/heads/cls_head.py
  29. 0 0
      mineru/model/utils/pytorchocr/modeling/heads/det_db_head.py
  30. 0 0
      mineru/model/utils/pytorchocr/modeling/heads/rec_ctc_head.py
  31. 0 0
      mineru/model/utils/pytorchocr/modeling/heads/rec_multi_head.py
  32. 1379 0
      mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py
  33. 2624 0
      mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py
  34. 0 0
      mineru/model/utils/pytorchocr/modeling/necks/__init__.py
  35. 0 0
      mineru/model/utils/pytorchocr/modeling/necks/db_fpn.py
  36. 0 0
      mineru/model/utils/pytorchocr/modeling/necks/intracl.py
  37. 0 0
      mineru/model/utils/pytorchocr/modeling/necks/rnn.py
  38. 0 0
      mineru/model/utils/pytorchocr/postprocess/__init__.py
  39. 0 0
      mineru/model/utils/pytorchocr/postprocess/cls_postprocess.py
  40. 0 0
      mineru/model/utils/pytorchocr/postprocess/db_postprocess.py
  41. 0 0
      mineru/model/utils/pytorchocr/postprocess/rec_postprocess.py
  42. 0 0
      mineru/model/utils/pytorchocr/utils/__init__.py
  43. 0 0
      mineru/model/utils/pytorchocr/utils/resources/arch_config.yaml
  44. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/arabic_dict.txt
  45. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/chinese_cht_dict.txt
  46. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/cyrillic_dict.txt
  47. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/devanagari_dict.txt
  48. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/en_dict.txt
  49. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/japan_dict.txt
  50. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ka_dict.txt
  51. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/korean_dict.txt
  52. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/latin_dict.txt
  53. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt
  54. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt
  55. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_dict.txt
  56. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt
  57. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt
  58. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt
  59. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt
  60. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt
  61. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt
  62. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/ta_dict.txt
  63. 0 0
      mineru/model/utils/pytorchocr/utils/resources/dict/te_dict.txt
  64. 0 0
      mineru/model/utils/pytorchocr/utils/resources/models_config.yml
  65. 24 0
      mineru/model/utils/pytorchocr/utils/resources/pp_formulanet_arch_config.yaml
  66. 0 0
      mineru/model/utils/tools/__init__.py
  67. 1 0
      mineru/model/utils/tools/infer/__init__.py
  68. 0 0
      mineru/model/utils/tools/infer/predict_cls.py
  69. 0 0
      mineru/model/utils/tools/infer/predict_det.py
  70. 0 0
      mineru/model/utils/tools/infer/predict_rec.py
  71. 0 0
      mineru/model/utils/tools/infer/predict_system.py
  72. 0 0
      mineru/model/utils/tools/infer/pytorchocr_utility.py
  73. 1 1
      mineru/model/vlm_vllm_model/server.py
  74. 1 0
      mineru/utils/enum_class.py
  75. 2 2
      pyproject.toml

+ 19 - 3
mineru/backend/pipeline/model_init.py

@@ -7,6 +7,7 @@ from .model_list import AtomicModel
 from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
+from ...model.mfr.pp_formulanet_plus_m.predict_formula import FormulaRecognizer
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
@@ -16,6 +17,9 @@ from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
+MFR_MODEL = "unimernet_small"
+# MFR_MODEL = "pp_formulanet_plus_m"
+
 
 def img_orientation_cls_model_init():
     atom_model_manager = AtomModelSingleton()
@@ -68,7 +72,13 @@ def mfd_model_init(weight, device='cpu'):
 
 
 def mfr_model_init(weight_dir, device='cpu'):
-    mfr_model = UnimernetModel(weight_dir, device)
+    if MFR_MODEL == "unimernet_small":
+        mfr_model = UnimernetModel(weight_dir, device)
+    elif MFR_MODEL == "pp_formulanet_plus_m":
+        mfr_model = FormulaRecognizer(weight_dir, device)
+    else:
+        logger.error('MFR model name not allow')
+        exit(1)
     return mfr_model
 
 
@@ -205,11 +215,17 @@ class MineruPipelineModel:
             )
 
             # 初始化公式解析模型
-            mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
+            if MFR_MODEL == "unimernet_small":
+                mfr_model_path = ModelPath.unimernet_small
+            elif MFR_MODEL == "pp_formulanet_plus_m":
+                mfr_model_path = ModelPath.pp_formulanet_plus_m
+            else:
+                logger.error('MFR model name not allow')
+                exit(1)
 
             self.mfr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
-                mfr_weight_dir=mfr_weight_dir,
+                mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
                 device=self.device,
             )
 

+ 0 - 1
mineru/backend/pipeline/model_list.py

@@ -7,4 +7,3 @@ class AtomicModel:
     WiredTable = "wired_table"
     TableCls = "table_cls"
     ImgOrientationCls = "img_ori_cls"
-

+ 2 - 2
mineru/backend/vlm/vlm_analyze.py

@@ -96,7 +96,7 @@ class ModelSingleton:
                         except ImportError:
                             raise ImportError("Please install vllm to use the vllm-engine backend.")
                         if "gpu_memory_utilization" not in kwargs:
-                            kwargs["gpu_memory_utilization"] = 0.5
+                            kwargs["gpu_memory_utilization"] = 0.7
                         if "model" not in kwargs:
                             kwargs["model"] = model_path
                         if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
@@ -111,7 +111,7 @@ class ModelSingleton:
                         except ImportError:
                             raise ImportError("Please install vllm to use the vllm-async-engine backend.")
                         if "gpu_memory_utilization" not in kwargs:
-                            kwargs["gpu_memory_utilization"] = 0.5
+                            kwargs["gpu_memory_utilization"] = 0.7
                         if "model" not in kwargs:
                             kwargs["model"] = model_path
                         if enable_custom_logits_processors() and ("logits_processors" not in kwargs):

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/__init__.py → mineru/model/mfr/pp_formulanet_plus_m/__init__.py


+ 141 - 0
mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py

@@ -0,0 +1,141 @@
+import os
+import torch
+import yaml
+from pathlib import Path
+from tqdm import tqdm
+from mineru.model.utils.tools.infer import pytorchocr_utility
+from mineru.model.utils.pytorchocr.base_ocr_v20 import BaseOCRV20
+from .processors import (
+    UniMERNetImgDecode,
+    UniMERNetTestTransform,
+    LatexImageFormat,
+    ToBatch,
+    UniMERNetDecode,
+)
+
+
+class FormulaRecognizer(BaseOCRV20):
+    def __init__(
+        self,
+        weight_dir,
+        device="cpu",
+    ):
+        self.weights_path = os.path.join(
+            weight_dir,
+            "PP-FormulaNet_plus-M.pth",
+        )
+        self.yaml_path = os.path.join(
+            Path(__file__).parent.parent.parent,
+            "utils",
+            "pytorchocr",
+            "utils",
+            "resources",
+            "pp_formulanet_arch_config.yaml"
+        )
+        self.infer_yaml_path = os.path.join(
+            weight_dir,
+            "PP-FormulaNet_plus-M_inference.yml",
+        )
+
+        network_config = pytorchocr_utility.AnalysisConfig(
+            self.weights_path, self.yaml_path
+        )
+        weights = self.read_pytorch_weights(self.weights_path)
+
+        super(FormulaRecognizer, self).__init__(network_config)
+
+        self.load_state_dict(weights)
+        # device = "cpu"
+        self.device = torch.device(device) if isinstance(device, str) else device
+        self.net.to(self.device)
+        self.net.eval()
+
+        with open(self.infer_yaml_path, "r", encoding="utf-8") as yaml_file:
+            data = yaml.load(yaml_file, Loader=yaml.FullLoader)
+
+        self.pre_tfs = {
+            "UniMERNetImgDecode": UniMERNetImgDecode(input_size=(384, 384)),
+            "UniMERNetTestTransform": UniMERNetTestTransform(),
+            "LatexImageFormat": LatexImageFormat(),
+            "ToBatch": ToBatch(),
+        }
+
+        self.post_op = UniMERNetDecode(
+            character_list=data["PostProcess"]["character_dict"]
+        )
+
+    def predict(self, img_list, batch_size: int = 64):
+        batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=img_list)
+        batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
+        batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
+        x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
+        x = torch.from_numpy(x[0]).to(self.device)
+        rec_formula = []
+        with torch.no_grad():
+            with tqdm(total=len(x), desc="Formula Predict") as pbar:
+                for index in range(0, len(x), batch_size):
+                    batch_data = x[index: index + batch_size]
+                    batch_preds = [self.net(batch_data)]
+                    batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
+                    rec_formula += self.post_op(batch_preds)
+                    pbar.update(len(batch_preds))
+        return rec_formula
+
+    def batch_predict(
+        self, images_mfd_res: list, images: list, batch_size: int = 64
+    ) -> list:
+        images_formula_list = []
+        mf_image_list = []
+        backfill_list = []
+        image_info = []  # Store (area, original_index, image) tuples
+
+        # Collect images with their original indices
+        for image_index in range(len(images_mfd_res)):
+            mfd_res = images_mfd_res[image_index]
+            image = images[image_index]
+            formula_list = []
+
+            for idx, (xyxy, conf, cla) in enumerate(
+                zip(mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls)
+            ):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    "category_id": 13 + int(cla.item()),
+                    "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    "score": round(float(conf.item()), 2),
+                    "latex": "",
+                }
+                formula_list.append(new_item)
+                bbox_img = image[ymin:ymax, xmin:xmax]
+                area = (xmax - xmin) * (ymax - ymin)
+
+                curr_idx = len(mf_image_list)
+                image_info.append((area, curr_idx, bbox_img))
+                mf_image_list.append(bbox_img)
+
+            images_formula_list.append(formula_list)
+            backfill_list += formula_list
+
+        # Stable sort by area
+        image_info.sort(key=lambda x: x[0])  # sort by area
+        sorted_indices = [x[1] for x in image_info]
+        sorted_images = [x[2] for x in image_info]
+
+        # Create mapping for results
+        index_mapping = {
+            new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)
+        }
+
+        # 进行预测
+        rec_formula = self.predict(sorted_images, batch_size)
+
+        # Restore original order
+        unsorted_results = [""] * len(rec_formula)
+        for new_idx, latex in enumerate(rec_formula):
+            original_idx = index_mapping[new_idx]
+            unsorted_results[original_idx] = latex
+
+        for res, latex in zip(backfill_list, unsorted_results):
+            res["latex"] = latex
+
+        return images_formula_list

+ 634 - 0
mineru/model/mfr/pp_formulanet_plus_m/processors.py

@@ -0,0 +1,634 @@
+import json
+import numpy as np
+import cv2
+import math
+import re
+
+from PIL import Image, ImageOps
+from typing import List, Optional, Tuple, Union, Dict, Any
+from tokenizers import AddedToken
+from tokenizers import Tokenizer as TokenizerFast
+
+
+class UniMERNetImgDecode(object):
+    """Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
+
+    def __init__(
+            self, input_size: Tuple[int, int], random_padding: bool = False, **kwargs
+    ) -> None:
+        """Initializes the UniMERNetImgDecode class with input size and random padding options.
+
+        Args:
+            input_size (tuple): The desired input size for the images (height, width).
+            random_padding (bool): Whether to use random padding for resizing.
+            **kwargs: Additional keyword arguments."""
+        self.input_size = input_size
+        self.random_padding = random_padding
+
+    def crop_margin(self, img: Image.Image) -> Image.Image:
+        """Crops the margin of the image based on grayscale thresholding.
+
+        Args:
+            img (PIL.Image.Image): The input image.
+
+        Returns:
+            PIL.Image.Image: The cropped image."""
+        data = np.array(img.convert("L"))
+        data = data.astype(np.uint8)
+        max_val = data.max()
+        min_val = data.min()
+        if max_val == min_val:
+            return img
+        data = (data - min_val) / (max_val - min_val) * 255
+        gray = 255 * (data < 200).astype(np.uint8)
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        return img.crop((a, b, w + a, h + b))
+
+    def get_dimensions(self, img: Union[Image.Image, np.ndarray]) -> List[int]:
+        """Gets the dimensions of the image.
+
+        Args:
+            img (PIL.Image.Image or numpy.ndarray): The input image.
+
+        Returns:
+            list: A list containing the number of channels, height, and width."""
+        if hasattr(img, "getbands"):
+            channels = len(img.getbands())
+        else:
+            channels = img.channels
+        width, height = img.size
+        return [channels, height, width]
+
+    def _compute_resized_output_size(
+            self,
+            image_size: Tuple[int, int],
+            size: Union[int, Tuple[int, int]],
+            max_size: Optional[int] = None,
+    ) -> List[int]:
+        """Computes the resized output size of the image.
+
+        Args:
+            image_size (tuple): The original size of the image (height, width).
+            size (int or tuple): The desired size for the smallest edge or both height and width.
+            max_size (int, optional): The maximum allowed size for the longer edge.
+
+        Returns:
+            list: A list containing the new height and width."""
+        if len(size) == 1:  # specified size only for the smallest edge
+            h, w = image_size
+            short, long = (w, h) if w <= h else (h, w)
+            requested_new_short = size if isinstance(size, int) else size[0]
+
+            new_short, new_long = requested_new_short, int(
+                requested_new_short * long / short
+            )
+
+            if max_size is not None:
+                if max_size <= requested_new_short:
+                    raise ValueError(
+                        f"max_size = {max_size} must be strictly greater than the requested "
+                        f"size for the smaller edge size = {size}"
+                    )
+                if new_long > max_size:
+                    new_short, new_long = int(max_size * new_short / new_long), max_size
+
+            new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
+        else:  # specified both h and w
+            new_w, new_h = size[1], size[0]
+        return [new_h, new_w]
+
+    def resize(
+            self, img: Image.Image, size: Union[int, Tuple[int, int]]
+    ) -> Image.Image:
+        """Resizes the image to the specified size.
+
+        Args:
+            img (PIL.Image.Image): The input image.
+            size (int or tuple): The desired size for the smallest edge or both height and width.
+
+        Returns:
+            PIL.Image.Image: The resized image."""
+        _, image_height, image_width = self.get_dimensions(img)
+        if isinstance(size, int):
+            size = [size]
+        max_size = None
+        output_size = self._compute_resized_output_size(
+            (image_height, image_width), size, max_size
+        )
+        img = img.resize(tuple(output_size[::-1]), resample=2)
+        return img
+
+    def img_decode(self, img: np.ndarray) -> Optional[np.ndarray]:
+        """Decodes the image by cropping margins, resizing, and adding padding.
+
+        Args:
+            img (numpy.ndarray): The input image array.
+
+        Returns:
+            numpy.ndarray: The decoded image array."""
+        try:
+            img = self.crop_margin(Image.fromarray(img).convert("RGB"))
+        except OSError:
+            return
+        if img.height == 0 or img.width == 0:
+            return
+        img = self.resize(img, min(self.input_size))
+        img.thumbnail((self.input_size[1], self.input_size[0]))
+        delta_width = self.input_size[1] - img.width
+        delta_height = self.input_size[0] - img.height
+        if self.random_padding:
+            pad_width = np.random.randint(low=0, high=delta_width + 1)
+            pad_height = np.random.randint(low=0, high=delta_height + 1)
+        else:
+            pad_width = delta_width // 2
+            pad_height = delta_height // 2
+        padding = (
+            pad_width,
+            pad_height,
+            delta_width - pad_width,
+            delta_height - pad_height,
+        )
+        return np.array(ImageOps.expand(img, padding))
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[Optional[np.ndarray]]:
+        """Calls the img_decode method on a list of images.
+
+        Args:
+            imgs (list of numpy.ndarray): The list of input image arrays.
+
+        Returns:
+            list of numpy.ndarray: The list of decoded image arrays."""
+        return [self.img_decode(img) for img in imgs]
+
+
+class UniMERNetTestTransform:
+    """
+    A class for transforming images according to UniMERNet test specifications.
+    """
+
+    def __init__(self, **kwargs) -> None:
+        """
+        Initializes the UniMERNetTestTransform class.
+        """
+        super().__init__()
+        self.num_output_channels = 3
+
+    def transform(self, img: np.ndarray) -> np.ndarray:
+        """
+        Transforms a single image for UniMERNet testing.
+
+        Args:
+            img (numpy.ndarray): The input image.
+
+        Returns:
+            numpy.ndarray: The transformed image.
+        """
+        mean = [0.7931, 0.7931, 0.7931]
+        std = [0.1738, 0.1738, 0.1738]
+        scale = float(1 / 255.0)
+        shape = (1, 1, 3)
+        mean = np.array(mean).reshape(shape).astype("float32")
+        std = np.array(std).reshape(shape).astype("float32")
+        img = (img.astype("float32") * scale - mean) / std
+        grayscale_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        squeezed = np.squeeze(grayscale_image)
+        img = cv2.merge([squeezed] * self.num_output_channels)
+        return img
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """
+        Applies the transform to a list of images.
+
+        Args:
+            imgs (list of numpy.ndarray): The list of input images.
+
+        Returns:
+            list of numpy.ndarray: The list of transformed images.
+        """
+        return [self.transform(img) for img in imgs]
+
+
+class LatexImageFormat:
+    """Class for formatting images to a specific format suitable for LaTeX."""
+
+    def __init__(self, **kwargs) -> None:
+        """Initializes the LatexImageFormat class with optional keyword arguments."""
+        super().__init__()
+
+    def format(self, img: np.ndarray) -> np.ndarray:
+        """Formats a single image to the LaTeX-compatible format.
+
+        Args:
+            img (numpy.ndarray): The input image as a numpy array.
+
+        Returns:
+            numpy.ndarray: The formatted image as a numpy array with an added dimension for color.
+        """
+        im_h, im_w = img.shape[:2]
+        divide_h = math.ceil(im_h / 16) * 16
+        divide_w = math.ceil(im_w / 16) * 16
+        img = img[:, :, 0]
+        img = np.pad(
+            img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
+        )
+        img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
+        return img_expanded[np.newaxis, :]
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Applies the format method to a list of images.
+
+        Args:
+            imgs (list of numpy.ndarray): A list of input images as numpy arrays.
+
+        Returns:
+            list of numpy.ndarray: A list of formatted images as numpy arrays.
+        """
+        return [self.format(img) for img in imgs]
+
+
+class ToBatch(object):
+    """A class for batching images."""
+
+    def __init__(self, **kwargs) -> None:
+        """Initializes the ToBatch object."""
+        super(ToBatch, self).__init__()
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Concatenates a list of images into a single batch.
+
+        Args:
+            imgs (list): A list of image arrays to be concatenated.
+
+        Returns:
+            list: A list containing the concatenated batch of images wrapped in another list (to comply with common batch processing formats).
+        """
+        batch_imgs = np.concatenate(imgs)
+        batch_imgs = batch_imgs.copy()
+        x = [batch_imgs]
+        return x
+
+
+class UniMERNetDecode(object):
+    """Class for decoding tokenized inputs using UniMERNet tokenizer.
+
+    Attributes:
+        SPECIAL_TOKENS_ATTRIBUTES (List[str]): List of special token attributes.
+        model_input_names (List[str]): List of model input names.
+        max_seq_len (int): Maximum sequence length.
+        pad_token_id (int): ID for the padding token.
+        bos_token_id (int): ID for the beginning-of-sequence token.
+        eos_token_id (int): ID for the end-of-sequence token.
+        padding_side (str): Padding side, either 'left' or 'right'.
+        pad_token (str): Padding token.
+        pad_token_type_id (int): Type ID for the padding token.
+        pad_to_multiple_of (Optional[int]): If set, pad to a multiple of this value.
+        tokenizer (TokenizerFast): Fast tokenizer instance.
+
+    Args:
+        character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
+        **kwargs: Additional keyword arguments.
+    """
+
+    SPECIAL_TOKENS_ATTRIBUTES = [
+        "bos_token",
+        "eos_token",
+        "unk_token",
+        "sep_token",
+        "pad_token",
+        "cls_token",
+        "mask_token",
+        "additional_special_tokens",
+    ]
+
+    def __init__(
+            self,
+            character_list: Dict[str, Any],
+            **kwargs,
+    ) -> None:
+        """Initializes the UniMERNetDecode class.
+
+        Args:
+            character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
+            **kwargs: Additional keyword arguments.
+        """
+
+        self._unk_token = "<unk>"
+        self._bos_token = "<s>"
+        self._eos_token = "</s>"
+        self._pad_token = "<pad>"
+        self._sep_token = None
+        self._cls_token = None
+        self._mask_token = None
+        self._additional_special_tokens = []
+        self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
+        self.max_seq_len = 2048
+        self.pad_token_id = 1
+        self.bos_token_id = 0
+        self.eos_token_id = 2
+        self.padding_side = "right"
+        self.pad_token_id = 1
+        self.pad_token = "<pad>"
+        self.pad_token_type_id = 0
+        self.pad_to_multiple_of = None
+
+        fast_tokenizer_str = json.dumps(character_list["fast_tokenizer_file"])
+        fast_tokenizer_buffer = fast_tokenizer_str.encode("utf-8")
+        self.tokenizer = TokenizerFast.from_buffer(fast_tokenizer_buffer)
+        tokenizer_config = (
+            character_list["tokenizer_config_file"]
+            if "tokenizer_config_file" in character_list
+            else None
+        )
+        added_tokens_decoder = {}
+        added_tokens_map = {}
+        if tokenizer_config is not None:
+            init_kwargs = tokenizer_config
+            if "added_tokens_decoder" in init_kwargs:
+                for idx, token in init_kwargs["added_tokens_decoder"].items():
+                    if isinstance(token, dict):
+                        token = AddedToken(**token)
+                    if isinstance(token, AddedToken):
+                        added_tokens_decoder[int(idx)] = token
+                        added_tokens_map[str(token)] = token
+                    else:
+                        raise ValueError(
+                            f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
+                        )
+            init_kwargs["added_tokens_decoder"] = added_tokens_decoder
+            added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
+            tokens_to_add = [
+                token
+                for index, token in sorted(
+                    added_tokens_decoder.items(), key=lambda x: x[0]
+                )
+                if token not in added_tokens_decoder
+            ]
+            added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
+            encoder = list(added_tokens_encoder.keys()) + [
+                str(token) for token in tokens_to_add
+            ]
+            tokens_to_add += [
+                token
+                for token in self.all_special_tokens_extended
+                if token not in encoder and token not in tokens_to_add
+            ]
+            if len(tokens_to_add) > 0:
+                is_last_special = None
+                tokens = []
+                special_tokens = self.all_special_tokens
+                for token in tokens_to_add:
+                    is_special = (
+                        (token.special or str(token) in special_tokens)
+                        if isinstance(token, AddedToken)
+                        else str(token) in special_tokens
+                    )
+                    if is_last_special is None or is_last_special == is_special:
+                        tokens.append(token)
+                    else:
+                        self._add_tokens(tokens, special_tokens=is_last_special)
+                        tokens = [token]
+                    is_last_special = is_special
+                if tokens:
+                    self._add_tokens(tokens, special_tokens=is_last_special)
+
+    def _add_tokens(
+            self, new_tokens: "List[Union[AddedToken, str]]", special_tokens: bool = False
+    ) -> "List[Union[AddedToken, str]]":
+        """Adds new tokens to the tokenizer.
+
+        Args:
+            new_tokens (List[Union[AddedToken, str]]): Tokens to be added.
+            special_tokens (bool): Indicates whether the tokens are special tokens.
+
+        Returns:
+            List[Union[AddedToken, str]]: added tokens.
+        """
+        if special_tokens:
+            return self.tokenizer.add_special_tokens(new_tokens)
+
+        return self.tokenizer.add_tokens(new_tokens)
+
+    def added_tokens_encoder(
+            self, added_tokens_decoder: "Dict[int, AddedToken]"
+    ) -> Dict[str, int]:
+        """Creates an encoder dictionary from added tokens.
+
+        Args:
+            added_tokens_decoder (Dict[int, AddedToken]): Dictionary mapping token IDs to tokens.
+
+        Returns:
+            Dict[str, int]: Dictionary mapping token strings to IDs.
+        """
+        return {
+            k.content: v
+            for v, k in sorted(added_tokens_decoder.items(), key=lambda item: item[0])
+        }
+
+    @property
+    def all_special_tokens(self) -> List[str]:
+        """Retrieves all special tokens.
+
+        Returns:
+            List[str]: List of all special tokens as strings.
+        """
+        all_toks = [str(s) for s in self.all_special_tokens_extended]
+        return all_toks
+
+    @property
+    def all_special_tokens_extended(self) -> "List[Union[str, AddedToken]]":
+        """Retrieves all special tokens, including extended ones.
+
+        Returns:
+            List[Union[str, AddedToken]]: List of all special tokens.
+        """
+        all_tokens = []
+        seen = set()
+        for value in self.special_tokens_map_extended.values():
+            if isinstance(value, (list, tuple)):
+                tokens_to_add = [token for token in value if str(token) not in seen]
+            else:
+                tokens_to_add = [value] if str(value) not in seen else []
+            seen.update(map(str, tokens_to_add))
+            all_tokens.extend(tokens_to_add)
+        return all_tokens
+
+    @property
+    def special_tokens_map_extended(self) -> Dict[str, Union[str, List[str]]]:
+        """Retrieves the extended map of special tokens.
+
+        Returns:
+            Dict[str, Union[str, List[str]]]: Dictionary mapping special token attributes to their values.
+        """
+        set_attr = {}
+        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+            attr_value = getattr(self, "_" + attr)
+            if attr_value:
+                set_attr[attr] = attr_value
+        return set_attr
+
+    def convert_ids_to_tokens(
+            self, ids: Union[int, List[int]], skip_special_tokens: bool = False
+    ) -> Union[str, List[str]]:
+        """Converts token IDs to token strings.
+
+        Args:
+            ids (Union[int, List[int]]): Token ID(s) to convert.
+            skip_special_tokens (bool): Whether to skip special tokens during conversion.
+
+        Returns:
+            Union[str, List[str]]: Converted token string(s).
+        """
+        if isinstance(ids, int):
+            return self.tokenizer.id_to_token(ids)
+        tokens = []
+        for index in ids:
+            index = int(index)
+            if skip_special_tokens and index in self.all_special_ids:
+                continue
+            tokens.append(self.tokenizer.id_to_token(index))
+        return tokens
+
+    def detokenize(self, tokens: List[List[int]]) -> List[List[str]]:
+        """Detokenizes a list of token IDs back into strings.
+
+        Args:
+            tokens (List[List[int]]): List of token ID lists.
+
+        Returns:
+            List[List[str]]: List of detokenized strings.
+        """
+        self.tokenizer.bos_token = "<s>"
+        self.tokenizer.eos_token = "</s>"
+        self.tokenizer.pad_token = "<pad>"
+        toks = [self.convert_ids_to_tokens(tok) for tok in tokens]
+        for b in range(len(toks)):
+            for i in reversed(range(len(toks[b]))):
+                if toks[b][i] is None:
+                    toks[b][i] = ""
+                toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
+                if toks[b][i] in (
+                        [
+                            self.tokenizer.bos_token,
+                            self.tokenizer.eos_token,
+                            self.tokenizer.pad_token,
+                        ]
+                ):
+                    del toks[b][i]
+        return toks
+
+    def token2str(self, token_ids: List[List[int]]) -> List[str]:
+        """Converts a list of token IDs to strings.
+
+        Args:
+            token_ids (List[List[int]]): List of token ID lists.
+
+        Returns:
+            List[str]: List of converted strings.
+        """
+        generated_text = []
+        for tok_id in token_ids:
+            end_idx = np.argwhere(tok_id == 2)
+            if len(end_idx) > 0:
+                end_idx = int(end_idx[0][0])
+                tok_id = tok_id[: end_idx + 1]
+            generated_text.append(
+                self.tokenizer.decode(tok_id, skip_special_tokens=True)
+            )
+        generated_text = [self.post_process(text) for text in generated_text]
+        return generated_text
+
+    def normalize(self, s: str) -> str:
+        """Normalizes a string by removing unnecessary spaces.
+
+        Args:
+            s (str): String to normalize.
+
+        Returns:
+            str: Normalized string.
+        """
+        text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
+        letter = "[a-zA-Z]"
+        noletter = r"[\W_^\d]"
+        names = []
+        for x in re.findall(text_reg, s):
+            pattern = r"(\\[a-zA-Z]+)\s(?=\w)|\\[a-zA-Z]+\s(?=})"
+            matches = re.findall(pattern, x[0])
+            for m in matches:
+                if (
+                        m
+                        not in [
+                    "\\operatorname",
+                    "\\mathrm",
+                    "\\text",
+                    "\\mathbf",
+                ]
+                        and m.strip() != ""
+                ):
+                    s = s.replace(m, m + "XXXXXXX")
+                    s = s.replace(" ", "")
+                    names.append(s)
+        if len(names) > 0:
+            s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+        news = s
+        while True:
+            s = news
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
+            news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
+            if news == s:
+                break
+        return s.replace("XXXXXXX", " ")
+
+    def remove_chinese_text_wrapping(self, formula):
+        pattern = re.compile(r"\\text\s*{\s*([^}]*?[\u4e00-\u9fff]+[^}]*?)\s*}")
+
+        def replacer(match):
+            return match.group(1)
+
+        replaced_formula = pattern.sub(replacer, formula)
+        return replaced_formula.replace('"', "")
+
+    def post_process(self, text: str) -> str:
+        """Post-processes a string by fixing text and normalizing it.
+
+        Args:
+            text (str): String to post-process.
+
+        Returns:
+            str: Post-processed string.
+        """
+        from ftfy import fix_text
+
+        text = self.remove_chinese_text_wrapping(text)
+        text = fix_text(text)
+        text = self.normalize(text)
+        return text
+
+    def __call__(
+            self,
+            preds: np.ndarray,
+            label: Optional[np.ndarray] = None,
+            mode: str = "eval",
+            *args,
+            **kwargs,
+    ) -> Union[List[str], tuple]:
+        """Processes predictions and optionally labels, returning the decoded text.
+
+        Args:
+            preds (np.ndarray): Model predictions.
+            label (Optional[np.ndarray]): True labels, if available.
+            mode (str): Mode of operation, either 'train' or 'eval'.
+
+        Returns:
+            Union[List[str], tuple]: Decoded text, optionally with labels.
+        """
+        if mode == "train":
+            preds_idx = np.array(preds.argmax(axis=2))
+            text = self.token2str(preds_idx)
+        else:
+            text = self.token2str(np.array(preds))
+        if label is None:
+            return text
+        label = self.token2str(np.array(label))
+        return text, label

+ 3 - 3
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -13,8 +13,8 @@ from mineru.utils.config_reader import get_device
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from mineru.utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
-from mineru.model.ocr.paddleocr2pytorch.tools.infer.predict_system import TextSystem
-from mineru.model.ocr.paddleocr2pytorch.tools.infer import pytorchocr_utility as utility
+from mineru.model.utils.tools.infer.predict_system import TextSystem
+from mineru.model.utils.tools.infer import pytorchocr_utility as utility
 import argparse
 
 
@@ -47,7 +47,7 @@ def get_model_params(lang, config):
         raise Exception (f'Language {lang} not supported')
 
 
-root_dir = Path(__file__).resolve().parent
+root_dir = os.path.join(Path(__file__).resolve().parent.parent.parent, 'utils')
 
 
 class PytorchPaddleOCR(TextSystem):

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/tools/infer/__init__.py → mineru/model/utils/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/__init__.py → mineru/model/utils/pytorchocr/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py → mineru/model/utils/pytorchocr/base_ocr_v20.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/data/__init__.py → mineru/model/utils/pytorchocr/data/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/data/imaug/__init__.py → mineru/model/utils/pytorchocr/data/imaug/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/data/imaug/operators.py → mineru/model/utils/pytorchocr/data/imaug/operators.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/__init__.py → mineru/model/utils/pytorchocr/modeling/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/__init__.py → mineru/model/utils/pytorchocr/modeling/architectures/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/architectures/base_model.py → mineru/model/utils/pytorchocr/modeling/architectures/base_model.py


+ 2 - 1
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py → mineru/model/utils/pytorchocr/modeling/backbones/__init__.py

@@ -37,7 +37,7 @@ def build_backbone(config, model_type):
         from .rec_mobilenet_v3 import MobileNetV3
         from .rec_svtrnet import SVTRNet
         from .rec_mv1_enhance import MobileNetV1Enhance
-        from .rec_pphgnetv2 import PPHGNetV2_B4
+        from .rec_pphgnetv2 import PPHGNetV2_B4, PPHGNetV2_B6_Formula
         support_dict = [
             "MobileNetV1Enhance",
             "MobileNetV3",
@@ -51,6 +51,7 @@ def build_backbone(config, model_type):
             "PPLCNetV3",
             "PPHGNet_small",
             "PPHGNetV2_B4",
+            "PPHGNetV2_B6_Formula"
         ]
     else:
         raise NotImplementedError

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/det_mobilenet_v3.py → mineru/model/utils/pytorchocr/modeling/backbones/det_mobilenet_v3.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_donut_swin.py → mineru/model/utils/pytorchocr/modeling/backbones/rec_donut_swin.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_hgnet.py → mineru/model/utils/pytorchocr/modeling/backbones/rec_hgnet.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_lcnetv3.py → mineru/model/utils/pytorchocr/modeling/backbones/rec_lcnetv3.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mobilenet_v3.py → mineru/model/utils/pytorchocr/modeling/backbones/rec_mobilenet_v3.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_mv1_enhance.py → mineru/model/utils/pytorchocr/modeling/backbones/rec_mv1_enhance.py


+ 2 - 2
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_pphgnetv2.py → mineru/model/utils/pytorchocr/modeling/backbones/rec_pphgnetv2.py

@@ -1626,8 +1626,8 @@ class PPHGNetV2_B6_Formula(nn.Module):
             pixel_values = torch.repeat_interleave(pixel_values, repeats=3, dim=1)
         pphgnet_b6_output = self.pphgnet_b6(pixel_values)
         b, c, h, w = pphgnet_b6_output.shape
-        pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).transpose(
-            [0, 2, 1]
+        pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).permute(
+            0, 2, 1
         )
         pphgnet_b6_output = DonutSwinModelOutput(
             last_hidden_state=pphgnet_b6_output,

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_svtrnet.py → mineru/model/utils/pytorchocr/modeling/backbones/rec_svtrnet.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/common.py → mineru/model/utils/pytorchocr/modeling/common.py


+ 2 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/__init__.py → mineru/model/utils/pytorchocr/modeling/heads/__init__.py

@@ -22,6 +22,7 @@ def build_head(config, **kwargs):
     # rec head
     from .rec_ctc_head import CTCHead
     from .rec_multi_head import MultiHead
+    from .rec_ppformulanet_head import PPFormulaNet_Head
 
     # cls head
     from .cls_head import ClsHead
@@ -32,6 +33,7 @@ def build_head(config, **kwargs):
         "ClsHead",
         "MultiHead",
         "PFHeadLocal",
+        "PPFormulaNet_Head",
     ]
 
     module_name = config.pop("name")

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/cls_head.py → mineru/model/utils/pytorchocr/modeling/heads/cls_head.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/det_db_head.py → mineru/model/utils/pytorchocr/modeling/heads/det_db_head.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_ctc_head.py → mineru/model/utils/pytorchocr/modeling/heads/rec_ctc_head.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/heads/rec_multi_head.py → mineru/model/utils/pytorchocr/modeling/heads/rec_multi_head.py


+ 1379 - 0
mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py

@@ -0,0 +1,1379 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import re
+import numpy as np
+import inspect
+import torch
+import torch.nn as nn
+from typing import Optional, Tuple, Union, List, Dict, Any
+from dataclasses import dataclass, fields, is_dataclass
+
+from sympy import totient
+
+from .rec_unimernet_head import (
+    MBartForCausalLM,
+    MBartDecoder,
+    MBartConfig,
+    ModelOutput,
+    BaseModelOutputWithPastAndCrossAttentions,
+    Seq2SeqLMOutput,
+    CausalLMOutputWithCrossAttentions,
+    LogitsProcessorList,
+    ForcedEOSTokenLogitsProcessor,
+    UniMERNetHead,
+)
+
+
+@dataclass
+class AttentionMaskConverter:
+    """
+    A class to convert attention masks based on specific configurations.
+
+    This class is designed to handle the conversion of attention masks with options for causal masking
+    and sliding window attention, which are commonly used in transformer models.
+
+    Attributes:
+        is_causal (bool): Flag indicating whether the attention mask should enforce causal masking,
+                          which ensures each position can only attend to previous positions.
+        sliding_window (int, optional): Size of the sliding window for local attention. If set,
+                                        attention is restricted to a local window of this size.
+
+    """
+
+    is_causal: bool
+    sliding_window: int
+
+    def __init__(self, is_causal: bool, sliding_window=None):
+        self.is_causal = is_causal
+        self.sliding_window = sliding_window
+
+        if self.sliding_window is not None and self.sliding_window <= 0:
+            raise ValueError(
+                f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
+            )
+
+    @staticmethod
+    def _make_causal_mask(
+            input_ids_shape,
+            dtype,
+            past_key_values_length=0,
+            sliding_window=None,
+            is_export=False,
+    ):
+        """
+        Make causal mask used for bi-directional self-attention.
+        """
+        bsz, tgt_len = input_ids_shape
+        if is_export:
+            mask = torch.full(
+                (tgt_len, tgt_len), torch.finfo(dtype).min, dtype=torch.float64
+            )
+            mask_cond = torch.arange(mask.shape[-1])
+            mask.masked_fill_(
+                mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
+            )
+        else:
+            mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
+            mask_cond = torch.arange(mask.shape[-1])
+            mask.masked_fill_(
+                mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
+            )
+            mask = mask.to(dtype)
+
+        if past_key_values_length > 0:
+            mask = torch.concat(
+                [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask],
+                dim=-1,
+            )
+
+        # add lower triangular sliding window mask if necessary
+        if sliding_window is not None:
+            diagonal = past_key_values_length - sliding_window - 1
+
+            context_mask = torch.tril(
+                torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
+            )
+            mask.masked_fill_(context_mask, torch.finfo(dtype).min)
+
+        return mask[None, None, :, :].expand(
+            [bsz, 1, tgt_len, tgt_len + past_key_values_length]
+        )
+
+    @staticmethod
+    def _make_causal_mask_parallel(
+            input_ids_shape,
+            dtype,
+            past_key_values_length=0,
+            sliding_window=None,
+            parallel_step=1,
+            is_export=False,
+    ):
+        """
+        Make causal mask used for bi-directional self-attention.
+        """
+        bsz, tgt_len = input_ids_shape
+        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
+        mask_cond = torch.arange(mask.shape[-1])
+        mask_cond_parallel = torch.arange(mask.shape[-1])
+
+        mask_parallel = torch.arange(0, tgt_len, step=parallel_step).reshape([1, -1])
+        mask_parallel = torch.repeat_interleave(mask_parallel, parallel_step, 1)[
+            :, :tgt_len
+        ]
+        mask.masked_fill_(
+            mask_cond < (mask_parallel + parallel_step).reshape([mask.shape[-1], 1]), 0
+        )
+        mask = mask.to(dtype)
+
+        if past_key_values_length > 0:
+            mask = torch.concat(
+                [torch.zeros([tgt_len, past_key_values_length], dtype=dtype), mask],
+                dim=-1,
+            )
+
+        # add lower triangular sliding window mask if necessary
+        if sliding_window is not None:
+            diagonal = past_key_values_length - sliding_window - 1
+
+            context_mask = torch.tril(
+                torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
+            )
+            mask.masked_fill_(context_mask, torch.finfo(dtype).min)
+
+        return mask[None, None, :, :].expand(
+            [bsz, 1, tgt_len, tgt_len + past_key_values_length]
+        )
+
+    def to_4d(
+            self,
+            attention_mask_2d,
+            query_length,
+            dtype,
+            key_value_length,
+            use_parallel=False,
+            parallel_step=3,
+            is_export=False,
+    ):
+        """
+        Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
+        key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
+        causal, a causal mask will be added.
+        """
+        input_shape = (attention_mask_2d.shape[0], query_length)
+
+        causal_4d_mask = None
+        if use_parallel:
+            step = parallel_step
+        else:
+            step = 1
+        if (
+                input_shape[-1] > step or self.sliding_window is not None
+        ) and self.is_causal:
+
+            if key_value_length is None:
+                raise ValueError(
+                    "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
+                )
+
+            past_key_values_length = key_value_length - query_length
+
+            if use_parallel:
+                causal_4d_mask = self._make_causal_mask_parallel(
+                    input_shape,
+                    dtype,
+                    past_key_values_length=past_key_values_length,
+                    sliding_window=self.sliding_window,
+                    parallel_step=parallel_step,
+                    is_export=is_export,
+                )
+            else:
+                causal_4d_mask = self._make_causal_mask(
+                    input_shape,
+                    dtype,
+                    past_key_values_length=past_key_values_length,
+                    sliding_window=self.sliding_window,
+                    is_export=is_export,
+                )
+
+        elif self.sliding_window is not None:
+            raise NotImplementedError(
+                "Sliding window is currently only implemented for causal masking"
+            )
+
+        expanded_attn_mask = self._expand_mask(
+            attention_mask_2d, dtype, tgt_len=input_shape[-1]
+        )
+
+        if causal_4d_mask is not None:
+            expanded_attn_mask = causal_4d_mask.masked_fill_(
+                expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
+            )
+
+        expanded_4d_mask = expanded_attn_mask
+        return expanded_4d_mask
+
+    def to_4d_export(
+            self,
+            attention_mask_2d,
+            query_length,
+            dtype,
+            key_value_length,
+            use_parallel=False,
+            parallel_step=3,
+            is_export=False,
+    ):
+        input_shape = (attention_mask_2d.shape[0], query_length)
+
+        expanded_attn_mask = self._expand_mask_export(
+            attention_mask_2d, dtype, tgt_len=input_shape[-1]
+        )
+        expanded_4d_mask = expanded_attn_mask
+
+        return expanded_4d_mask
+
+    def _expand_mask(self, mask, dtype, tgt_len=None):
+        """
+        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+        """
+        bsz, src_len = mask.shape
+        tgt_len = tgt_len if tgt_len is not None else src_len
+        expanded_mask = (
+            mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
+        )
+
+        inverted_mask = 1.0 - expanded_mask
+
+        return inverted_mask.masked_fill_(
+            inverted_mask.to(torch.bool), torch.finfo(dtype).min
+        )
+
+    def _expand_mask_export(self, mask, dtype, tgt_len=None):
+        """
+        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+        """
+        bsz, src_len = mask.shape
+        expanded_mask = (
+            mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
+        )
+        inverted_mask = 1.0 - expanded_mask
+        return inverted_mask.masked_fill_(
+            inverted_mask.to(torch.bool), torch.finfo(dtype).min
+        )
+
+
+def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
+    return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _prepare_4d_causal_attention_mask(
+        attention_mask,
+        input_shape,
+        inputs_embeds,
+        past_key_values_length,
+        sliding_window=None,
+        use_parallel=False,
+        parallel_step=3,
+        is_export=False,
+):
+    """
+    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+    `(batch_size, key_value_length)`
+
+    Args:
+        attention_mask (`paddle.Tensor` or `None`):
+            A 2D attention mask of shape `(batch_size, key_value_length)`
+        input_shape (`tuple(int)` or `list(int)` or `paddle.Size`):
+            The input shape should be a tuple that defines `(batch_size, query_length)`.
+        inputs_embeds (`paddle.Tensor`):
+            The embedded inputs as a paddle Tensor.
+        past_key_values_length (`int`):
+            The length of the key value cache.
+        sliding_window (`int`, *optional*):
+            If the model uses windowed attention, a sliding window should be passed.
+    """
+    attn_mask_converter = AttentionMaskConverter(
+        is_causal=True, sliding_window=sliding_window
+    )
+
+    key_value_length = input_shape[-1] + past_key_values_length
+
+    # 4d mask is passed through the layers
+    if attention_mask is not None and len(attention_mask.shape) == 2:
+        attention_mask = attn_mask_converter.to_4d(
+            attention_mask,
+            input_shape[-1],
+            key_value_length=key_value_length,
+            dtype=inputs_embeds.dtype,
+            use_parallel=use_parallel,
+            parallel_step=parallel_step,
+            is_export=is_export,
+        )
+    elif attention_mask is not None and len(attention_mask.shape) == 4:
+        expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
+        if tuple(attention_mask.shape) != expected_shape:
+            raise ValueError(
+                f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
+            )
+        else:
+            # if the 4D mask has correct shape - invert it and fill with negative infinity
+            inverted_mask = 1.0 - attention_mask
+            attention_mask = inverted_mask.masked_fill_(
+                inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
+            )
+    else:
+        attention_mask = attn_mask_converter.to_causal_4d(
+            input_shape[0],
+            input_shape[-1],
+            key_value_length,
+            dtype=inputs_embeds.dtype,
+        )
+
+    return attention_mask
+
+
+def _prepare_4d_causal_attention_mask_export(
+        attention_mask,
+        input_shape,
+        inputs_embeds,
+        past_key_values_length,
+        sliding_window=None,
+        use_parallel=False,
+        parallel_step=3,
+        is_export=False,
+):
+    """
+    Prepare a 4D causal attention mask for export.
+
+    This function prepares a 4-dimensional causal attention mask, which is used to ensure that each position in the
+    sequence can only attend to previous positions. It is specifically designed to handle scenarios where the model
+    is being exported, potentially with additional options like sliding window or parallel processing.
+
+    Args:
+        attention_mask: The initial attention mask, typically used to avoid attending to padding tokens.
+        input_shape: Shape of the input tensor, usually in the form (batch_size, sequence_length).
+        inputs_embeds: Embeddings of the input sequence, used to derive certain dimensions if needed.
+        past_key_values_length: Length of past key values, used in contexts like transformer decoders with caching.
+        sliding_window: Optional parameter. If provided, specifies the size of a sliding window for local attention.
+        use_parallel: Flag indicating whether to use parallel processing for attention computation.
+        parallel_step: Number of steps to use in parallel processing, relevant if `use_parallel` is True.
+        is_export: Flag indicating whether the attention mask is being prepared for model export.
+
+    Returns:
+        A 4D causal attention mask suitable for use in transformer models, ensuring correct causal masking.
+    """
+    attn_mask_converter = AttentionMaskConverter(
+        is_causal=True, sliding_window=sliding_window
+    )
+    key_value_length = input_shape[-1] + past_key_values_length
+
+    shape = attention_mask.shape
+    len_shape = len(shape)
+
+    attention_mask = attn_mask_converter.to_4d_export(
+        attention_mask,
+        input_shape[-1],
+        key_value_length=key_value_length,
+        dtype=inputs_embeds.dtype,
+        use_parallel=use_parallel,
+        parallel_step=parallel_step,
+        is_export=is_export,
+    )
+    return attention_mask
+
+
+class CustomMBartDecoder(MBartDecoder):
+    def __init__(self, config):
+        super().__init__(config)
+        hidden_size = config.d_model
+        self.is_export = config.is_export
+        self.config_decoder = config
+
+    def forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            encoder_hidden_states=None,
+            encoder_attention_mask=None,
+            head_mask=None,
+            cross_attn_head_mask=None,
+            past_key_values=None,
+            inputs_embeds=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+    ):
+        self.is_export = False if self.training else True
+
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+            )
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.shape
+            input_ids = input_ids.reshape([-1, input_shape[-1]])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.shape[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError(
+                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+            )
+
+        # past_key_values_length
+        past_key_values_length = (
+            past_key_values[0][0].shape[2] if past_key_values is not None else 0
+        )
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        if self._use_flash_attention_2:
+            # 2d mask is passed through the layers
+            attention_mask = (
+                attention_mask
+                if (attention_mask is not None and 0 in attention_mask)
+                else None
+            )
+        else:
+            # 4d mask is passed through the layers
+            if self.is_export:
+                attention_mask = _prepare_4d_causal_attention_mask_export(
+                    attention_mask,
+                    input_shape,
+                    inputs_embeds,
+                    past_key_values_length,
+                    use_parallel=self.config_decoder.use_parallel,
+                    parallel_step=self.config_decoder.parallel_step,
+                    is_export=self.is_export,
+                )
+            else:
+                attention_mask = _prepare_4d_causal_attention_mask(
+                    attention_mask,
+                    input_shape,
+                    inputs_embeds,
+                    past_key_values_length,
+                    use_parallel=self.config_decoder.use_parallel,
+                    parallel_step=self.config_decoder.parallel_step,
+                    is_export=self.is_export,
+                )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = (
+                    encoder_attention_mask if 0 in encoder_attention_mask else None
+                )
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.embed_positions(input, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions
+
+        hidden_states = self.layernorm_embedding(hidden_states)
+        hidden_states = nn.functional.dropout(
+            hidden_states, p=self.dropout, training=self.training
+        )
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                print(
+                    "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = (
+            () if (output_attentions and encoder_hidden_states is not None) else None
+        )
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip(
+                [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
+        ):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = (
+                past_key_values[idx] if past_key_values is not None else None
+            )
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    (
+                        cross_attn_head_mask[idx]
+                        if cross_attn_head_mask is not None
+                        else None
+                    ),
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx]
+                        if cross_attn_head_mask is not None
+                        else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if self.is_export:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+            else:
+                if use_cache:
+                    next_decoder_cache += (
+                        layer_outputs[3 if output_attentions else 1],
+                    )
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        if self.is_export:
+            next_cache = next_decoder_cache
+        else:
+            next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_cache,
+                    all_hidden_states,
+                    all_self_attns,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class CustomMBartForCausalLM(MBartForCausalLM):
+    def __init__(self, config):
+        super().__init__(config)
+        # Modify the decoder within MBartDecoderWrapper
+        self.model.decoder = CustomMBartDecoder(config)
+
+    def forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            encoder_hidden_states=None,
+            encoder_attention_mask=None,
+            head_mask=None,
+            cross_attn_head_mask=None,
+            past_key_values=None,
+            inputs_embeds=None,
+            labels=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+    ):
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        logits = self.lm_head(outputs[0])
+
+        return CausalLMOutputWithCrossAttentions(
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+
+class PPFormulaNet_Head(UniMERNetHead):
+    """
+    PPFormulaNet_Head
+    Args:
+        max_new_tokens (int): Maximum number of new tokens to generate. Default is 1536.
+        decoder_start_token_id (int): Start token ID for the decoder. Default is 0.
+        temperature (float): Temperature parameter for controlling randomness in sampling. Default is 0.2.
+        do_sample (bool): Flag to determine whether to use sampling for generation. Default is False.
+        top_p (float): Top-p (nucleus) sampling parameter for controlling diversity. Default is 0.95.
+        in_channels (int): Number of input channels for the model. Default is 1024.
+        decoder_layers (int): Number of layers in the decoder. Default is 8.
+        encoder_hidden_size (int): Size of the hidden layer in the encoder. Default is 1024.
+        decoder_ffn_dim (int): Dimension of the feed-forward network in the decoder. Default is 4096.
+        decoder_hidden_size (int): Size of the hidden layer in the decoder. Default is 1024.
+        is_export (bool): Flag indicating whether the model is to be exported. Default is False.
+        length_aware (bool): Flag to determine if the model should be aware of input sequence length. Default is True.
+        use_parallel (bool): Flag to enable or disable parallel processing. Default is False.
+        parallel_step (int): Number of steps to use in parallel processing. Default is 3.
+    """
+
+    def __init__(
+            self,
+            max_new_tokens=1536,
+            decoder_start_token_id=0,
+            temperature=0.2,
+            do_sample=False,
+            top_p=0.95,
+            in_channels=1024,
+            decoder_layers=8,
+            encoder_hidden_size=1024,
+            decoder_ffn_dim=4096,
+            decoder_hidden_size=1024,
+            is_export=False,
+            length_aware=True,
+            use_parallel=False,
+            parallel_step=3,
+    ):
+
+        super().__init__()
+
+        mbart_config_dict = {
+            "activation_dropout": 0.0,
+            "activation_function": "gelu",
+            "add_cross_attention": True,
+            "add_final_layer_norm": True,
+            "attention_dropout": 0.0,
+            "bos_token_id": 0,
+            "classifier_dropout": 0.0,
+            "d_model": decoder_hidden_size,
+            "decoder_attention_heads": 16,
+            "decoder_ffn_dim": decoder_ffn_dim,
+            "decoder_layerdrop": 0.0,
+            "decoder_layers": decoder_layers,
+            "dropout": 0.1,
+            "encoder_attention_heads": 16,
+            "encoder_ffn_dim": 4096,
+            "encoder_layerdrop": 0.0,
+            "encoder_layers": 12,
+            "eos_token_id": 2,
+            "forced_eos_token_id": 2,
+            "init_std": 0.02,
+            "is_decoder": True,
+            "is_encoder_decoder": False,
+            "output_hidden_states": False,
+            "max_position_embeddings": (
+                max_new_tokens + parallel_step if use_parallel else max_new_tokens
+            ),
+            "model_type": "mbart",
+            "num_hidden_layers": 12,
+            "pad_token_id": 1,
+            "scale_embedding": True,
+            "tie_word_embeddings": False,
+            "transformers_version": "4.40.0",
+            "use_cache": True,
+            "use_return_dict": True,
+            "vocab_size": 50000,
+            "_attn_implementation": "eager",
+            "hidden_size": decoder_hidden_size,
+            "use_parallel": use_parallel,
+            "parallel_step": int(parallel_step),
+            "is_export": is_export,
+        }
+        self.decoder_start_token_id = decoder_start_token_id
+        self.temperature = temperature
+        self.do_sample = do_sample
+        self.top_p = top_p
+        self.is_export = is_export
+        self.max_seq_len = max_new_tokens
+        self.config_decoder = MBartConfig(**mbart_config_dict)
+        self.encoder_hidden_size = encoder_hidden_size
+        self.decoder = CustomMBartForCausalLM(self.config_decoder)
+        if self.config_decoder.hidden_size != self.encoder_hidden_size:
+            self.enc_to_dec_proj = nn.Linear(
+                self.encoder_hidden_size, self.config_decoder.hidden_size
+            )
+        generation_config = {
+            "max_length": 1537,
+            "forced_eos_token_id": 2,
+        }
+        self.eos_token_id = generation_config["forced_eos_token_id"]
+        self.pad_token_id = self.config_decoder.pad_token_id
+        self.logits_processor = LogitsProcessorList()
+        self.logits_processor.append(
+            ForcedEOSTokenLogitsProcessor(
+                generation_config["max_length"],
+                generation_config["forced_eos_token_id"],
+            )
+        )
+
+    def prepare_inputs_for_generation(
+            self,
+            input_ids,
+            past_key_values=None,
+            attention_mask=None,
+            use_cache=None,
+            encoder_outputs=None,
+            **kwargs,
+    ):
+        decoder_inputs = self.prepare_inputs_for_generation_mbart(
+            input_ids, past_key_values=past_key_values
+        )
+        decoder_attention_mask = (
+            decoder_inputs["attention_mask"]
+            if "attention_mask" in decoder_inputs
+            else None
+        )
+        input_dict = {
+            "attention_mask": attention_mask,
+            "decoder_attention_mask": decoder_attention_mask,
+            "decoder_input_ids": decoder_inputs["input_ids"],
+            "past_key_values": decoder_inputs["past_key_values"],
+            "use_cache": use_cache,
+        }
+        return input_dict
+
+    def _extract_past_from_model_output(
+            self, outputs: ModelOutput, standardize_cache_format: bool = False
+    ):
+        past_key_values = None
+        if "past_key_values" in outputs:
+            past_key_values = outputs.past_key_values
+        elif "mems" in outputs:
+            past_key_values = outputs.mems
+        elif "past_buckets_states" in outputs:
+            past_key_values = outputs.past_buckets_states
+        return past_key_values
+
+    def _update_model_kwargs_for_generation(
+            self,
+            outputs: ModelOutput,
+            model_kwargs: Dict[str, Any],
+            is_encoder_decoder: bool = False,
+            standardize_cache_format: bool = False,
+    ) -> Dict[str, Any]:
+        # update past_key_values
+        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+            outputs, standardize_cache_format=standardize_cache_format
+        )
+        if getattr(outputs, "state", None) is not None:
+            model_kwargs["state"] = outputs.state
+
+        # update token_type_ids with last value
+        if "token_type_ids" in model_kwargs:
+            token_type_ids = model_kwargs["token_type_ids"]
+            model_kwargs["token_type_ids"] = torch.concat(
+                [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
+            )
+
+        if not is_encoder_decoder:
+            # update attention mask
+            if "attention_mask" in model_kwargs:
+                attention_mask = model_kwargs["attention_mask"]
+                model_kwargs["attention_mask"] = torch.concat(
+                    [
+                        attention_mask,
+                        attention_mask.new_ones((attention_mask.shape[0], 1)),
+                    ],
+                    dim=-1,
+                )
+        else:
+            # update decoder attention mask
+            if "decoder_attention_mask" in model_kwargs:
+                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+                model_kwargs["decoder_attention_mask"] = torch.concat(
+                    [
+                        decoder_attention_mask,
+                        decoder_attention_mask.new_ones(
+                            (decoder_attention_mask.shape[0], 1)
+                        ),
+                    ],
+                    dim=-1,
+                )
+
+        if (
+                "cache_position" in model_kwargs
+                and model_kwargs["cache_position"] is not None
+        ):
+            model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
+        return model_kwargs
+
+    def stopping_criteria(self, input_ids):
+        if self.is_export:
+            return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
+        is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
+        return is_done
+
+    def stopping_criteria_parallel(self, input_ids):
+        parallel_step = self.config_decoder.parallel_step
+
+        if self.is_export:
+            is_done_list = []
+            for i in range(parallel_step, 0, -1):
+                cur_is_done = input_ids[:, -i] == torch.Tensor([self.eos_token_id])
+                is_done_list.append(cur_is_done)
+            is_done_list = torch.Tensor(is_done_list).permute([1, 0])
+            return is_done_list
+        else:
+            is_done = torch.isin(
+                input_ids[:, -parallel_step:],
+                torch.Tensor([self.eos_token_id]).reshape([1, 1]),
+            )
+            return torch.Tensor(is_done)
+
+    def generate_single_iter(
+            self,
+            decoder_input_ids=None,
+            decoder_attention_mask=None,
+            encoder_outputs=None,
+            past_key_values=None,
+            decoder_inputs_embeds=None,
+            labels=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+            **kwargs,
+    ):
+
+        encoder_hidden_states = encoder_outputs[0]
+        if self.config_decoder.hidden_size != self.encoder_hidden_size:
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+        kwargs_decoder = {}
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=None,
+            inputs_embeds=None,
+            output_attentions=False,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            return_dict=return_dict,
+            **kwargs_decoder,
+        )
+
+        return Seq2SeqLMOutput(
+            loss=None,
+            logits=decoder_outputs.logits,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+    def _prepare_decoder_input_ids_for_generation(
+            self,
+            batch_size,
+            model_kwargs,
+            decoder_start_token_id=None,
+            bos_token_id=None,
+    ):
+
+        # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+        # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+            decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+        elif "input_ids" in model_kwargs:
+            decoder_input_ids = model_kwargs.pop("input_ids")
+        else:
+            decoder_input_ids = None
+
+        # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
+        decoder_start_token_id = self._get_decoder_start_token_id(
+            decoder_start_token_id, bos_token_id
+        )
+
+        if isinstance(decoder_start_token_id, list):
+            if len(decoder_start_token_id) != batch_size:
+                raise ValueError(
+                    f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
+                )
+            decoder_input_ids_start = torch.Tensor(
+                decoder_start_token_id
+            ).to(torch.int64)
+            decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
+        else:
+            use_parallel = self.config_decoder.use_parallel
+            parallel_step = self.config_decoder.parallel_step
+
+            if use_parallel:
+                decoder_input_ids_start = (
+                        torch.ones(
+                            (batch_size, parallel_step),
+                            dtype=torch.int64,
+                        )
+                        * decoder_start_token_id
+                )
+            else:
+                decoder_input_ids_start = (
+                        torch.ones(
+                            (batch_size, 1),
+                            dtype=torch.int64,
+                        )
+                        * decoder_start_token_id
+                )
+        # no user input -> use decoder_start_token_id as decoder_input_ids
+        if decoder_input_ids is None:
+            decoder_input_ids = decoder_input_ids_start
+        # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
+        elif (
+                self.config.model_type == "vision-encoder-decoder"
+                and "donut" in self.name_or_path.lower()
+        ):
+            pass
+        elif self.config.model_type in ["whisper"]:
+            pass
+        # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
+        # decoder_attention_mask if provided)
+        elif (
+                isinstance(decoder_start_token_id, int)
+                and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
+        ) or (
+                isinstance(decoder_start_token_id, torch.Tensor)
+                and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
+        ):
+            decoder_input_ids = torch.concat(
+                [decoder_input_ids_start, decoder_input_ids], dim=-1
+            )
+            if "decoder_attention_mask" in model_kwargs:
+                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+                decoder_attention_mask = torch.cat(
+                    (
+                        torch.ones_like(decoder_attention_mask)[:, :1],
+                        decoder_attention_mask,
+                    ),
+                    dim=-1,
+                )
+                model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+        return decoder_input_ids, model_kwargs
+
+    @torch.no_grad()
+    def generate_export(
+            self,
+            encoder_outputs,
+            model_kwargs,
+    ):
+        use_parallel = self.config_decoder.use_parallel
+        parallel_step = self.config_decoder.parallel_step
+        batch_size = encoder_outputs["last_hidden_state"].shape[0]
+        generation_config = {
+            "decoder_start_token_id": 0,
+            "bos_token_id": 0,
+        }
+        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+            batch_size=batch_size,
+            model_kwargs=model_kwargs,
+            decoder_start_token_id=generation_config["decoder_start_token_id"],
+            bos_token_id=generation_config["bos_token_id"],
+        )
+        if not use_parallel:
+            input_ids = input_ids.reshape([-1, 1])
+        decoder_input_ids = input_ids
+        model_kwargs["key use_cache"] = True
+        batch_size, cur_len = input_ids.shape
+
+        if "inputs_embeds" in model_kwargs:
+            cur_len = model_kwargs["inputs_embeds"].shape[1]
+
+        cache_position = torch.arange(cur_len)
+        pad_token_id = self.pad_token_id
+        eos_token_id = [self.eos_token_id]
+        eos_token = self.eos_token_id
+        if use_parallel:
+            unfinished_sequences = torch.ones(
+                [batch_size, parallel_step], dtype=torch.int64
+            )
+            parallel_length = math.ceil(self.max_seq_len // parallel_step)
+        else:
+            unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
+            parallel_length = self.max_seq_len
+
+        i_idx = 0
+        past_key_values = []
+        decoder_attention_heads = self.config_decoder.decoder_attention_heads
+        decoder_attention_heads_dim = int(
+            self.config_decoder.d_model / decoder_attention_heads
+        )
+        for i in range(self.config_decoder.decoder_layers):
+            init_arr = torch.zeros(
+                [batch_size, decoder_attention_heads, 0, decoder_attention_heads_dim]
+            )
+            cache = (init_arr, init_arr, init_arr, init_arr)
+            past_key_values.append(cache)
+
+        while i_idx < parallel_length:
+
+            model_inputs = self.prepare_inputs_for_generation_export(
+                past_key_values=past_key_values, **model_kwargs
+            )
+            decoder_attention_mask = torch.ones(input_ids.shape)
+
+            outputs = self.generate_single_iter(
+                decoder_input_ids=decoder_input_ids,
+                decoder_attention_mask=decoder_attention_mask,
+                encoder_outputs=encoder_outputs,
+                past_key_values=past_key_values,
+                return_dict=True,
+                output_attentions=False,
+                output_hidden_states=False,
+            )
+
+            if use_parallel:
+                next_token_logits = outputs.logits[:, -parallel_step:, :]
+            else:
+                next_token_logits = outputs.logits[:, -1, :]
+            next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
+            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+
+            if eos_token_id is not None:
+                # False
+                if pad_token_id is None:
+                    raise ValueError(
+                        "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+                    )
+                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+                        1 - unfinished_sequences
+                )
+            if use_parallel:
+                input_ids = torch.concat([input_ids, next_tokens], dim=-1)
+                decoder_input_ids = next_tokens
+            else:
+                input_ids = torch.concat(
+                    [input_ids, next_tokens.unsqueeze(1)], dim=-1
+                )
+                decoder_input_ids = next_tokens.unsqueeze(1)
+
+            past_length = past_key_values[0][0].shape[2]
+
+            past_key_values = outputs.past_key_values
+            cache_position = cache_position[-1:] + 1
+            if use_parallel:
+                unfinished_sequences = (
+                        unfinished_sequences
+                        & ~self.stopping_criteria_parallel(input_ids).to(torch.int64)
+                )
+            else:
+                unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
+                    input_ids
+                ).to(torch.int64)
+
+            if (
+                    eos_token is not None
+                    and (
+                    torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
+                    >= 1
+            ).all()
+            ):
+                break
+            i_idx += 1
+            # break
+
+        return input_ids
+
+    @torch.no_grad()
+    def generate(
+            self,
+            encoder_outputs,
+            model_kwargs,
+    ):
+        """
+        Generate sequences from the model without computing gradients.
+
+        This method is used to generate sequences from the model based on the given encoder outputs.
+        It does not compute gradients, making it suitable for inference.
+
+        Args:
+            encoder_outputs: The outputs from the encoder, typically including hidden states necessary for generation.
+            model_kwargs: Additional keyword arguments that may include parameters such as maximum length,
+                        temperature, top-k/top-p sampling parameters, and other generation-specific settings.
+
+        Returns:
+            Generated sequences based on the encoder outputs and specified generation parameters.
+        """
+        use_parallel = self.config_decoder.use_parallel
+        parallel_step = self.config_decoder.parallel_step
+        batch_size = encoder_outputs["last_hidden_state"].shape[0]
+        generation_config = {
+            "decoder_start_token_id": 0,
+            "bos_token_id": 0,
+        }
+
+        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+            batch_size=batch_size,
+            model_kwargs=model_kwargs,
+            decoder_start_token_id=generation_config["decoder_start_token_id"],
+            bos_token_id=generation_config["bos_token_id"],
+        )
+
+        decoder_input_ids = input_ids
+        model_kwargs["key use_cache"] = True
+        batch_size, cur_len = input_ids.shape
+
+        if "inputs_embeds" in model_kwargs:
+            cur_len = model_kwargs["inputs_embeds"].shape[1]
+        model_kwargs["cache_position"] = torch.arange(cur_len)
+        pad_token_id = self.pad_token_id
+        eos_token_id = [self.eos_token_id]
+        eos_token = self.eos_token_id
+        if use_parallel:
+            unfinished_sequences = torch.ones(
+                [batch_size, parallel_step], dtype=torch.int64
+            )
+            parallel_length = math.ceil(self.max_seq_len // parallel_step)
+        else:
+            unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
+            parallel_length = self.max_seq_len
+        past_key_values = []
+
+        for idx in range(parallel_length):
+
+            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+            outputs = self.generate_single_iter(
+                **model_inputs,
+                encoder_outputs=encoder_outputs,
+                return_dict=True,
+                output_attentions=False,
+                output_hidden_states=False,
+            )
+
+            if use_parallel:
+                next_token_logits = outputs.logits[:, :, :]
+            else:
+                next_token_logits = outputs.logits[:, -1, :]
+
+            next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
+            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+            if eos_token_id is not None:
+                # False
+                if pad_token_id is None:
+                    raise ValueError(
+                        "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+                    )
+                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+                        1 - unfinished_sequences
+                )
+            if use_parallel:
+                input_ids = torch.concat([input_ids, next_tokens], dim=-1)
+            else:
+                input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
+
+            model_kwargs = self._update_model_kwargs_for_generation(
+                outputs,
+                model_kwargs,
+                is_encoder_decoder=self.config_decoder.is_encoder_decoder,
+            )
+            if use_parallel:
+                unfinished_sequences = (
+                        unfinished_sequences
+                        & ~self.stopping_criteria_parallel(input_ids).to(torch.int64)
+                )
+            else:
+                unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
+                    input_ids
+                ).to(torch.int64)
+
+            if (
+                    eos_token is not None
+                    and (
+                    torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
+                    >= 1
+            ).all()
+            ):
+                break
+        return input_ids
+
+    def forwad_train(
+            self,
+            encoder_outputs,
+            decoder_input_ids,
+            decoder_attention_mask,
+            past_key_values=None,
+            decoder_inputs_embeds=None,
+            labels=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+            **kwargs,
+    ):
+        """
+        Forward pass for training the model.
+
+        Args:
+            encoder_outputs: The outputs from the encoder, typically including hidden states.
+            decoder_input_ids: Input IDs for the decoder.
+            decoder_attention_mask: Attention mask for the decoder inputs to avoid attending to padding tokens.
+            past_key_values: Previously computed key and value states for the decoder, used for fast generation.
+            decoder_inputs_embeds: Optional embeddings for decoder inputs, used instead of decoder_input_ids if provided.
+            labels: Labels for computing the training loss.
+            use_cache: Whether to use a cache of past key values for faster generation.
+            output_attentions: Whether to output attention weights.
+            output_hidden_states: Whether to output hidden states of all layers.
+            return_dict: Whether to return the output as a dictionary.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            Depending on the `return_dict` flag, returns either a dictionary of model outputs or a tuple.
+        """
+        if self.config_decoder.use_parallel:
+            batch = decoder_input_ids.shape[0]
+            add_sos_token = self.config_decoder.parallel_step - 1
+            start_token = torch.zeros([batch, add_sos_token]).to(torch.int64)
+            start_mask = torch.ones([batch, add_sos_token]).to(torch.int64)
+            decoder_input_ids = torch.concat([start_token, decoder_input_ids], dim=1)
+            decoder_attention_mask = torch.concat(
+                [start_mask, decoder_attention_mask], dim=1
+            )
+
+        labels = decoder_input_ids * 1
+        labels = labels.masked_fill_(labels == self.pad_token_id, -100)
+        if self.config_decoder.use_parallel:
+            input_decoder_input_ids = decoder_input_ids[
+                :, : -self.config_decoder.parallel_step
+            ]
+            input_decoder_attention_mask = decoder_attention_mask[
+                :, : -self.config_decoder.parallel_step
+            ]
+        else:
+            input_decoder_input_ids = decoder_input_ids[:, :-1]
+            input_decoder_attention_mask = decoder_attention_mask[:, :-1]
+
+        encoder_hidden_states = encoder_outputs[0]
+        kwargs_decoder = {}
+        if self.config_decoder.hidden_size != self.encoder_hidden_size:
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+        decoder_outputs = self.decoder(
+            input_ids=input_decoder_input_ids,
+            attention_mask=input_decoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=None,
+            inputs_embeds=None,
+            output_attentions=False,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            return_dict=return_dict,
+            **kwargs_decoder,
+        )
+
+        logits = decoder_outputs.logits
+        return logits, labels
+
+    # forward for export
+    def forward(self, inputs, targets=None):
+        self.is_export = False if self.training else True
+        if not self.training:
+            encoder_outputs = inputs
+            model_kwargs = {
+                "output_attentions": False,
+                "output_hidden_states": False,
+                "use_cache": True,
+            }
+            if self.is_export:
+                word_pred = self.generate_export(encoder_outputs, model_kwargs)
+            else:
+                word_pred = self.generate(encoder_outputs, model_kwargs)
+
+            return word_pred
+        encoder_outputs, tgt_seq, mask = inputs
+        logits, masked_labels = self.forwad_train(encoder_outputs, tgt_seq, mask)
+
+        return logits, masked_labels

+ 2624 - 0
mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py

@@ -0,0 +1,2624 @@
+import copy
+import math
+import re
+import numpy as np
+import inspect
+import warnings
+from collections import OrderedDict
+from typing import Optional, Tuple, Union, List, Dict, Any
+from dataclasses import dataclass, fields, is_dataclass
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+
+
+class ModelOutput(OrderedDict):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def __post_init__(self):
+        class_fields = fields(self)
+
+        if not len(class_fields):
+            raise ValueError(f"{self.__class__.__name__} has no fields.")
+        if not all(field.default is None for field in class_fields[1:]):
+            raise ValueError(
+                f"{self.__class__.__name__} should not have more than one required field."
+            )
+
+        first_field = getattr(self, class_fields[0].name)
+        other_fields_are_none = all(
+            getattr(self, field.name) is None for field in class_fields[1:]
+        )
+        if other_fields_are_none:
+            if isinstance(first_field, dict):
+                iterator = first_field.items()
+                first_field_iterator = True
+            else:
+                try:
+                    iterator = iter(first_field)
+                    first_field_iterator = True
+                except TypeError:
+                    first_field_iterator = False
+
+            if first_field_iterator:
+                for idx, element in enumerate(iterator):
+                    if (
+                            not isinstance(element, (list, tuple))
+                            or not len(element) == 2
+                            or not isinstance(element[0], str)
+                    ):
+                        if idx == 0:
+                            self[class_fields[0].name] = first_field
+                        else:
+                            raise ValueError(
+                                f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
+                            )
+                        break
+                    setattr(self, element[0], element[1])
+                    if element[1] is not None:
+                        self[element[0]] = element[1]
+            elif first_field is not None:
+                self[class_fields[0].name] = first_field
+        else:
+            for field in class_fields:
+                v = getattr(self, field.name)
+                if v is not None:
+                    self[field.name] = v
+
+    def __delitem__(self, *args, **kwargs):
+        raise Exception(
+            f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
+        )
+
+    def setdefault(self, *args, **kwargs):
+        raise Exception(
+            f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
+        )
+
+    def pop(self, *args, **kwargs):
+        raise Exception(
+            f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
+        )
+
+    def update(self, *args, **kwargs):
+        raise Exception(
+            f"You cannot use ``update`` on a {self.__class__.__name__} instance."
+        )
+
+    def __getitem__(self, k):
+        if isinstance(k, str):
+            inner_dict = dict(self.items())
+            return inner_dict[k]
+        else:
+            return self.to_tuple()[k]
+
+    def __setattr__(self, name, value):
+        if name in self.keys() and value is not None:
+            super().__setitem__(name, value)
+        super().__setattr__(name, value)
+
+    def __setitem__(self, key, value):
+        super().__setitem__(key, value)
+        super().__setattr__(key, value)
+
+    def __reduce__(self):
+        if not is_dataclass(self):
+            return super().__reduce__()
+        callable, _args, *remaining = super().__reduce__()
+        args = tuple(getattr(self, field.name) for field in fields(self))
+        return callable, args, *remaining
+
+    def to_tuple(self):
+        return tuple(self[k] for k in self.keys())
+
+
+@dataclass
+class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+    last_hidden_state = None
+    past_key_values = None
+    hidden_states = None
+    attentions = None
+    cross_attentions = None
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
+@dataclass
+class Seq2SeqLMOutput(ModelOutput):
+    loss = None
+    logits = None
+    past_key_values = None
+    decoder_hidden_states = None
+    decoder_attentions = None
+    cross_attentions = None
+    encoder_last_hidden_state = None
+    encoder_hidden_states = None
+    encoder_attentions = None
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
+class MBartConfig(object):
+    model_type = "mbart"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    attribute_map = {
+        "num_attention_heads": "encoder_attention_heads",
+        "hidden_size": "d_model",
+    }
+
+    def __init__(
+            self,
+            vocab_size=50265,
+            max_position_embeddings=1024,
+            encoder_layers=12,
+            encoder_ffn_dim=4096,
+            encoder_attention_heads=16,
+            decoder_layers=12,
+            decoder_ffn_dim=4096,
+            decoder_attention_heads=16,
+            encoder_layerdrop=0.0,
+            decoder_layerdrop=0.0,
+            use_cache=True,
+            is_encoder_decoder=True,
+            activation_function="gelu",
+            d_model=1024,
+            dropout=0.1,
+            output_hidden_states=False,
+            use_return_dict=True,
+            attention_dropout=0.0,
+            activation_dropout=0.0,
+            init_std=0.02,
+            classifier_dropout=0.0,
+            scale_embedding=False,
+            pad_token_id=1,
+            bos_token_id=0,
+            eos_token_id=2,
+            forced_eos_token_id=2,
+            _attn_implementation="eager",
+            hidden_size=1024,
+            use_parallel=False,
+            parallel_step=2,
+            is_export=False,
+            **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.hidden_size = hidden_size
+        self.max_position_embeddings = max_position_embeddings
+        self.d_model = d_model
+        self.encoder_ffn_dim = encoder_ffn_dim
+        self.encoder_layers = encoder_layers
+        self.encoder_attention_heads = encoder_attention_heads
+        self.decoder_ffn_dim = decoder_ffn_dim
+        self.decoder_layers = decoder_layers
+        self.decoder_attention_heads = decoder_attention_heads
+        self.dropout = dropout
+        self.output_hidden_states = output_hidden_states
+        self.use_return_dict = use_return_dict
+        self.attention_dropout = attention_dropout
+        self.activation_dropout = activation_dropout
+        self.activation_function = activation_function
+        self.init_std = init_std
+        self.encoder_layerdrop = encoder_layerdrop
+        self.decoder_layerdrop = decoder_layerdrop
+        self.classifier_dropout = classifier_dropout
+        self.use_cache = use_cache
+        self.num_hidden_layers = encoder_layers
+        self.scale_embedding = (
+            scale_embedding  # scale factor will be sqrt(d_model) if True
+        )
+        self.pad_token_id = pad_token_id
+        self.bos_token_id = bos_token_id
+        self.eos_token_id = eos_token_id
+        self.is_encoder_decoder = is_encoder_decoder
+        self.forced_eos_token_id = forced_eos_token_id
+        self._attn_implementation = _attn_implementation
+        self.use_parallel = use_parallel
+        self.parallel_step = parallel_step
+        self.is_export = is_export
+        super().__init__()
+
+
+@dataclass
+class AttentionMaskConverter:
+    """
+    A utility class for converting attention masks used in transformer models.
+
+    This class handles the conversion of attention masks based on whether the
+    attention mechanism is causal (i.e., preventing information flow from future
+    tokens to past tokens) and whether a sliding window approach is used.
+
+    Attributes:
+        is_causal (bool): Indicates if the attention mechanism is causal.
+        sliding_window (Optional[int]): Specifies the size of the sliding window
+                                        for local attention, if applicable.
+
+    Args:
+        is_causal (bool): Determines if the attention mask should enforce causality.
+        sliding_window (Optional[int], optional): The size of the sliding window
+                                                  for local attention. Default is None.
+    """
+
+    is_causal: bool
+    sliding_window: int
+
+    def __init__(self, is_causal: bool, sliding_window=None):
+        self.is_causal = is_causal
+        self.sliding_window = sliding_window
+
+        if self.sliding_window is not None and self.sliding_window <= 0:
+            raise ValueError(
+                f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
+            )
+
+    @staticmethod
+    def _make_causal_mask(
+            input_ids_shape,
+            dtype,
+            past_key_values_length=0,
+            sliding_window=None,
+            is_export=False,
+    ):
+        bsz, tgt_len = input_ids_shape
+        if is_export:
+            mask = torch.full(
+                [tgt_len, tgt_len], fill_value=torch.finfo(dtype).min, dtype=torch.float64
+            )
+        else:
+            mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
+        mask_cond = torch.arange(mask.shape[-1])
+        mask = mask.masked_fill_(
+            mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
+        )
+        return mask[None, None, :, :].expand(
+            [bsz, 1, tgt_len, tgt_len + past_key_values_length]
+        )
+
+    def to_4d_export(
+            self,
+            attention_mask_2d,
+            query_length,
+            dtype,
+            key_value_length,
+            is_export=False,
+    ):
+        input_shape = (attention_mask_2d.shape[0], query_length)
+        expanded_attn_mask = self._expand_mask(
+            attention_mask_2d, dtype, tgt_len=input_shape[-1]
+        )
+        expanded_4d_mask = expanded_attn_mask
+
+        return expanded_4d_mask
+
+    def to_4d(
+            self,
+            attention_mask_2d,
+            query_length,
+            dtype,
+            key_value_length,
+            is_export=False,
+    ):
+
+        input_shape = (attention_mask_2d.shape[0], query_length)
+        causal_4d_mask = None
+        if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
+            if key_value_length is None:
+                raise ValueError(
+                    "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
+                )
+
+            past_key_values_length = key_value_length - query_length
+
+            causal_4d_mask = self._make_causal_mask(
+                input_shape,
+                dtype,
+                past_key_values_length=past_key_values_length,
+                sliding_window=self.sliding_window,
+                is_export=is_export,
+            )
+        elif self.sliding_window is not None:
+            raise NotImplementedError(
+                "Sliding window is currently only implemented for causal masking"
+            )
+
+        expanded_attn_mask = self._expand_mask(
+            attention_mask_2d, dtype, tgt_len=input_shape[-1]
+        )
+
+        if causal_4d_mask is not None:
+            if is_export:
+                expanded_attn_mask = causal_4d_mask
+                return expanded_attn_mask
+            else:
+                expanded_attn_mask = causal_4d_mask.masked_fill_(
+                    expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
+                )
+
+        expanded_4d_mask = expanded_attn_mask
+
+        return expanded_4d_mask
+
+    def _expand_mask(self, mask, dtype, tgt_len=None):
+        bsz, src_len = mask.shape
+        tgt_len = tgt_len if tgt_len is not None else src_len
+        expanded_mask = (
+            mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
+        )
+        inverted_mask = 1.0 - expanded_mask
+        return inverted_mask.masked_fill_(
+            inverted_mask.to(torch.bool), torch.finfo(dtype).min
+        )
+
+
+def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
+    return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _prepare_4d_causal_attention_mask_export(
+        attention_mask,
+        input_shape,
+        inputs_embeds,
+        past_key_values_length,
+        sliding_window=None,
+        is_export=False,
+):
+    attn_mask_converter = AttentionMaskConverter(
+        is_causal=True, sliding_window=sliding_window
+    )
+    key_value_length = input_shape[-1] + past_key_values_length
+
+    shape = attention_mask.shape
+    len_shape = len(shape)
+
+    attention_mask = attn_mask_converter.to_4d_export(
+        attention_mask,
+        input_shape[-1],
+        key_value_length=key_value_length,
+        dtype=inputs_embeds.dtype,
+        is_export=is_export,
+    )
+    return attention_mask
+
+
+def _prepare_4d_causal_attention_mask(
+        attention_mask,
+        input_shape,
+        inputs_embeds,
+        past_key_values_length,
+        sliding_window=None,
+        is_export=False,
+):
+    attn_mask_converter = AttentionMaskConverter(
+        is_causal=True, sliding_window=sliding_window
+    )
+    key_value_length = input_shape[-1] + past_key_values_length
+
+    shape = attention_mask.shape
+    len_shape = len(shape)
+    if (attention_mask is not None) and (len_shape == 2):
+        attention_mask = attn_mask_converter.to_4d(
+            attention_mask,
+            input_shape[-1],
+            key_value_length=key_value_length,
+            dtype=inputs_embeds.dtype,
+            is_export=is_export,
+        )
+
+        return attention_mask
+    elif attention_mask is not None and len(attention_mask.shape) == 4:
+        expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
+        if tuple(attention_mask.shape) != expected_shape:
+            raise ValueError(
+                f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
+            )
+        else:
+            inverted_mask = 1.0 - attention_mask
+            attention_mask = inverted_mask.masked_fill_(
+                inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
+            )
+    else:
+        attention_mask = attn_mask_converter.to_causal_4d(
+            input_shape[0],
+            input_shape[-1],
+            key_value_length,
+            dtype=inputs_embeds.dtype,
+        )
+
+    return attention_mask
+
+
+class MBartLearnedPositionalEmbedding(nn.Embedding):
+    """
+    This module learns positional embeddings up to a fixed maximum size.
+    """
+
+    def __init__(self, num_embeddings, embedding_dim):
+        self.offset = 2
+        super().__init__(num_embeddings + self.offset, embedding_dim)
+
+    def forward(self, input_ids, past_key_values_length=0):
+        """`input_ids' shape is expected to be [bsz x seqlen]."""
+        bsz, seq_len = input_ids.shape[:2]
+        positions = torch.arange(
+            past_key_values_length, past_key_values_length + seq_len, dtype=torch.int64
+        ).expand([bsz, -1])
+        return nn.Embedding.forward(self, positions + self.offset)
+
+
+class MBartPreTrainedModel(nn.Module):
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
+    _supports_flash_attn_2 = True
+
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+    def _initialize_weights(self, module):
+        """
+        Initialize the weights if they are not already initialized.
+        """
+        if getattr(module, "_is_hf_initialized", False):
+            return
+        self._init_weights(module)
+
+    def post_init(self):
+        self.apply(self._initialize_weights)
+
+    def _init_weights(self, module):
+        std = self.config.init_std
+        if isinstance(module, nn.Linear):
+            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
+            if module.bias is not None:
+                torch.nn.init.constant_(module.bias, val=0.0)
+        elif isinstance(module, nn.Embedding):
+            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
+            if module.padding_idx is not None:
+                torch.nn.init.constant_(module.weight[module.padding_idx], val=0.0)
+
+    @property
+    def dummy_inputs(self):
+        pad_token = self.config.pad_token_id
+        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
+        dummy_inputs = {
+            "attention_mask": input_ids.ne(pad_token),
+            "input_ids": input_ids,
+        }
+        return dummy_inputs
+
+
+class MBartAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(
+            self,
+            embed_dim,
+            num_heads,
+            dropout: float = 0.0,
+            is_decoder: bool = False,
+            bias: bool = True,
+            is_causal: bool = False,
+            config=None,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {num_heads})."
+            )
+        self.scaling = self.head_dim ** -0.5
+        self.is_decoder = is_decoder
+        self.is_causal = is_causal
+
+        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+    def _shape(self, tensor, seq_len, bsz):
+        return tensor.reshape([bsz, seq_len, self.num_heads, self.head_dim]).permute(
+            0, 2, 1, 3
+        )
+
+    def forward(
+            self,
+            hidden_states,
+            key_value_states=None,
+            past_key_value=None,
+            attention_mask=None,
+            layer_head_mask=None,
+            output_attentions=False,
+    ):
+
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.shape
+        query_states = self.q_proj(hidden_states) * self.scaling
+        if (
+                is_cross_attention
+                and past_key_value is not None
+                and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.concat([past_key_value[0], key_states], dim=2)
+            value_states = torch.concat([past_key_value[1], value_states], dim=2)
+        else:
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
+        key_states = key_states.reshape(proj_shape)
+        value_states = value_states.reshape(proj_shape)
+
+        src_len = key_states.shape[1]
+        attn_weights = torch.bmm(query_states, key_states.permute([0, 2, 1]))
+
+        if attention_mask is not None:
+            attn_weights = (
+                    attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
+                    + attention_mask
+            )
+            attn_weights = attn_weights.reshape(
+                [bsz * self.num_heads, tgt_len, src_len]
+            )
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+        if layer_head_mask is not None:
+            if tuple(layer_head_mask.shape) != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of shape {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.shape}"
+                )
+            attn_weights = layer_head_mask.reshape(
+                [1, -1, 1, 1]
+            ) * attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
+            attn_weights = attn_weights.reshape(
+                [bsz * self.num_heads, tgt_len, src_len]
+            )
+
+        if output_attentions:
+            attn_weights_reshaped = attn_weights.reshape(
+                [bsz, self.num_heads, tgt_len, src_len]
+            )
+            attn_weights = attn_weights_reshaped.reshape(
+                [bsz * self.num_heads, tgt_len, src_len]
+            )
+        else:
+            attn_weights_reshaped = None
+        attn_probs = nn.functional.dropout(
+            attn_weights, p=self.dropout, training=self.training
+        )
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        attn_output = attn_output.reshape([bsz, self.num_heads, tgt_len, self.head_dim])
+        attn_output = attn_output.permute([0, 2, 1, 3])
+
+        attn_output = attn_output.reshape([bsz, tgt_len, self.embed_dim])
+        attn_output = self.out_proj(attn_output)
+        return attn_output, attn_weights_reshaped, past_key_value
+
+
+MBART_ATTENTION_CLASSES = {
+    "eager": MBartAttention,
+}
+
+
+class MBartDecoderLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.embed_dim = config.d_model
+        self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            is_causal=True,
+            config=config,
+        )
+        self.is_export = config.is_export
+        self.dropout = config.dropout
+        self.activation_fn = F.gelu
+        self.activation_dropout = config.activation_dropout
+
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+            self.embed_dim,
+            config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            config=config,
+        )
+        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+            self,
+            hidden_states,
+            attention_mask=None,
+            encoder_hidden_states=None,
+            encoder_attention_mask=None,
+            layer_head_mask=None,
+            cross_attn_layer_head_mask=None,
+            past_key_value: Optional[Tuple[torch.Tensor]] = None,
+            output_attentions: Optional[bool] = False,
+            use_cache: Optional[bool] = True,
+    ) -> torch.Tensor:
+
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        self_attn_past_key_value = (
+            past_key_value[:2] if past_key_value is not None else None
+        )
+
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            past_key_value=self_attn_past_key_value,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(
+            hidden_states, p=self.dropout, training=self.training
+        )
+        hidden_states = residual + hidden_states
+
+        cross_attn_present_key_value = None
+        cross_attn_weights = None
+        if encoder_hidden_states is not None:
+            residual = hidden_states
+            hidden_states = self.encoder_attn_layer_norm(hidden_states)
+            cross_attn_past_key_value = (
+                past_key_value[-2:] if past_key_value is not None else None
+            )
+            hidden_states, cross_attn_weights, cross_attn_present_key_value = (
+                self.encoder_attn(
+                    hidden_states=hidden_states,
+                    key_value_states=encoder_hidden_states,
+                    attention_mask=encoder_attention_mask,
+                    layer_head_mask=cross_attn_layer_head_mask,
+                    past_key_value=cross_attn_past_key_value,
+                    output_attentions=output_attentions,
+                )
+            )
+            hidden_states = nn.functional.dropout(
+                hidden_states, p=self.dropout, training=self.training
+            )
+            hidden_states = residual + hidden_states
+
+            present_key_value = present_key_value + cross_attn_present_key_value
+
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(
+            hidden_states, p=self.activation_dropout, training=self.training
+        )
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(
+            hidden_states, p=self.dropout, training=self.training
+        )
+        hidden_states = residual + hidden_states
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights, cross_attn_weights)
+
+        if self.is_export:
+            outputs += (present_key_value,)
+        else:
+            if use_cache:
+                outputs += (present_key_value,)
+        return outputs
+
+
+class MBartForCausalLM(MBartPreTrainedModel):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        config = copy.deepcopy(config)
+        config.is_decoder = True
+        config.is_encoder_decoder = False
+        super().__init__(config)
+        self.model = MBartDecoderWrapper(config)
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.decoder.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.decoder.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model.decoder = decoder
+
+    def get_decoder(self):
+        return self.model.decoder
+
+    def forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            encoder_hidden_states=None,
+            encoder_attention_mask=None,
+            head_mask=None,
+            cross_attn_head_mask=None,
+            past_key_values=None,
+            inputs_embeds=None,
+            labels=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+    ):
+
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        outputs = self.model.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        logits = self.lm_head(outputs[0])
+
+        loss = None
+        if labels is not None:
+            labels = labels
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(
+                logits.reshape([-1, self.config.vocab_size]), labels.reshape([-1])
+            )
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    def prepare_inputs_for_generation(
+            self,
+            input_ids,
+            past_key_values=None,
+            attention_mask=None,
+            use_cache=None,
+            **kwargs,
+    ):
+        if attention_mask is None:
+            attention_mask = input_ids.new_ones(input_ids.shape)
+
+        if past_key_values:
+            past_length = past_key_values[0][0].shape[2]
+
+            if input_ids.shape[1] > past_length:
+                remove_prefix_length = past_length
+            else:
+                remove_prefix_length = input_ids.shape[1] - 1
+
+            input_ids = input_ids[:, remove_prefix_length:]
+        return {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            "past_key_values": past_key_values,
+            "use_cache": use_cache,
+        }
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            reordered_past += (
+                tuple(
+                    past_state.index_select(0, beam_idx) for past_state in layer_past
+                ),
+            )
+        return reordered_past
+
+
+class myLayerNorm(nn.LayerNorm):
+    """
+    Custom implementation of Layer Normalization, with additional options.
+
+    This class extends the standard LayerNorm to include optional features,
+    such as drop block regularization, which might be used for improving
+    model generalization.
+
+    Args:
+        num_channels (int): The number of features or channels in the input.
+        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
+        affine (bool, optional): If True, this module has learnable affine parameters (gamma and beta). Default is True.
+        drop_block (optional): Additional regularization technique that might be applied. Default is None.
+
+    """
+
+    def __init__(
+            self,
+            num_channels,
+            eps=1e-5,
+            affine=True,
+            drop_block=None,
+    ):
+        super(nn.LayerNorm, self).__init__()
+        self._epsilon = eps
+        self.num_channels = num_channels
+        if affine:
+            self.weight = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
+            self.bias = torch.nn.Parameter(torch.randn([num_channels]) * 0.01)
+            torch.nn.init.ones_(self.weight)
+            torch.nn.init.zeros_(self.bias)
+
+    def forward(self, x):
+        x = F.layer_norm(
+            x,
+            [self.num_channels],
+            weight=self.weight,
+            bias=self.bias,
+            eps=self._epsilon,
+        )
+        return x
+
+
+class MBartDecoder(MBartPreTrainedModel):
+    """
+    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
+
+    Args:
+        config
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self, config, embed_tokens=None):
+        super().__init__(config)
+        self.dropout = config.dropout
+        self.layerdrop = config.decoder_layerdrop
+        self.padding_idx = config.pad_token_id
+        self.max_target_positions = config.max_position_embeddings
+        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+        self.embed_tokens = nn.Embedding(
+            config.vocab_size, config.d_model, self.padding_idx
+        )
+
+        if embed_tokens is not None:
+            self.embed_tokens.weight = embed_tokens.weight
+
+        self.embed_positions = MBartLearnedPositionalEmbedding(
+            config.max_position_embeddings,
+            config.d_model,
+        )
+        self.layers = nn.ModuleList(
+            [MBartDecoderLayer(config) for _ in range(config.decoder_layers)]
+        )
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        self.layernorm_embedding = myLayerNorm(config.d_model, affine=True)
+        self.layer_norm = nn.LayerNorm(config.d_model)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+        self.is_export = config.is_export
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            encoder_hidden_states=None,
+            encoder_attention_mask=None,
+            head_mask=None,
+            cross_attn_head_mask=None,
+            past_key_values=None,
+            inputs_embeds=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+    ):
+
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+            )
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.shape
+            input_ids = input_ids.reshape([-1, input_shape[-1]])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.shape[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError(
+                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+            )
+
+        past_key_values_length = (
+            past_key_values[0][0].shape[2] if past_key_values is not None else 0
+        )
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        if self._use_flash_attention_2:
+            attention_mask = (
+                attention_mask
+                if (attention_mask is not None and 0 in attention_mask)
+                else None
+            )
+        else:
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask,
+                input_shape,
+                inputs_embeds,
+                past_key_values_length,
+                is_export=self.is_export,
+            )
+
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = (
+                    encoder_attention_mask if 0 in encoder_attention_mask else None
+                )
+            else:
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.embed_positions(input, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions
+        hidden_states = self.layernorm_embedding(hidden_states)
+
+        hidden_states = nn.functional.dropout(
+            hidden_states, p=self.dropout, training=self.training
+        )
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                print(
+                    "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = (
+            () if (output_attentions and encoder_hidden_states is not None) else None
+        )
+        next_decoder_cache = () if use_cache else None
+
+        for attn_mask, mask_name in zip(
+                [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
+        ):
+            if attn_mask is not None:
+                if attn_mask.shape[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.shape[0]}."
+                    )
+
+        for idx, decoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = (
+                past_key_values[idx] if past_key_values is not None else None
+            )
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    (
+                        cross_attn_head_mask[idx]
+                        if cross_attn_head_mask is not None
+                        else None
+                    ),
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx]
+                        if cross_attn_head_mask is not None
+                        else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_cache,
+                    all_hidden_states,
+                    all_self_attns,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class MBartDecoderWrapper(MBartPreTrainedModel):
+    """
+    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+    used in combination with the [`EncoderDecoderModel`] framework.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.decoder = MBartDecoder(config)
+
+    def forward(self, *args, **kwargs):
+        return self.decoder(*args, **kwargs)
+
+
+def _in_projection(
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        w_q: torch.Tensor,
+        w_k: torch.Tensor,
+        w_v: torch.Tensor,
+        b_q: Optional[torch.Tensor] = None,
+        b_k: Optional[torch.Tensor] = None,
+        b_v: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
+    assert w_q.shape == (
+        Eq,
+        Eq,
+    ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
+    assert w_k.shape == (
+        Eq,
+        Ek,
+    ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
+    assert w_v.shape == (
+        Eq,
+        Ev,
+    ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
+    assert b_q is None or b_q.shape == (
+        Eq,
+    ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
+    assert b_k is None or b_k.shape == (
+        Eq,
+    ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
+    assert b_v is None or b_v.shape == (
+        Eq,
+    ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
+    return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
+
+
+def _scaled_dot_product_attention(
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        attn_mask: Optional[torch.Tensor] = None,
+        dropout_p: float = 0.0,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    B, Nt, E = q.shape
+    q = q / math.sqrt(E)
+    attn = torch.bmm(q, k.permute([0, 2, 1]))
+    if attn_mask is not None:
+        attn += attn_mask
+    attn = F.softmax(attn, dim=-1)
+    if dropout_p > 0.0:
+        attn = F.dropout(attn, p=dropout_p)
+    output = torch.bmm(attn, v)
+    return output, attn
+
+
+def linear(x, w, b, is_transpose):
+    if is_transpose:
+        w = w.T
+    if b is not None:
+        return torch.matmul(x, w) + b
+    else:
+        return torch.matmul(x, w)
+
+
+def _in_projection_packed(
+        q: Tensor,
+        k: Tensor,
+        v: Tensor,
+        w: Tensor,
+        b: Optional[Tensor] = None,
+        is_export=False,
+) -> List[Tensor]:
+    E = q.shape[-1]
+    if k is v:
+        if q is k:
+            proj = linear(q, w, b, is_transpose=True)
+            if is_export:
+                B, D, L = proj.shape
+                proj = proj.reshape([B, D, 3, E])
+                proj = (
+                    proj.unsqueeze(0)
+                    .permute([3, 1, 2, 0, 4])
+                    .squeeze(-2)
+                    .contiguous()
+                )
+            else:
+                proj = (
+                    proj.unflatten(-1, (3, E))
+                    .unsqueeze(0)
+                    .permute([3, 1, 2, 0, 4])
+                    .squeeze(-2)
+                    .contiguous()
+                )
+            return proj[0], proj[1], proj[2]
+    else:
+        w_q, w_k, w_v = w.chunk(3)
+        if b is None:
+            b_q = b_k = b_v = None
+        else:
+            b_q, b_k, b_v = b.chunk(3)
+        return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
+
+
+def multi_head_attention_forward(
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        embed_dim_to_check: int,
+        num_heads: int,
+        in_proj_weight: torch.Tensor,
+        in_proj_bias: Optional[torch.Tensor],
+        bias_k: Optional[torch.Tensor],
+        bias_v: Optional[torch.Tensor],
+        add_zero_attn: bool,
+        dropout_p: float,
+        out_proj_weight: torch.Tensor,
+        out_proj_bias: Optional[torch.Tensor],
+        training: bool = True,
+        key_padding_mask: Optional[torch.Tensor] = None,
+        need_weights: bool = True,
+        attn_mask: Optional[torch.Tensor] = None,
+        use_separate_proj_weight: bool = False,
+        q_proj_weight: Optional[torch.Tensor] = None,
+        k_proj_weight: Optional[torch.Tensor] = None,
+        v_proj_weight: Optional[torch.Tensor] = None,
+        static_k: Optional[torch.Tensor] = None,
+        static_v: Optional[torch.Tensor] = None,
+        is_export=False,
+):
+    tgt_len, bsz, embed_dim = query.shape
+    src_len, _, _ = key.shape
+
+    if isinstance(embed_dim, torch.Tensor):
+        head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
+    else:
+        head_dim = embed_dim // num_heads
+    q, k, v = _in_projection_packed(
+        query, key, value, in_proj_weight, in_proj_bias, is_export
+    )
+
+    if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+        warnings.warn(
+            "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
+        )
+        key_padding_mask = key_padding_mask.to(torch.bool)
+
+    if bias_k is not None and bias_v is not None:  # False
+        assert static_k is None, "bias cannot be added to static key."
+        assert static_v is None, "bias cannot be added to static value."
+        k = torch.concat([k, bias_k.repeat(1, bsz, 1)])
+        v = torch.concat([v, bias_v.repeat(1, bsz, 1)])
+    else:
+        assert bias_k is None
+        assert bias_v is None
+
+    q = q.reshape([tgt_len, bsz * num_heads, head_dim]).permute([1, 0, 2])
+    if static_k is None:  # True
+        k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).permute([1, 0, 2])
+    else:
+        assert (
+                static_k.shape[0] == bsz * num_heads
+        ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
+        assert (
+                static_k.shape[2] == head_dim
+        ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
+        k = static_k
+    if static_v is None:  # True
+        v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
+    else:
+        assert (
+                static_v.shape[0] == bsz * num_heads
+        ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
+        assert (
+                static_v.shape[2] == head_dim
+        ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
+        v = static_v
+
+    src_len = k.shape[1]
+
+    if not training:
+        dropout_p = 0.0
+
+    attn_output, attn_output_weights = _scaled_dot_product_attention(
+        q, k, v, attn_mask, dropout_p
+    )
+
+    attn_output = attn_output.permute([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
+    attn_output = linear(
+        attn_output, out_proj_weight, out_proj_bias, is_transpose=False
+    )
+
+    if need_weights:
+        attn_output_weights = attn_output_weights.reshape(
+            [bsz, num_heads, tgt_len, src_len]
+        )
+        return attn_output, attn_output_weights.sum(dim=1) / num_heads
+    else:
+        return attn_output, None
+
+
+class MyMultiheadAttention(nn.Module):
+    """
+    Custom implementation of a multi-head attention layer.
+
+    Attributes:
+        __constants__ (list): List of constant attributes.
+        bias_k (Optional[paddle.Tensor]): Optional tensor for key bias.
+        bias_v (Optional[paddle.Tensor]): Optional tensor for value bias.
+
+    Args:
+        embed_dim (int): Total dimension of the model. This is the size of the input feature vectors.
+        num_heads (int): Number of parallel attention heads. The input dimension must be divisible by the number of heads.
+        dropout (float, optional): Dropout probability on the attention weights. Default is 0.0.
+        bias (bool, optional): If True, adds a learnable bias to the output. Default is True.
+        add_bias_kv (bool, optional): If True, adds bias to the key and value sequences. Default is False.
+        add_zero_attn (bool, optional): If True, adds a zero attention head. Default is False.
+        kdim (int, optional): Total number of features for keys. If None, defaults to embed_dim.
+        vdim (int, optional): Total number of features for values. If None, defaults to embed_dim.
+        batch_first (bool, optional): If True, the input and output tensors are provided as (batch, seq, feature). Default is False.
+        device (optional): The device on which the layer's parameters should be initialized. Default is None.
+        dtype (optional): The data type for the parameters. Default is None.
+        is_export (bool, optional): If True, the layer is set up for export, potentially changing behavior for compatibility. Default is False.
+    """
+
+    __constants__ = ["batch_first"]
+    bias_k: Optional[torch.Tensor]
+    bias_v: Optional[torch.Tensor]
+
+    def __init__(
+            self,
+            embed_dim,
+            num_heads,
+            dropout=0.0,
+            bias=True,
+            add_bias_kv=False,
+            add_zero_attn=False,
+            kdim=None,
+            vdim=None,
+            batch_first=False,
+            device=None,
+            dtype=None,
+            is_export=False,
+    ) -> None:
+        super(MyMultiheadAttention, self).__init__()
+        self.embed_dim = embed_dim
+        self.kdim = kdim if kdim is not None else embed_dim
+        self.vdim = vdim if vdim is not None else embed_dim
+        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.batch_first = batch_first
+        self.head_dim = embed_dim // num_heads
+        self.is_export = is_export
+        assert (
+                self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        if self._qkv_same_embed_dim is False:
+            pass
+        else:
+            if dtype is None:
+                dtype = torch.float32
+            self.in_proj_weight = torch.nn.Parameter(torch.randn(3 * embed_dim, embed_dim) * 0.01)
+            self.q_proj_weight = None
+            self.k_proj_weight = None
+            self.v_proj_weight = None
+
+        if bias:
+            self.in_proj_bias = torch.nn.Parameter(torch.randn(3 * embed_dim, ) * 0.01)
+            torch.nn.init.zeros_(self.in_proj_bias)
+        else:
+            self.in_proj_bias = None
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+        if add_bias_kv:
+            pass
+        else:
+            self.bias_k = self.bias_v = None
+
+        self.add_zero_attn = add_zero_attn
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+
+        if self._qkv_same_embed_dim:
+            torch.nn.init.xavier_normal_(self.in_proj_weight)
+        else:
+            torch.nn.init.xavier_normal_(self.q_proj_weight)
+            torch.nn.init.xavier_normal_(self.k_proj_weight)
+            torch.nn.init.xavier_normal_(self.v_proj_weight)
+
+        if self.in_proj_bias is not None:
+            torch.nn.init.zeros_(self.in_proj_bias)
+            torch.nn.init.zeros_(self.out_proj.bias)
+        if self.bias_k is not None:
+            torch.nn.init.xavier_normal_(self.bias_k)
+        if self.bias_v is not None:
+            torch.nn.init.xavier_normal_(self.bias_v)
+
+    def forward(
+            self,
+            query: torch.Tensor,
+            key: torch.Tensor,
+            value: torch.Tensor,
+            key_padding_mask: Optional[torch.Tensor] = None,
+            need_weights: bool = True,
+            attn_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+
+        attn_output, attn_output_weights = multi_head_attention_forward(
+            query,
+            key,
+            value,
+            self.embed_dim,
+            self.num_heads,
+            self.in_proj_weight,
+            self.in_proj_bias,
+            self.bias_k,
+            self.bias_v,
+            self.add_zero_attn,
+            self.dropout,
+            self.out_proj.weight,
+            self.out_proj.bias,
+            training=self.training,
+            key_padding_mask=key_padding_mask,
+            need_weights=need_weights,
+            attn_mask=attn_mask,
+            is_export=self.is_export,
+        )
+
+        return attn_output, attn_output_weights
+
+
+class LogitsProcessorList(list):
+    """
+    A list of logits processors that can be applied sequentially.
+
+    Methods:
+        __call__(input_ids, scores, **kwargs): Apply all processors to the given inputs.
+    """
+
+    def __call__(self, input_ids, scores, **kwargs):
+        for processor in self:
+            function_args = inspect.signature(processor.__call__).parameters
+            if len(function_args) > 2:
+                if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
+                    raise ValueError(
+                        f"Make sure that all the required parameters: {list(function_args.keys())} for "
+                        f"{processor.__class__} are passed to the logits processor."
+                    )
+                scores = processor(input_ids, scores, **kwargs)
+            else:
+                scores = processor(input_ids, scores)
+        return scores
+
+
+class ForcedEOSTokenLogitsProcessor(object):
+    """
+    A processor that forces the generation of an end-of-sequence (EOS) token
+    at a specified position in the sequence.
+
+    This is typically used in language generation tasks to ensure that the
+    generated sequence ends properly when it reaches a certain length.
+
+    Args:
+        max_length (int): The maximum length of the sequence. Forces EOS when this length is reached.
+        eos_token_id (Union[int, List[int]]): The ID(s) of the EOS token(s) to be forced in the sequence.
+    """
+
+    def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
+        self.max_length = max_length
+        if isinstance(eos_token_id, int):
+            eos_token_id = [eos_token_id]
+        self.eos_token_id = eos_token_id
+
+    def __call__(self, input_ids, scores):
+        cur_len = input_ids.shape[-1]
+        scores_processed = scores
+        if cur_len == self.max_length - 1:
+            scores_processed = torch.full_like(scores, -math.inf)
+            scores_processed[:, self.eos_token_id] = 0
+        return scores_processed
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentions(ModelOutput):
+    loss = None
+    logits = None
+    past_key_values = None
+    hidden_states = None
+    attentions = None
+    cross_attentions = None
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+    """
+
+    logits = None
+    counting = None
+    past_key_values = None
+    hidden_states = None
+    attentions = None
+    cross_attentions = None
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
+class CustomMBartDecoder(MBartDecoder):
+    """
+    A custom MBartDecoder that includes additional processing layers.
+
+    This class extends the MBartDecoder by adding a customizable neural network
+    component called `counting_context_weight`, which applies a series of linear
+    transformations followed by ReLU activations. This can be used to modify or
+    enhance the decoder's behavior for specific tasks.
+
+    Args:
+        config: The configuration object containing model parameters.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        hidden_size = config.d_model
+        self.is_export = config.is_export
+        self.counting_context_weight = nn.Sequential(
+            nn.Linear(config.vocab_size, hidden_size),
+            nn.ReLU(),
+            nn.Linear(hidden_size, hidden_size),
+            nn.ReLU(),
+            nn.Linear(hidden_size, config.d_model),
+        )
+
+    def forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            count_pred=None,
+            encoder_hidden_states=None,
+            encoder_attention_mask=None,
+            head_mask=None,
+            cross_attn_head_mask=None,
+            past_key_values=None,
+            inputs_embeds=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+    ):
+        self.is_export = False if self.training else True
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+            )
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.shape
+            input_ids = input_ids.reshape([-1, input_shape[-1]])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.shape[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError(
+                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+            )
+
+        past_key_values_length = (
+            past_key_values[0][0].shape[2] if past_key_values is not None else 0
+        )
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        if self._use_flash_attention_2:
+            attention_mask = (
+                attention_mask
+                if (attention_mask is not None and 0 in attention_mask)
+                else None
+            )
+        else:
+            if self.is_export:
+                attention_mask = _prepare_4d_causal_attention_mask_export(
+                    attention_mask,
+                    input_shape,
+                    inputs_embeds,
+                    past_key_values_length,
+                    is_export=self.is_export,
+                ).to(torch.float32)
+            else:
+                attention_mask = _prepare_4d_causal_attention_mask(
+                    attention_mask,
+                    input_shape,
+                    inputs_embeds,
+                    past_key_values_length,
+                    is_export=self.is_export,
+                )
+
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = (
+                    encoder_attention_mask if 0 in encoder_attention_mask else None
+                )
+            else:
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.embed_positions(input, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions
+
+        # TODO: add counting context weight to hidden_states
+        if count_pred is not None:
+            count_context_weight = self.counting_context_weight(count_pred)
+            hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
+
+        hidden_states = self.layernorm_embedding(hidden_states)
+        hidden_states = nn.functional.dropout(
+            hidden_states, p=self.dropout, training=self.training
+        )
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                print(
+                    "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = (
+            () if (output_attentions and encoder_hidden_states is not None) else None
+        )
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip(
+                [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
+        ):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+
+        for idx, decoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand()
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = (
+                past_key_values[idx] if past_key_values is not None else None
+            )
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    (
+                        cross_attn_head_mask[idx]
+                        if cross_attn_head_mask is not None
+                        else None
+                    ),
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx]
+                        if cross_attn_head_mask is not None
+                        else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+            if self.is_export:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+            else:
+                if use_cache:
+                    next_decoder_cache += (
+                        layer_outputs[3 if output_attentions else 1],
+                    )
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+        if self.is_export:
+            next_cache = next_decoder_cache
+        else:
+            next_cache = next_decoder_cache if use_cache else None
+        if not self.is_export:
+            if not return_dict:
+                return tuple(
+                    v
+                    for v in [
+                        hidden_states,
+                        next_cache,
+                        all_hidden_states,
+                        all_self_attns,
+                        all_cross_attentions,
+                    ]
+                    if v is not None
+                )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class SelfAttentionBlock(nn.Module):
+    """
+    A self-attention block that implements multi-head self-attention
+    followed by a feed-forward network, typically used in transformer architectures.
+
+    Args:
+        embed_size (int): The size of the embedding vector.
+        num_heads (int): The number of attention heads.
+        is_export (bool): Flag indicating whether to configure the layer for export.
+    """
+
+    def __init__(self, embed_size, num_heads, is_export):
+        super(SelfAttentionBlock, self).__init__()
+        self.self_attention = MyMultiheadAttention(
+            embed_dim=embed_size, num_heads=num_heads, is_export=is_export
+        )
+        self.norm = nn.LayerNorm(embed_size)
+
+    def forward(self, x):
+        attn_output, _ = self.self_attention(x, x, x)
+        x = self.norm(attn_output + x)
+        return x
+
+
+class SeqCountingDecoder(nn.Module):
+    """
+    A custom sequence counting decoder that incorporates multi-head attention layers
+    and feed-forward networks to process sequences, potentially for latex code counting .
+
+    Args:
+        in_features (int): The number of input features.
+        out_features (int): The number of output features.
+        num_heads (int): The number of attention heads. Defaults to 8.
+        num_layers (int): The number of attention layers. Defaults to 4.
+        is_export (bool): Flag indicating whether to configure the layer for export.
+    """
+
+    def __init__(
+            self, in_features, out_features, num_heads=8, num_layers=4, is_export=False
+    ):
+        super(SeqCountingDecoder, self).__init__()
+
+        self.attention_blocks = nn.ModuleList(
+            [
+                SelfAttentionBlock(
+                    embed_size=in_features, num_heads=num_heads, is_export=is_export
+                )
+                for i in range(num_layers)
+            ]
+        )
+        self.fc1 = nn.Linear(in_features, in_features // 2)
+        self.relu = nn.ReLU()
+        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
+        self.fc2 = nn.Linear(in_features // 2, out_features)
+
+    def forward(self, x):
+        for block in self.attention_blocks:
+            x = block(x)
+        x = self.fc1(x)
+        x = self.relu(x)
+        x = x.transpose([0, 2, 1])
+        x = self.global_avg_pool(x)
+        x = x.squeeze(-1)
+        x = self.fc2(x)
+        return x
+
+
+class CustomMBartForCausalLM(MBartForCausalLM):
+    """
+    Custom MBart model for causal language modeling with a custom decoder.
+
+    This class extends the MBartForCausalLM by replacing its decoder with a
+    custom decoder, allowing for additional flexibility and features in the
+    decoding process.
+
+    Args:
+        config: The configuration object containing model parameters.
+        length_aware (bool): A flag to enable or configure length-aware mechanisms.
+    """
+
+    def __init__(self, config, length_aware=True):
+        super().__init__(config)
+        self.model.decoder = CustomMBartDecoder(config)
+        self.counting_decoder = SeqCountingDecoder(
+            config.d_model, config.vocab_size, is_export=config.is_export
+        )
+        self.length_aware = length_aware
+
+    def forward(
+            self,
+            input_ids=None,
+            attention_mask=None,
+            encoder_hidden_states=None,
+            encoder_attention_mask=None,
+            head_mask=None,
+            cross_attn_head_mask=None,
+            past_key_values=None,
+            inputs_embeds=None,
+            labels=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+            count_gt=None,
+    ):
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        if self.length_aware:
+            count_pred = self.counting_decoder(encoder_hidden_states)
+        else:
+            count_pred = None
+
+        outputs = self.model.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            count_pred=count_pred,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        logits = self.lm_head(outputs[0])
+
+        return CausalLMOutputWithCrossAttentionsAndCounting(
+            logits=logits,
+            counting=count_pred,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+
+class UniMERNetHead(nn.Module):
+    """Implementation of UniMERNetHead decoder.
+
+    Args:
+         max_new_tokens (int): Maximum number of new tokens to generate.
+         decoder_start_token_id (int): ID of the token that starts the decoding.
+         temperature (float): Sampling temperature for generation.
+         do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
+         top_p (float): Top-p (nucleus) sampling parameter.
+         in_channels (int): Number of input channels/features.
+         encoder_hidden_size (int): Hidden size of the encoder.
+         decoder_hidden_size (int): Hidden size of the decoder.
+         decoder_ffn_dim (int): Dimension of the decoder's feed-forward network.
+         decoder_layers (int): Number of layers in the decoder.
+         is_export (bool): Flag indicating if the model is being prepared for export.
+         length_aware (bool): Flag to enable length-aware mechanisms.
+    """
+
+    def __init__(
+            self,
+            max_new_tokens=1536,
+            decoder_start_token_id=0,
+            temperature=0.2,
+            do_sample=False,
+            top_p=0.95,
+            in_channels=1024,
+            encoder_hidden_size=1024,
+            decoder_hidden_size=1024,
+            decoder_ffn_dim=4096,
+            decoder_layers=8,
+            is_export=False,
+            length_aware=True,
+    ):
+        super().__init__()
+        mbart_config_dict = {
+            "activation_dropout": 0.0,
+            "activation_function": "gelu",
+            "add_cross_attention": True,
+            "add_final_layer_norm": True,
+            "attention_dropout": 0.0,
+            "bos_token_id": 0,
+            "classifier_dropout": 0.0,
+            "d_model": decoder_hidden_size,
+            "decoder_attention_heads": 16,
+            "decoder_ffn_dim": decoder_ffn_dim,
+            "decoder_layerdrop": 0.0,
+            "decoder_layers": decoder_layers,
+            "dropout": 0.1,
+            "encoder_attention_heads": 16,
+            "encoder_ffn_dim": 4096,
+            "encoder_layerdrop": 0.0,
+            "encoder_layers": 12,
+            "eos_token_id": 2,
+            "forced_eos_token_id": 2,
+            "init_std": 0.02,
+            "is_decoder": True,
+            "is_encoder_decoder": False,
+            "output_hidden_states": False,
+            "max_position_embeddings": max_new_tokens,
+            "model_type": "mbart",
+            "num_hidden_layers": 12,
+            "pad_token_id": 1,
+            "scale_embedding": True,
+            "tie_word_embeddings": False,
+            "transformers_version": "4.40.0",
+            "use_cache": True,
+            "use_return_dict": True,
+            "vocab_size": 50000,
+            "_attn_implementation": "eager",
+            "hidden_size": decoder_hidden_size,
+            "is_export": is_export,
+        }
+
+        self.max_new_tokens = max_new_tokens
+        self.decoder_start_token_id = decoder_start_token_id
+        self.temperature = temperature
+        self.do_sample = do_sample
+        self.top_p = top_p
+        self.max_seq_len = max_new_tokens
+        self.config_decoder = MBartConfig(**mbart_config_dict)
+        self.encoder_hidden_size = encoder_hidden_size
+        self.is_export = self.config_decoder.is_export
+        self.decoder = CustomMBartForCausalLM(
+            self.config_decoder, length_aware=length_aware
+        )
+        if self.config_decoder.hidden_size != self.encoder_hidden_size:
+            self.enc_to_dec_proj = nn.Linear(
+                self.encoder_hidden_size, self.config_decoder.hidden_size
+            )
+        generation_config = {
+            "max_length": 1537,
+            "forced_eos_token_id": 2,
+        }
+        self.eos_token_id = generation_config["forced_eos_token_id"]
+        self.pad_token_id = self.config_decoder.pad_token_id
+        self.logits_processor = LogitsProcessorList()
+        self.logits_processor.append(
+            ForcedEOSTokenLogitsProcessor(
+                generation_config["max_length"],
+                generation_config["forced_eos_token_id"],
+            )
+        )
+
+    def _get_decoder_start_token_id(
+            self, decoder_start_token_id=None, bos_token_id=None
+    ) -> int:
+        decoder_start_token_id = (
+            decoder_start_token_id
+            if decoder_start_token_id is not None
+            else self.generation_config.decoder_start_token_id
+        )
+        bos_token_id = (
+            bos_token_id
+            if bos_token_id is not None
+            else self.generation_config.bos_token_id
+        )
+        if decoder_start_token_id is not None:
+            return decoder_start_token_id
+        elif bos_token_id is not None:
+            return bos_token_id
+        raise ValueError(
+            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
+        )
+
+    def _prepare_decoder_input_ids_for_generation(
+            self,
+            batch_size,
+            model_kwargs,
+            decoder_start_token_id=None,
+            bos_token_id=None,
+    ):
+        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+            decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+        elif "input_ids" in model_kwargs:
+            decoder_input_ids = model_kwargs.pop("input_ids")
+        else:
+            decoder_input_ids = None
+
+        decoder_start_token_id = self._get_decoder_start_token_id(
+            decoder_start_token_id, bos_token_id
+        )
+
+        if isinstance(decoder_start_token_id, list):
+            if len(decoder_start_token_id) != batch_size:
+                raise ValueError(
+                    f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
+                )
+            decoder_input_ids_start = torch.LongTensor(decoder_start_token_id)
+            decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
+        else:
+            decoder_input_ids_start = (
+                    torch.ones(
+                        (batch_size, 1),
+                        dtype=torch.int64,
+                    )
+                    * decoder_start_token_id
+            )
+
+        if decoder_input_ids is None:
+            decoder_input_ids = decoder_input_ids_start
+        elif (
+                self.config.model_type == "vision-encoder-decoder"
+                and "donut" in self.name_or_path.lower()
+        ):
+            pass
+        elif self.config.model_type in ["whisper"]:
+            pass
+        elif (
+                isinstance(decoder_start_token_id, int)
+                and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
+        ) or (
+                isinstance(decoder_start_token_id, torch.Tensor)
+                and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
+        ):
+            decoder_input_ids = torch.concat(
+                [decoder_input_ids_start, decoder_input_ids], dim=-1
+            )
+            if "decoder_attention_mask" in model_kwargs:
+                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+                decoder_attention_mask = torch.cat(
+                    (
+                        torch.ones_like(decoder_attention_mask)[:, :1],
+                        decoder_attention_mask,
+                    ),
+                    dim=-1,
+                )
+                model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+        return decoder_input_ids, model_kwargs
+
+    def prepare_inputs_for_generation_mbart(
+            self,
+            input_ids,
+            past_key_values=None,
+            attention_mask=None,
+            use_cache=None,
+            **kwargs,
+    ):
+
+        if attention_mask is None:
+            attention_mask = torch.ones(input_ids.shape)
+
+        if past_key_values:
+            past_length = past_key_values[0][0].shape[2]
+
+            if input_ids.shape[1] > past_length:
+                remove_prefix_length = past_length
+            else:
+                remove_prefix_length = input_ids.shape[1] - 1
+
+            input_ids = input_ids[:, remove_prefix_length:]
+        return {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            "past_key_values": past_key_values,
+            "use_cache": use_cache,
+        }
+
+    def prepare_inputs_for_generation(
+            self,
+            input_ids,
+            past_key_values=None,
+            attention_mask=None,
+            use_cache=None,
+            encoder_outputs=None,
+            **kwargs,
+    ):
+        decoder_inputs = self.prepare_inputs_for_generation_mbart(
+            input_ids, past_key_values=past_key_values
+        )
+        decoder_attention_mask = (
+            decoder_inputs["attention_mask"]
+            if "attention_mask" in decoder_inputs
+            else None
+        )
+        input_dict = {
+            "attention_mask": attention_mask,
+            "decoder_attention_mask": decoder_attention_mask,
+            "decoder_input_ids": decoder_inputs["input_ids"],
+            "encoder_outputs": encoder_outputs,
+            "past_key_values": decoder_inputs["past_key_values"],
+            "use_cache": use_cache,
+        }
+        return input_dict
+
+    def prepare_inputs_for_generation_export(
+            self,
+            past_key_values=None,
+            attention_mask=None,
+            use_cache=None,
+            encoder_outputs=None,
+            **kwargs,
+    ):
+
+        input_dict = {
+            "decoder_attention_mask": None,
+            "use_cache": use_cache,
+        }
+        return input_dict
+
+    def _extract_past_from_model_output(
+            self, outputs: ModelOutput, standardize_cache_format: bool = False
+    ):
+        past_key_values = None
+        if "past_key_values" in outputs:
+            past_key_values = outputs.past_key_values
+        elif "mems" in outputs:
+            past_key_values = outputs.mems
+        elif "past_buckets_states" in outputs:
+            past_key_values = outputs.past_buckets_states
+
+        return past_key_values
+
+    def _update_model_kwargs_for_generation(
+            self,
+            outputs: ModelOutput,
+            model_kwargs: Dict[str, Any],
+            is_encoder_decoder: bool = False,
+            standardize_cache_format: bool = False,
+    ) -> Dict[str, Any]:
+        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+            outputs, standardize_cache_format=standardize_cache_format
+        )
+        if getattr(outputs, "state", None) is not None:
+            model_kwargs["state"] = outputs.state
+
+        if "token_type_ids" in model_kwargs:
+            token_type_ids = model_kwargs["token_type_ids"]
+            model_kwargs["token_type_ids"] = torch.concat(
+                [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
+            )
+
+        if not is_encoder_decoder:
+            if "attention_mask" in model_kwargs:
+                attention_mask = model_kwargs["attention_mask"]
+                model_kwargs["attention_mask"] = torch.concat(
+                    [
+                        attention_mask,
+                        attention_mask.new_ones((attention_mask.shape[0], 1)),
+                    ],
+                    dim=-1,
+                )
+        else:
+            if "decoder_attention_mask" in model_kwargs:
+                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+                model_kwargs["decoder_attention_mask"] = torch.concat(
+                    [
+                        decoder_attention_mask,
+                        decoder_attention_mask.new_ones(
+                            (decoder_attention_mask.shape[0], 1)
+                        ),
+                    ],
+                    dim=-1,
+                )
+
+        if (
+                "cache_position" in model_kwargs
+                and model_kwargs["cache_position"] is not None
+        ):
+            model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
+
+        return model_kwargs
+
+    def stopping_criteria(self, input_ids):
+        if self.is_export:
+            return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
+        is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
+        return is_done
+
+    def generate_single_iter(
+            self,
+            decoder_input_ids=None,
+            decoder_attention_mask=None,
+            encoder_outputs=None,
+            past_key_values=None,
+            decoder_inputs_embeds=None,
+            labels=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+            **kwargs,
+    ):
+        encoder_hidden_states = encoder_outputs[0]
+        if self.config_decoder.hidden_size != self.encoder_hidden_size:
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+        kwargs_decoder = {}
+
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=None,
+            inputs_embeds=None,
+            output_attentions=False,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            return_dict=return_dict,
+            **kwargs_decoder,
+        )
+
+        return Seq2SeqLMOutput(
+            loss=None,
+            logits=decoder_outputs.logits,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+    @torch.no_grad()
+    def generate(
+            self,
+            model_kwargs,
+    ):
+        """
+        Generate sequences using the UniMERNetHead for inference tasks.
+
+        Args:
+            model_kwargs (dict): A dictionary of model configurations and inputs, which typically include:
+                - encoder_outputs: Outputs from the encoder.
+                - use_cache: Boolean flag to indicate if caching should be used.
+                - output_attentions: Boolean flag for outputting attention scores.
+                - output_hidden_states: Boolean flag for outputting hidden states.
+
+        Returns:
+            A tensor containing the generated sequences.
+        """
+        batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
+        generation_config = {
+            "decoder_start_token_id": 0,
+            "bos_token_id": 0,
+        }
+        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+            batch_size=batch_size,
+            model_kwargs=model_kwargs,
+            decoder_start_token_id=generation_config["decoder_start_token_id"],
+            bos_token_id=generation_config["bos_token_id"],
+        )
+        model_kwargs["key use_cache"] = True
+        batch_size, cur_len = input_ids.shape
+
+        if "inputs_embeds" in model_kwargs:
+            cur_len = model_kwargs["inputs_embeds"].shape[1]
+        model_kwargs["cache_position"] = torch.arange(cur_len)
+        pad_token_id = self.pad_token_id
+        eos_token_id = [self.eos_token_id]
+        eos_token = self.eos_token_id
+        unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
+        for idx in range(self.max_seq_len):
+            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+            outputs = self.generate_single_iter(
+                **model_inputs,
+                return_dict=True,
+                output_attentions=False,
+                output_hidden_states=False,
+            )
+            next_token_logits = outputs.logits[:, -1, :]
+
+            next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
+            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+            if eos_token_id is not None:
+                if pad_token_id is None:
+                    raise ValueError(
+                        "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+                    )
+                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+                        1 - unfinished_sequences
+                )
+            input_ids = torch.concat([input_ids, next_tokens[:, None]], dim=-1)
+            model_kwargs = self._update_model_kwargs_for_generation(
+                outputs,
+                model_kwargs,
+                is_encoder_decoder=self.config_decoder.is_encoder_decoder,
+            )
+            unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
+                input_ids
+            ).to(torch.int64)
+
+            if (
+                    eos_token is not None
+                    and (
+                    torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
+                    >= 1
+            ).all()
+            ):
+                break
+
+        return input_ids
+
+    @torch.no_grad()
+    def generate_export(
+            self,
+            encoder_outputs,
+            model_kwargs,
+    ):
+        batch_size = encoder_outputs["last_hidden_state"].shape[0]
+        generation_config = {
+            "decoder_start_token_id": 0,
+            "bos_token_id": 0,
+        }
+        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+            batch_size=batch_size,
+            model_kwargs=model_kwargs,
+            decoder_start_token_id=generation_config["decoder_start_token_id"],
+            bos_token_id=generation_config["bos_token_id"],
+        )
+        input_ids = input_ids.reshape([-1, 1])
+        decoder_input_ids = input_ids
+        model_kwargs["key use_cache"] = True
+        batch_size, cur_len = input_ids.shape
+
+        if "inputs_embeds" in model_kwargs:
+            cur_len = model_kwargs["inputs_embeds"].shape[1]
+        cache_position = torch.arange(cur_len)
+        pad_token_id = self.pad_token_id
+        eos_token_id = [self.eos_token_id]
+        eos_token = self.eos_token_id
+        unfinished_sequences = torch.ones([batch_size], dtype=torch.int64)
+        i_idx = torch.full([], 0)
+        past_key_values = []
+        for i in range(8):
+            init_arr = torch.zeros([batch_size, 16, 0, 64])
+            cache = (init_arr, init_arr, init_arr, init_arr)
+            past_key_values.append(cache)
+        idx = 0
+        while i_idx < torch.Tensor(self.max_seq_len):
+
+            model_inputs = self.prepare_inputs_for_generation_export(
+                past_key_values=past_key_values, **model_kwargs
+            )
+            decoder_attention_mask = model_inputs["decoder_attention_mask"]
+            decoder_attention_mask = torch.ones(input_ids.shape)
+
+            outputs = self.generate_single_iter(
+                decoder_input_ids=decoder_input_ids,
+                decoder_attention_mask=decoder_attention_mask,
+                encoder_outputs=encoder_outputs,
+                past_key_values=past_key_values,
+                return_dict=True,
+                output_attentions=False,
+                output_hidden_states=False,
+            )
+
+            next_token_logits = outputs.logits[:, -1, :]
+
+            next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
+            next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+            if eos_token_id is not None:
+                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+                        1 - unfinished_sequences
+                )
+            input_ids = torch.concat([input_ids, next_tokens.unsqueeze(1)], dim=-1)
+            past_length = past_key_values[0][0].shape[2]
+            decoder_input_ids = next_tokens.unsqueeze(1)
+            past_key_values = outputs.past_key_values
+            cache_position = cache_position[-1:] + 1
+            unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
+                input_ids
+            ).to(torch.int64)
+            if (
+                    eos_token is not None
+                    and (
+                    torch.cumsum((input_ids == eos_token).to(torch.int64), 1)[:, -1]
+                    >= 1
+            ).all()
+            ):
+                break
+
+            i_idx += 1
+        return input_ids
+
+    def forwad_train(
+            self,
+            encoder_outputs,
+            decoder_input_ids,
+            decoder_attention_mask,
+            past_key_values=None,
+            decoder_inputs_embeds=None,
+            labels=None,
+            use_cache=None,
+            output_attentions=None,
+            output_hidden_states=None,
+            return_dict=None,
+            **kwargs,
+    ):
+        """
+        Training for the UniMERNetHead.
+
+        Args:
+            encoder_outputs: Outputs from the encoder, used as input to the decoder.
+            decoder_input_ids: Input IDs for the decoder.
+            decoder_attention_mask: Attention mask for the decoder inputs.
+            past_key_values: Cached key/values for faster decoding.
+            decoder_inputs_embeds: Optional embeddings for the decoder inputs.
+            labels: Target labels for calculating loss.
+            use_cache: Whether to use cache during decoding.
+            output_attentions: Whether to return attention scores.
+            output_hidden_states: Whether to return hidden states.
+            return_dict: Whether to return a dictionary of outputs.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            logits: The raw, unnormalized predictions from the model.
+            count_pred: Optional prediction related to sequence length or other counts.
+            masked_labels: The labels used during training, possibly masked.
+        """
+        labels = decoder_input_ids * 1
+        labels = labels.masked_fill_(labels == self.pad_token_id, -100)
+        input_decoder_input_ids = decoder_input_ids[:, :-1]
+        input_decoder_attention_mask = decoder_attention_mask[:, :-1]
+        encoder_hidden_states = encoder_outputs[0]
+        if self.config_decoder.hidden_size != self.encoder_hidden_size:
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+        kwargs_decoder = {}
+        decoder_outputs = self.decoder(
+            input_ids=input_decoder_input_ids,
+            attention_mask=input_decoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=None,
+            inputs_embeds=None,
+            output_attentions=False,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            return_dict=return_dict,
+            **kwargs_decoder,
+        )
+
+        logits = decoder_outputs.logits
+        count_pred = decoder_outputs.counting
+        return logits, count_pred, labels
+
+    def forward(self, inputs, targets=None):
+        """
+        Forward pass for the UniMERNetHead, handling both training and inference.
+
+        Args:
+            inputs: The input data, which can vary based on training or inference.
+            targets: The target labels, used only during training.
+
+        Returns:
+            During inference: Returns predicted latex code.
+            During training: Returns logits, predicted counts, and masked labels.
+        """
+        self.is_export = False if self.training else True
+        if not self.training:
+            encoder_outputs = inputs
+            if self.is_export:
+                model_kwargs = {
+                    "output_attentions": False,
+                    "output_hidden_states": False,
+                    "use_cache": True,
+                }
+                word_pred = self.generate_export(encoder_outputs, model_kwargs)
+            else:
+                model_kwargs = {
+                    "output_attentions": False,
+                    "output_hidden_states": False,
+                    "use_cache": True,
+                    "encoder_outputs": encoder_outputs,
+                }
+                word_pred = self.generate(model_kwargs)
+
+            return word_pred
+
+        encoder_outputs, tgt_seq, mask = inputs
+        logits, count_pred, masked_labels = self.forwad_train(
+            encoder_outputs, tgt_seq, mask
+        )
+        return logits, count_pred, masked_labels

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/__init__.py → mineru/model/utils/pytorchocr/modeling/necks/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/db_fpn.py → mineru/model/utils/pytorchocr/modeling/necks/db_fpn.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/intracl.py → mineru/model/utils/pytorchocr/modeling/necks/intracl.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py → mineru/model/utils/pytorchocr/modeling/necks/rnn.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/__init__.py → mineru/model/utils/pytorchocr/postprocess/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/cls_postprocess.py → mineru/model/utils/pytorchocr/postprocess/cls_postprocess.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py → mineru/model/utils/pytorchocr/postprocess/db_postprocess.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py → mineru/model/utils/pytorchocr/postprocess/rec_postprocess.py


+ 0 - 0
mineru/model/utils/pytorchocr/utils/__init__.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/arch_config.yaml → mineru/model/utils/pytorchocr/utils/resources/arch_config.yaml


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/arabic_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/arabic_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/chinese_cht_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/chinese_cht_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/cyrillic_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/cyrillic_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/devanagari_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/devanagari_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/en_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/en_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/japan_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/japan_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ka_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ka_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/korean_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/korean_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/latin_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/latin_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocr_keys_v1.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv4_doc_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv5_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_el_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_en_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_eslav_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_korean_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_latin_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ppocrv5_th_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/ta_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/ta_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/dict/te_dict.txt → mineru/model/utils/pytorchocr/utils/resources/dict/te_dict.txt


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/utils/resources/models_config.yml → mineru/model/utils/pytorchocr/utils/resources/models_config.yml


+ 24 - 0
mineru/model/utils/pytorchocr/utils/resources/pp_formulanet_arch_config.yaml

@@ -0,0 +1,24 @@
+Architecture:
+  model_type: rec
+  algorithm: PP-FormulaNet_plus-M
+  in_channels: 3
+  Transform:
+  Backbone:
+    name: PPHGNetV2_B6_Formula
+    class_num: 1024
+
+  Head:
+    name: PPFormulaNet_Head
+    max_new_tokens: 2560
+    decoder_start_token_id: 0
+    decoder_ffn_dim: 2048
+    decoder_hidden_size: 512
+    decoder_layers: 6
+    temperature: 0.2
+    do_sample: False
+    top_p: 0.95 
+    encoder_hidden_size: 2048
+    is_export: False
+    length_aware: False 
+    use_parallel: False
+    parallel_step: 0

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/tools/__init__.py → mineru/model/utils/tools/__init__.py


+ 1 - 0
mineru/model/utils/tools/infer/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_cls.py → mineru/model/utils/tools/infer/predict_cls.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py → mineru/model/utils/tools/infer/predict_det.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py → mineru/model/utils/tools/infer/predict_rec.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_system.py → mineru/model/utils/tools/infer/predict_system.py


+ 0 - 0
mineru/model/ocr/paddleocr2pytorch/tools/infer/pytorchocr_utility.py → mineru/model/utils/tools/infer/pytorchocr_utility.py


+ 1 - 1
mineru/model/vlm_vllm_model/server.py

@@ -43,7 +43,7 @@ def main():
     if not has_port_arg:
         args.extend(["--port", "30000"])
     if not has_gpu_memory_utilization_arg:
-        args.extend(["--gpu-memory-utilization", "0.5"])
+        args.extend(["--gpu-memory-utilization", "0.7"])
     if not model_path:
         model_path = auto_download_and_get_model_root_path("/", "vlm")
     if (not has_logits_processors_arg) and custom_logits_processors:

+ 1 - 0
mineru/utils/enum_class.py

@@ -70,6 +70,7 @@ class ModelPath:
     doclayout_yolo = "models/Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
     yolo_v8_mfd = "models/MFD/YOLO/yolo_v8_ft.pt"
     unimernet_small = "models/MFR/unimernet_hf_small_2503"
+    pp_formulanet_plus_m = "models/MFR/pp_formulanet_plus_m"
     pytorch_paddle = "models/OCR/paddleocr_torch"
     layout_reader = "models/ReadingOrder/layout_reader"
     slanet_plus = "models/TabRec/SlanetPlus/slanet-plus.onnx"

+ 2 - 2
pyproject.toml

@@ -39,7 +39,7 @@ dependencies = [
     "openai>=1.70.0,<3",
     "beautifulsoup4>=4.13.5,<5",
     "magika>=0.6.2,<0.7.0",
-    "mineru-vl-utils>=0.1.13,<1",
+    "mineru-vl-utils>=0.1.14,<1",
 ]
 
 [project.optional-dependencies]
@@ -56,7 +56,7 @@ vlm = [
     "accelerate>=1.5.1",
 ]
 vllm = [
-    "vllm>=0.10.1.1,<0.11",
+    "vllm>=0.10.1.1,<0.12",
 ]
 pipeline = [
     "matplotlib>=3.10,<4",