Bladeren bron

Merge remote-tracking branch 'origin/dev' into dev

myhloli 8 maanden geleden
bovenliggende
commit
9ce72d78e6
39 gewijzigde bestanden met toevoegingen van 4853 en 468 verwijderingen
  1. 3 7
      docker/ascend_npu/requirements.txt
  2. 3 7
      docker/china/requirements.txt
  3. 3 7
      docker/global/requirements.txt
  4. 1 1
      magic-pdf.template.json
  5. 156 0
      magic_pdf/data/batch_build_dataset.py
  6. 40 23
      magic_pdf/data/dataset.py
  7. 108 9
      magic_pdf/data/utils.py
  8. 4 3
      magic_pdf/dict2md/ocr_mkcontent.py
  9. 11 6
      magic_pdf/libs/pdf_image_tools.py
  10. 5 116
      magic_pdf/model/batch_analyze.py
  11. 130 28
      magic_pdf/model/doc_analyze_by_custom_model.py
  12. 4 29
      magic_pdf/model/pdf_extract_kit.py
  13. 2 4
      magic_pdf/model/sub_modules/language_detection/utils.py
  14. 24 19
      magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
  15. 20 98
      magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
  16. 13 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
  17. 189 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
  18. 8 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
  19. 163 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
  20. 2351 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
  21. 0 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py
  22. 9 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py
  23. 132 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py
  24. 114 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py
  25. 1084 0
      magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py
  26. 14 12
      magic_pdf/model/sub_modules/model_init.py
  27. 17 11
      magic_pdf/model/sub_modules/model_utils.py
  28. 1 0
      magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
  29. 1 0
      magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py
  30. 1 1
      magic_pdf/pdf_parse_union_core_v2.py
  31. 1 1
      magic_pdf/resources/model_config/model_configs.yaml
  32. 19 10
      magic_pdf/tools/cli.py
  33. 88 11
      magic_pdf/tools/common.py
  34. 70 45
      projects/web_api/app.py
  35. 3 2
      requirements.txt
  36. 13 5
      scripts/download_models.py
  37. 13 5
      scripts/download_models_hf.py
  38. 11 8
      setup.py
  39. 24 0
      signatures/version1/cla.json

+ 3 - 7
docker/ascend_npu/requirements.txt

@@ -7,19 +7,15 @@ numpy>=1.21.6,<2.0.0
 fast-langdetect>=0.2.3,<0.3.0
 scikit-learn>=1.0.2
 pdfminer.six==20231228
-unimernet==0.2.3
-torch>=2.2.2,<=2.3.1
-torchvision>=0.17.2,<=0.18.1
+torch==2.3.1
+torchvision==0.18.1
 matplotlib
 ultralytics>=8.3.48
 paddleocr==2.7.3
 paddlepaddle==3.0.0rc1
-struct-eqtable==0.3.2
-einops
-accelerate
 rapidocr-paddle>=1.4.5,<2.0.0
 rapidocr-onnxruntime>=1.4.4,<2.0.0
 rapid-table>=1.0.3,<2.0.0
 doclayout-yolo==0.0.2b1
+ftfy
 openai
-detectron2

+ 3 - 7
docker/china/requirements.txt

@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
 fast-langdetect>=0.2.3,<0.3.0
 scikit-learn>=1.0.2
 pdfminer.six==20231228
-unimernet==0.2.3
-torch>=2.2.2,<=2.3.1
-torchvision>=0.17.2,<=0.18.1
+torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
+torchvision
 matplotlib
 ultralytics>=8.3.48
 paddleocr==2.7.3
-struct-eqtable==0.3.2
-einops
-accelerate
 rapidocr-paddle>=1.4.5,<2.0.0
 rapidocr-onnxruntime>=1.4.4,<2.0.0
 rapid-table>=1.0.3,<2.0.0
 doclayout-yolo==0.0.2b1
+ftfy
 openai
-detectron2

+ 3 - 7
docker/global/requirements.txt

@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
 fast-langdetect>=0.2.3,<0.3.0
 scikit-learn>=1.0.2
 pdfminer.six==20231228
-unimernet==0.2.3
-torch>=2.2.2,<=2.3.1
-torchvision>=0.17.2,<=0.18.1
+torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
+torchvision
 matplotlib
 ultralytics>=8.3.48
 paddleocr==2.7.3
-struct-eqtable==0.3.2
-einops
-accelerate
 rapidocr-paddle>=1.4.5,<2.0.0
 rapidocr-onnxruntime>=1.4.4,<2.0.0
 rapid-table>=1.0.3,<2.0.0
 doclayout-yolo==0.0.2b1
+ftfy
 openai
-detectron2

+ 1 - 1
magic-pdf.template.json

@@ -40,5 +40,5 @@
             "enable": false
         }
     },
-    "config_version": "1.1.1"
+    "config_version": "1.2.0"
 }

+ 156 - 0
magic_pdf/data/batch_build_dataset.py

@@ -0,0 +1,156 @@
+import concurrent.futures
+
+import fitz
+
+from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.utils import fitz_doc_to_image  # PyMuPDF
+
+
+def partition_array_greedy(arr, k):
+    """Partition an array into k parts using a simple greedy approach.
+
+    Parameters:
+    -----------
+    arr : list
+        The input array of integers
+    k : int
+        Number of partitions to create
+
+    Returns:
+    --------
+    partitions : list of lists
+        The k partitions of the array
+    """
+    # Handle edge cases
+    if k <= 0:
+        raise ValueError('k must be a positive integer')
+    if k > len(arr):
+        k = len(arr)  # Adjust k if it's too large
+    if k == 1:
+        return [list(range(len(arr)))]
+    if k == len(arr):
+        return [[i] for i in range(len(arr))]
+
+    # Sort the array in descending order
+    sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
+
+    # Initialize k empty partitions
+    partitions = [[] for _ in range(k)]
+    partition_sums = [0] * k
+
+    # Assign each element to the partition with the smallest current sum
+    for idx in sorted_indices:
+        # Find the partition with the smallest sum
+        min_sum_idx = partition_sums.index(min(partition_sums))
+
+        # Add the element to this partition
+        partitions[min_sum_idx].append(idx)  # Store the original index
+        partition_sums[min_sum_idx] += arr[idx][1]
+
+    return partitions
+
+
+def process_pdf_batch(pdf_jobs, idx):
+    """Process a batch of PDF pages using multiple threads.
+
+    Parameters:
+    -----------
+    pdf_jobs : list of tuples
+        List of (pdf_path, page_num) tuples
+    output_dir : str or None
+        Directory to save images to
+    num_threads : int
+        Number of threads to use
+    **kwargs :
+        Additional arguments for process_pdf_page
+
+    Returns:
+    --------
+    images : list
+        List of processed images
+    """
+    images = []
+
+    for pdf_path, _ in pdf_jobs:
+        doc = fitz.open(pdf_path)
+        tmp = []
+        for page_num in range(len(doc)):
+            page = doc[page_num]
+            tmp.append(fitz_doc_to_image(page))
+        images.append(tmp)
+    return (idx, images)
+
+
+def batch_build_dataset(pdf_paths, k, lang=None):
+    """Process multiple PDFs by partitioning them into k balanced parts and
+    processing each part in parallel.
+
+    Parameters:
+    -----------
+    pdf_paths : list
+        List of paths to PDF files
+    k : int
+        Number of partitions to create
+    output_dir : str or None
+        Directory to save images to
+    threads_per_worker : int
+        Number of threads to use per worker
+    **kwargs :
+        Additional arguments for process_pdf_page
+
+    Returns:
+    --------
+    all_images : list
+        List of all processed images
+    """
+    # Get page counts for each PDF
+    pdf_info = []
+    total_pages = 0
+
+    for pdf_path in pdf_paths:
+        try:
+            doc = fitz.open(pdf_path)
+            num_pages = len(doc)
+            pdf_info.append((pdf_path, num_pages))
+            total_pages += num_pages
+            doc.close()
+        except Exception as e:
+            print(f'Error opening {pdf_path}: {e}')
+
+    # Partition the jobs based on page countEach job has 1 page
+    partitions = partition_array_greedy(pdf_info, k)
+
+    # Process each partition in parallel
+    all_images_h = {}
+
+    with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
+        # Submit one task per partition
+        futures = []
+        for sn, partition in enumerate(partitions):
+            # Get the jobs for this partition
+            partition_jobs = [pdf_info[idx] for idx in partition]
+
+            # Submit the task
+            future = executor.submit(
+                process_pdf_batch,
+                partition_jobs,
+                sn
+            )
+            futures.append(future)
+        # Process results as they complete
+        for i, future in enumerate(concurrent.futures.as_completed(futures)):
+            try:
+                idx, images = future.result()
+                all_images_h[idx] = images
+            except Exception as e:
+                print(f'Error processing partition: {e}')
+    results = [None] * len(pdf_paths)
+    for i in range(len(partitions)):
+        partition = partitions[i]
+        for j in range(len(partition)):
+            with open(pdf_info[partition[j]][0], 'rb') as f:
+                pdf_bytes = f.read()
+            dataset = PymuDocDataset(pdf_bytes, lang=lang)
+            dataset.set_images(all_images_h[i][j])
+            results[partition[j]] = dataset
+    return results

+ 40 - 23
magic_pdf/data/dataset.py

@@ -97,10 +97,10 @@ class Dataset(ABC):
 
     @abstractmethod
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
         pass
 
@@ -119,7 +119,7 @@ class Dataset(ABC):
 
     @abstractmethod
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -128,8 +128,7 @@ class Dataset(ABC):
 
     @abstractmethod
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         pass
 
 
@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
         if lang == '':
             self._lang = None
         elif lang == 'auto':
-            from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang
+            from magic_pdf.model.sub_modules.language_detection.utils import \
+                auto_detect_lang
             self._lang = auto_detect_lang(bits)
-            logger.info(f"lang: {lang}, detect_lang: {self._lang}")
+            logger.info(f'lang: {lang}, detect_lang: {self._lang}')
         else:
             self._lang = lang
-            logger.info(f"lang: {lang}")
+            logger.info(f'lang: {lang}')
+
     def __len__(self) -> int:
         """The page number of the pdf."""
         return len(self._records)
@@ -186,12 +187,12 @@ class PymuDocDataset(Dataset):
         return self._records[page_id]
 
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
-        
+
         dir_name = os.path.dirname(file_path)
         if dir_name not in ('', '.', '..'):
             os.makedirs(dir_name, exist_ok=True)
@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
         return proc(self, *args, **kwargs)
 
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
         return classify(self._data_bits)
 
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         return PymuDocDataset(self._raw_data)
 
+    def set_images(self, images):
+        for i in range(len(self._records)):
+            self._records[i].set_image(images[i])
 
 class ImageDataset(Dataset):
     def __init__(self, bits: bytes):
@@ -270,10 +273,10 @@ class ImageDataset(Dataset):
         return self._records[page_id]
 
     def dump_to_file(self, file_path: str):
-        """Dump the file
+        """Dump the file.
 
-        Args: 
-            file_path (str): the file path 
+        Args:
+            file_path (str): the file path
         """
         dir_name = os.path.dirname(file_path)
         if dir_name not in ('', '.', '..'):
@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
         return proc(self, *args, **kwargs)
 
     def classify(self) -> SupportedPdfParseMethod:
-        """classify the dataset 
+        """classify the dataset.
 
         Returns:
             SupportedPdfParseMethod: _description_
@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
         return SupportedPdfParseMethod.OCR
 
     def clone(self):
-        """clone this dataset
-        """
+        """clone this dataset."""
         return ImageDataset(self._raw_data)
 
+    def set_images(self, images):
+        for i in range(len(self._records)):
+            self._records[i].set_image(images[i])
+
 class Doc(PageableData):
     """Initialized with pymudoc object."""
 
     def __init__(self, doc: fitz.Page):
         self._doc = doc
+        self._img = None
 
     def get_image(self):
         """Return the image info.
@@ -321,7 +328,17 @@ class Doc(PageableData):
                 height: int
             }
         """
-        return fitz_doc_to_image(self._doc)
+        if self._img is None:
+            self._img = fitz_doc_to_image(self._doc)
+        return self._img
+
+    def set_image(self, img):
+        """
+        Args:
+            img (np.ndarray): the image
+        """
+        if self._img is None:
+            self._img = img
 
     def get_doc(self) -> fitz.Page:
         """Get the pymudoc object.

+ 108 - 9
magic_pdf/data/utils.py

@@ -1,12 +1,15 @@
 
+import multiprocessing as mp
+import threading
+from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor,
+                                as_completed)
+
 import fitz
 import numpy as np
 from loguru import logger
 
-from magic_pdf.utils.annotations import ImportPIL
 
 
-@ImportPIL
 def fitz_doc_to_image(doc, dpi=200) -> dict:
     """Convert fitz.Document to image, Then convert the image to numpy array.
 
@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
     Returns:
         dict:  {'img': numpy array, 'width': width, 'height': height }
     """
-    from PIL import Image
     mat = fitz.Matrix(dpi / 72, dpi / 72)
     pm = doc.get_pixmap(matrix=mat, alpha=False)
 
@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
     if pm.width > 4500 or pm.height > 4500:
         pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
-    img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
-    img = np.array(img)
+    # Convert pixmap samples directly to numpy array
+    img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
 
     img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
 
     return img_dict
 
-@ImportPIL
 def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
-    from PIL import Image
     images = []
     with fitz.open('pdf', pdf_bytes) as doc:
         pdf_page_num = doc.page_count
@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
                 if pm.width > 4500 or pm.height > 4500:
                     pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
-                img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
-                img = np.array(img)
+                # Convert pixmap samples directly to numpy array
+                img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
+
                 img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
             else:
                 img_dict = {'img': [], 'width': 0, 'height': 0}
 
             images.append(img_dict)
     return images
+
+
+def convert_page(bytes_page):
+    pdfs = fitz.open('pdf', bytes_page)
+    page = pdfs[0]
+    return fitz_doc_to_image(page)
+
+def parallel_process_pdf_safe(pages, num_workers=None, **kwargs):
+    """Process PDF pages in parallel with serialization-safe approach."""
+    if num_workers is None:
+        num_workers = mp.cpu_count()
+
+
+    # Process the extracted page data in parallel
+    with ProcessPoolExecutor(max_workers=num_workers) as executor:
+        # Process the page data
+        results = list(
+            executor.map(convert_page, pages)
+        )
+
+    return results
+
+
+def threaded_process_pdf(pdf_path, num_threads=4, **kwargs):
+    """Process all pages of a PDF using multiple threads.
+
+    Parameters:
+    -----------
+    pdf_path : str
+        Path to the PDF file
+    num_threads : int
+        Number of threads to use
+    **kwargs :
+        Additional arguments for fitz_doc_to_image
+
+    Returns:
+    --------
+    images : list
+        List of processed images, in page order
+    """
+    # Open the PDF
+    doc = fitz.open(pdf_path)
+    num_pages = len(doc)
+
+    # Create a list to store results in the correct order
+    results = [None] * num_pages
+
+    # Create a thread pool
+    with ThreadPoolExecutor(max_workers=num_threads) as executor:
+        # Submit all tasks
+        futures = {}
+        for page_num in range(num_pages):
+            page = doc[page_num]
+            future = executor.submit(fitz_doc_to_image, page, **kwargs)
+            futures[future] = page_num
+        # Process results as they complete with progress bar
+        for future in as_completed(futures):
+            page_num = futures[future]
+            try:
+                results[page_num] = future.result()
+            except Exception as e:
+                print(f'Error processing page {page_num}: {e}')
+                results[page_num] = None
+
+    # Close the document
+    doc.close()
+
+if __name__ == '__main__':
+    pdf = fitz.open('/tmp/[MS-DOC].pdf')
+
+
+    pdf_page = [fitz.open() for i in range(pdf.page_count)]
+    [pdf_page[i].insert_pdf(pdf, from_page=i, to_page=i) for i in range(pdf.page_count)]
+
+    pdf_page = [v.tobytes() for v in pdf_page]
+    results = parallel_process_pdf_safe(pdf_page, num_workers=16)
+
+    # threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
+
+    """ benchmark results of multi-threaded processing (fitz page to image)
+    total page nums: 578
+    thread nums,    time cost
+    1               7.351 sec
+    2               6.334 sec
+    4               5.968 sec
+    8               6.728 sec
+    16              8.085 sec
+    """
+
+    """ benchmark results of multi-processor processing (fitz page to image)
+    total page nums: 578
+    processor nums,    time cost
+    1                  17.170 sec
+    2                  10.170 sec
+    4                  7.841 sec
+    8                  7.900 sec
+    16                 7.984 sec
+    """

+ 4 - 3
magic_pdf/dict2md/ocr_mkcontent.py

@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
             'text': merge_para_with_text(para_block),
         }
     elif para_type == BlockType.Title:
-        title_level = get_title_level(para_block)
         para_content = {
             'type': 'text',
             'text': merge_para_with_text(para_block),
-            'text_level': title_level,
         }
+        title_level = get_title_level(para_block)
+        if title_level != 0:
+            para_content['text_level'] = title_level
     elif para_type == BlockType.InterlineEquation:
         para_content = {
             'type': 'equation',
@@ -319,5 +320,5 @@ def get_title_level(block):
     if title_level > 4:
         title_level = 4
     elif title_level < 1:
-        title_level = 1
+        title_level = 0
     return title_level

+ 11 - 6
magic_pdf/libs/pdf_image_tools.py

@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
     # 截取图片
     pix = page.get_pixmap(clip=rect, matrix=zoom)
 
-    # 将字节数据转换为文件对象
-    image_file = BytesIO(pix.tobytes(output='png'))
-    # 使用 Pillow 打开图像
-    pil_image = Image.open(image_file)
     if mode == "cv2":
-        image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
+        # 直接转换为numpy数组供cv2使用
+        img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
+        # PyMuPDF使用RGB顺序,而cv2使用BGR顺序
+        if pix.n == 3 or pix.n == 4:
+            image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
+        else:
+            image_result = img_array
     elif mode == "pillow":
-        image_result = pil_image
+        # 将字节数据转换为文件对象
+        image_file = BytesIO(pix.tobytes(output='png'))
+        # 使用 Pillow 打开图像
+        image_result = Image.open(image_file)
     else:
         raise ValueError(f"mode: {mode} is not supported.")
 

+ 5 - 116
magic_pdf/model/batch_analyze.py

@@ -1,23 +1,15 @@
 import time
 
 import cv2
-import numpy as np
 import torch
 from loguru import logger
-from PIL import Image
 
 from magic_pdf.config.constants import MODEL_NAME
-# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
-# from magic_pdf.data.dataset import Dataset
-# from magic_pdf.libs.clean_memory import clean_memory
-# from magic_pdf.libs.config_reader import get_device
-# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
 from magic_pdf.model.pdf_extract_kit import CustomPEKModel
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
-# from magic_pdf.operators.models import InferenceResult
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
@@ -31,7 +23,6 @@ class BatchAnalyze:
 
     def __call__(self, images: list) -> list:
         images_layout_res = []
-
         layout_start_time = time.time()
         if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
             # layoutlmv3
@@ -41,36 +32,14 @@ class BatchAnalyze:
         elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
             layout_images = []
-            modified_images = []
             for image_index, image in enumerate(images):
-                pil_img = Image.fromarray(image)
-                # width, height = pil_img.size
-                # if height > width:
-                #     input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
-                #     new_image, useful_list = crop_img(
-                #         input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
-                #     )
-                #     layout_images.append(new_image)
-                #     modified_images.append([image_index, useful_list])
-                # else:
-                layout_images.append(pil_img)
+                layout_images.append(image)
 
             images_layout_res += self.model.layout_model.batch_predict(
                 # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
                 layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
             )
 
-            for image_index, useful_list in modified_images:
-                for res in images_layout_res[image_index]:
-                    for i in range(len(res['poly'])):
-                        if i % 2 == 0:
-                            res['poly'][i] = (
-                                res['poly'][i] - useful_list[0] + useful_list[2]
-                            )
-                        else:
-                            res['poly'][i] = (
-                                res['poly'][i] - useful_list[1] + useful_list[3]
-                            )
         logger.info(
             f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
         )
@@ -111,7 +80,7 @@ class BatchAnalyze:
         # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
         for index in range(len(images)):
             layout_res = images_layout_res[index]
-            pil_img = Image.fromarray(images[index])
+            np_array_img = images[index]
 
             ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                 get_res_list_from_layout_res(layout_res)
@@ -121,14 +90,14 @@ class BatchAnalyze:
             # Process each area that requires OCR processing
             for res in ocr_res_list:
                 new_image, useful_list = crop_img(
-                    res, pil_img, crop_paste_x=50, crop_paste_y=50
+                    res, np_array_img, crop_paste_x=50, crop_paste_y=50
                 )
                 adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
                     single_page_mfdetrec_res, useful_list
                 )
 
                 # OCR recognition
-                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
                 if self.model.apply_ocr:
                     ocr_res = self.model.ocr_model.ocr(
@@ -150,7 +119,7 @@ class BatchAnalyze:
             if self.model.apply_table:
                 table_start = time.time()
                 for res in table_res_list:
-                    new_image, _ = crop_img(res, pil_img)
+                    new_image, _ = crop_img(res, np_array_img)
                     single_table_start_time = time.time()
                     html_code = None
                     if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
@@ -197,83 +166,3 @@ class BatchAnalyze:
             logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
 
         return images_layout_res
-
-
-# def doc_batch_analyze(
-#     dataset: Dataset,
-#     ocr: bool = False,
-#     show_log: bool = False,
-#     start_page_id=0,
-#     end_page_id=None,
-#     lang=None,
-#     layout_model=None,
-#     formula_enable=None,
-#     table_enable=None,
-#     batch_ratio: int | None = None,
-# ) -> InferenceResult:
-#     """Perform batch analysis on a document dataset.
-#
-#     Args:
-#         dataset (Dataset): The dataset containing document pages to be analyzed.
-#         ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
-#         show_log (bool, optional): Flag to enable logging. Defaults to False.
-#         start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
-#         end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
-#         lang (str, optional): Language for OCR. Defaults to None.
-#         layout_model (optional): Layout model to be used for analysis. Defaults to None.
-#         formula_enable (optional): Flag to enable formula detection. Defaults to None.
-#         table_enable (optional): Flag to enable table detection. Defaults to None.
-#         batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
-#
-#     Raises:
-#         CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
-#
-#     Returns:
-#         InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
-#     """
-#
-#     if not torch.cuda.is_available():
-#         raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
-#
-#     lang = None if lang == '' else lang
-#     # TODO: auto detect batch size
-#     batch_ratio = 1 if batch_ratio is None else batch_ratio
-#     end_page_id = end_page_id if end_page_id else len(dataset)
-#
-#     model_manager = ModelSingleton()
-#     custom_model: CustomPEKModel = model_manager.get_model(
-#         ocr, show_log, lang, layout_model, formula_enable, table_enable
-#     )
-#     batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-#
-#     model_json = []
-#
-#     # batch analyze
-#     images = []
-#     for index in range(len(dataset)):
-#         if start_page_id <= index <= end_page_id:
-#             page_data = dataset.get_page(index)
-#             img_dict = page_data.get_image()
-#             images.append(img_dict['img'])
-#     analyze_result = batch_model(images)
-#
-#     for index in range(len(dataset)):
-#         page_data = dataset.get_page(index)
-#         img_dict = page_data.get_image()
-#         page_width = img_dict['width']
-#         page_height = img_dict['height']
-#         if start_page_id <= index <= end_page_id:
-#             result = analyze_result.pop(0)
-#         else:
-#             result = []
-#
-#         page_info = {'page_no': index, 'height': page_height, 'width': page_width}
-#         page_dict = {'layout_dets': result, 'page_info': page_info}
-#         model_json.append(page_dict)
-#
-#     # TODO: clean memory when gpu memory is not enough
-#     clean_memory_start_time = time.time()
-#     clean_memory(get_device())
-#     logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
-#
-#     return InferenceResult(model_json, dataset)

+ 130 - 28
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,18 +1,19 @@
+import concurrent.futures as fut
+import multiprocessing as mp
 import os
 import time
+
+import numpy as np
 import torch
 
 os.environ['FLAGS_npu_jit_compile'] = '0'  # 关闭paddle的jit编译
 os.environ['FLAGS_use_stride_kernel'] = '0'
 os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'  # 让mps可以fallback
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
-# 关闭paddle的信号处理
-import paddle
-paddle.disable_signal_handler()
+
 
 from loguru import logger
 
-from magic_pdf.model.batch_analyze import BatchAnalyze
 from magic_pdf.model.sub_modules.model_utils import get_vram
 
 try:
@@ -30,8 +31,8 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                           get_local_models_dir,
                                           get_table_recog_config)
 from magic_pdf.model.model_list import MODEL
-from magic_pdf.operators.models import InferenceResult
 
+# from magic_pdf.operators.models import InferenceResult
 
 class ModelSingleton:
     _instance = None
@@ -72,9 +73,7 @@ def custom_model_init(
     formula_enable=None,
     table_enable=None,
 ):
-
     model = None
-
     if model_config.__model_mode__ == 'lite':
         logger.warning(
             'The Lite mode is provided for developers to conduct testing only, and the output quality is '
@@ -132,7 +131,6 @@ def custom_model_init(
 
     return custom_model
 
-
 def doc_analyze(
     dataset: Dataset,
     ocr: bool = False,
@@ -143,14 +141,112 @@ def doc_analyze(
     layout_model=None,
     formula_enable=None,
     table_enable=None,
-) -> InferenceResult:
-
+):
     end_page_id = (
         end_page_id
         if end_page_id is not None and end_page_id >= 0
         else len(dataset) - 1
     )
 
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
+    images = []
+    page_wh_list = []
+    for index in range(len(dataset)):
+        if start_page_id <= index <= end_page_id:
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            images.append(img_dict['img'])
+            page_wh_list.append((img_dict['width'], img_dict['height']))
+
+    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        batch_size = MIN_BATCH_INFERENCE_SIZE
+        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
+    else:
+        batch_images = [images]
+
+    results = []
+    for sn, batch_image in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
+        results.extend(result)
+
+    model_json = []
+    for index in range(len(dataset)):
+        if start_page_id <= index <= end_page_id:
+            result = results.pop(0)
+            page_width, page_height = page_wh_list.pop(0)
+        else:
+            result = []
+            page_height = 0
+            page_width = 0
+
+        page_info = {'page_no': index, 'width': page_width, 'height': page_height}
+        page_dict = {'layout_dets': result, 'page_info': page_info}
+        model_json.append(page_dict)
+
+    from magic_pdf.operators.models import InferenceResult
+    return InferenceResult(model_json, dataset)
+
+def batch_doc_analyze(
+    datasets: list[Dataset],
+    ocr: bool = False,
+    show_log: bool = False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
+    images = []
+    page_wh_list = []
+    for dataset in datasets:
+        for index in range(len(dataset)):
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            images.append(img_dict['img'])
+            page_wh_list.append((img_dict['width'], img_dict['height']))
+
+    if len(images) >= MIN_BATCH_INFERENCE_SIZE:
+        batch_size = MIN_BATCH_INFERENCE_SIZE
+        batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
+    else:
+        batch_images = [images]
+
+    results = []
+
+    for sn, batch_image in enumerate(batch_images):
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
+        results.extend(result)
+    infer_results = []
+
+    from magic_pdf.operators.models import InferenceResult
+    for index in range(len(datasets)):
+        dataset = datasets[index]
+        model_json = []
+        for i in range(len(dataset)):
+            result = results.pop(0)
+            page_width, page_height = page_wh_list.pop(0)
+            page_info = {'page_no': i, 'width': page_width, 'height': page_height}
+            page_dict = {'layout_dets': result, 'page_info': page_info}
+            model_json.append(page_dict)
+        infer_results.append(InferenceResult(model_json, dataset))
+    return infer_results
+
+
+def may_batch_image_analyze(
+        images: list[np.ndarray],
+        idx: int,
+        ocr: bool = False,
+        show_log: bool = False,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None):
+    # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
+    # 关闭paddle的信号处理
+    import paddle
+    paddle.disable_signal_handler()
+    from magic_pdf.model.batch_analyze import BatchAnalyze
+
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(
         ocr, show_log, lang, layout_model, formula_enable, table_enable
@@ -160,33 +256,32 @@ def doc_analyze(
     batch_ratio = 1
     device = get_device()
 
-    npu_support = False
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         import torch_npu
         if torch_npu.npu.is_available():
-            npu_support = True
             torch.npu.set_compile_mode(jit_compile=False)
 
-    if torch.cuda.is_available() and device != 'cpu' or npu_support:
-        gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
-        if gpu_memory is not None and gpu_memory >= 8:
+    if str(device).startswith('npu') or str(device).startswith('cuda'):
+        gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
+        if gpu_memory is not None:
             if gpu_memory >= 20:
                 batch_ratio = 16
             elif gpu_memory >= 15:
                 batch_ratio = 8
             elif gpu_memory >= 10:
                 batch_ratio = 4
-            else:
+            elif gpu_memory >= 7:
                 batch_ratio = 2
-
+            else:
+                batch_ratio = 1
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
             batch_analyze = True
-
-    model_json = []
+    elif str(device).startswith('mps'):
+        batch_analyze = True
     doc_analyze_start = time.time()
 
     if batch_analyze:
-        # batch analyze
+        """# batch analyze
         images = []
         page_wh_list = []
         for index in range(len(dataset)):
@@ -195,9 +290,10 @@ def doc_analyze(
                 img_dict = page_data.get_image()
                 images.append(img_dict['img'])
                 page_wh_list.append((img_dict['width'], img_dict['height']))
+        """
         batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-        analyze_result = batch_model(images)
-
+        results = batch_model(images)
+        """
         for index in range(len(dataset)):
             if start_page_id <= index <= end_page_id:
                 result = analyze_result.pop(0)
@@ -210,10 +306,10 @@ def doc_analyze(
             page_info = {'page_no': index, 'width': page_width, 'height': page_height}
             page_dict = {'layout_dets': result, 'page_info': page_info}
             model_json.append(page_dict)
-
+        """
     else:
         # single analyze
-
+        """
         for index in range(len(dataset)):
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
@@ -230,6 +326,13 @@ def doc_analyze(
             page_info = {'page_no': index, 'width': page_width, 'height': page_height}
             page_dict = {'layout_dets': result, 'page_info': page_info}
             model_json.append(page_dict)
+        """
+        results = []
+        for img_idx, img in enumerate(images):
+            inference_start = time.time()
+            result = custom_model(img)
+            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
+            results.append(result)
 
     gc_start = time.time()
     clean_memory(get_device())
@@ -237,10 +340,9 @@ def doc_analyze(
     logger.info(f'gc time: {gc_time}')
 
     doc_analyze_time = round(time.time() - doc_analyze_start, 2)
-    doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
+    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
     logger.info(
         f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
         f' speed: {doc_analyze_speed} pages/second'
     )
-
-    return InferenceResult(model_json, dataset)
+    return (idx, results)

+ 4 - 29
magic_pdf/model/pdf_extract_kit.py

@@ -3,11 +3,9 @@ import os
 import time
 
 import cv2
-import numpy as np
 import torch
 import yaml
 from loguru import logger
-from PIL import Image
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 
@@ -120,7 +118,7 @@ class CustomPEKModel:
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_cfg_path=mfr_cfg_path,
-                device='cpu' if str(self.device).startswith("mps") else self.device,
+                device=self.device,
             )
 
         # 初始化layout模型
@@ -174,11 +172,6 @@ class CustomPEKModel:
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
-
-        pil_img = Image.fromarray(image)
-        width, height = pil_img.size
-        # logger.info(f'width: {width}, height: {height}')
-
         # layout检测
         layout_start = time.time()
         layout_res = []
@@ -186,24 +179,6 @@ class CustomPEKModel:
             # layoutlmv3
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
-            # doclayout_yolo
-            # if height > width:
-            #     input_res = {"poly":[0,0,width,0,width,height,0,height]}
-            #     new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
-            #     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-            #     layout_res = self.layout_model.predict(new_image)
-            #     for res in layout_res:
-            #         p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
-            #         p1 = p1 - paste_x + xmin
-            #         p2 = p2 - paste_y + ymin
-            #         p3 = p3 - paste_x + xmin
-            #         p4 = p4 - paste_y + ymin
-            #         p5 = p5 - paste_x + xmin
-            #         p6 = p6 - paste_y + ymin
-            #         p7 = p7 - paste_x + xmin
-            #         p8 = p8 - paste_y + ymin
-            #         res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
-            # else:
             layout_res = self.layout_model.predict(image)
 
         layout_cost = round(time.time() - layout_start, 2)
@@ -234,11 +209,11 @@ class CustomPEKModel:
         ocr_start = time.time()
         # Process each area that requires OCR processing
         for res in ocr_res_list:
-            new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
+            new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
             adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
 
             # OCR recognition
-            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
             if self.apply_ocr:
                 ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
@@ -260,7 +235,7 @@ class CustomPEKModel:
         if self.apply_table:
             table_start = time.time()
             for res in table_res_list:
-                new_image, _ = crop_img(res, pil_img)
+                new_image, _ = crop_img(res, image)
                 single_table_start_time = time.time()
                 html_code = None
                 if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:

+ 2 - 4
magic_pdf/model/sub_modules/language_detection/utils.py

@@ -3,8 +3,6 @@ import os
 from pathlib import Path
 
 import yaml
-from PIL import Image
-
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 
 from magic_pdf.config.constants import MODEL_NAME
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
     )
     text_images = []
     for simple_image in simple_images:
-        image = Image.fromarray(simple_image['img'])
+        image = simple_image['img']
         layout_res = temp_layout_model.predict(image)
         # 给textblock截图
         for res in layout_res:
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
                 # 初步清洗(宽和高都小于100)
                 if x2 - x1 < 100 and y2 - y1 < 100:
                     continue
-                text_images.append(image.crop((x1, y1, x2, y2)))
+                text_images.append(image[y1:y2, x1:x2])
     return text_images
 
 

+ 24 - 19
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -2,9 +2,9 @@
 import time
 from collections import Counter
 from uuid import uuid4
-
+import cv2
+import numpy as np
 import torch
-from PIL import Image
 from loguru import logger
 from ultralytics import YOLO
 
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
     if result_images is None:
         result_images = []
 
-    width, height = image.size
+    height, width = image.shape[:2]
     long_side = max(width, height)  # 获取较长边长度
 
     if long_side <= 400:
@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             if x + new_long_side > width:
                 continue
-            box = (x, 0, x + new_long_side, height)
-            sub_image = image.crop(box)
+            sub_image = image[0:height, x:x + new_long_side]
             sub_images.append(sub_image)
     else:  # 如果高度是较长边
         for y in range(0, height, new_long_side):
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             if y + new_long_side > height:
                 continue
-            box = (0, y, width, y + new_long_side)
-            sub_image = image.crop(box)
+            sub_image = image[y:y + new_long_side, 0:width]
             sub_images.append(sub_image)
 
     for sub_image in sub_images:
@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
 
 def resize_images_to_224(image):
     """
-    若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
+    若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
+    Works directly with NumPy arrays.
     """
     try:
-        width, height = image.size
+        height, width = image.shape[:2]
+
         if width < 224 or height < 224:
-            new_image = Image.new('RGB', (224, 224), (0, 0, 0))
-            paste_x = (224 - width) // 2
-            paste_y = (224 - height) // 2
-            new_image.paste(image, (paste_x, paste_y))
+            # Create black background
+            new_image = np.zeros((224, 224, 3), dtype=np.uint8)
+            # Calculate paste position (ensure they're not negative)
+            paste_x = max(0, (224 - width) // 2)
+            paste_y = max(0, (224 - height) // 2)
+            # Make sure we don't exceed the boundaries of new_image
+            paste_width = min(width, 224)
+            paste_height = min(height, 224)
+            # Paste original image onto black background
+            new_image[paste_y:paste_y + paste_height, paste_x:paste_x + paste_width] = image[:paste_height, :paste_width]
             image = new_image
         else:
-            image = image.resize((224, 224), Image.Resampling.LANCZOS)
+            # Resize using cv2
+            image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
 
-        # uuid = str(uuid4())
-        # image.save(f"/tmp/{uuid}.jpg")
         return image
     except Exception as e:
-        logger.exception(e)
+        logger.exception(f"Error in resize_images_to_224: {e}")
+        return None
 
 
 class YOLOv11LangDetModel(object):
@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
     def do_detect(self, images: list):
         all_images = []
         for image in images:
-            width, height = image.size
-            # logger.info(f"image size: {width} x {height}")
+            height, width = image.shape[:2]
             if width < 100 and height < 100:
                 continue
             temp_images = split_images(image)

+ 20 - 98
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -1,14 +1,5 @@
-import argparse
-import os
-import re
-
 import torch
-import unimernet.tasks as tasks
-from PIL import Image
 from torch.utils.data import DataLoader, Dataset
-from torchvision import transforms
-from unimernet.common.config import Config
-from unimernet.processors import load_processor
 
 
 class MathDataset(Dataset):
@@ -20,55 +11,25 @@ class MathDataset(Dataset):
         return len(self.image_paths)
 
     def __getitem__(self, idx):
-        # if not pil image, then convert to pil image
-        if isinstance(self.image_paths[idx], str):
-            raw_image = Image.open(self.image_paths[idx])
-        else:
-            raw_image = self.image_paths[idx]
+        raw_image = self.image_paths[idx]
         if self.transform:
             image = self.transform(raw_image)
             return image
 
 
-def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code."""
-    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
-    letter = "[a-zA-Z]"
-    noletter = "[\W_^\d]"
-    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
-    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
-    news = s
-    while True:
-        s = news
-        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
-        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
-        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
-        if news == s:
-            break
-    return s
-
-
 class UnimernetModel(object):
     def __init__(self, weight_dir, cfg_path, _device_="cpu"):
-        args = argparse.Namespace(cfg_path=cfg_path, options=None)
-        cfg = Config(args)
-        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
-        cfg.config.model.model_config.model_name = weight_dir
-        cfg.config.model.tokenizer_config.path = weight_dir
-        task = tasks.setup_task(cfg)
-        self.model = task.build_model(cfg)
+        from .unimernet_hf import UnimernetModel
+        if _device_.startswith("mps"):
+            self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
+        else:
+            self.model = UnimernetModel.from_pretrained(weight_dir)
         self.device = _device_
         self.model.to(_device_)
+        if not _device_.startswith("cpu"):
+            self.model = self.model.to(dtype=torch.float16)
         self.model.eval()
-        vis_processor = load_processor(
-            "formula_image_eval",
-            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
-        )
-        self.mfr_transform = transforms.Compose(
-            [
-                vis_processor,
-            ]
-        )
+
 
     def predict(self, mfd_res, image):
         formula_list = []
@@ -84,62 +45,22 @@ class UnimernetModel(object):
                 "latex": "",
             }
             formula_list.append(new_item)
-            pil_img = Image.fromarray(image)
-            bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+            bbox_img = image[ymin:ymax, xmin:xmax]
             mf_image_list.append(bbox_img)
 
-        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        dataset = MathDataset(mf_image_list, transform=self.model.transform)
         dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
         mfr_res = []
         for mf_img in dataloader:
+            mf_img = mf_img.to(dtype=self.model.dtype)
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["pred_str"])
+            mfr_res.extend(output["fixed_str"])
         for res, latex in zip(formula_list, mfr_res):
-            res["latex"] = latex_rm_whitespace(latex)
+            res["latex"] = latex
         return formula_list
 
-    # def batch_predict(
-    #     self, images_mfd_res: list, images: list, batch_size: int = 64
-    # ) -> list:
-    #     images_formula_list = []
-    #     mf_image_list = []
-    #     backfill_list = []
-    #     for image_index in range(len(images_mfd_res)):
-    #         mfd_res = images_mfd_res[image_index]
-    #         pil_img = Image.fromarray(images[image_index])
-    #         formula_list = []
-    #
-    #         for xyxy, conf, cla in zip(
-    #             mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
-    #         ):
-    #             xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-    #             new_item = {
-    #                 "category_id": 13 + int(cla.item()),
-    #                 "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-    #                 "score": round(float(conf.item()), 2),
-    #                 "latex": "",
-    #             }
-    #             formula_list.append(new_item)
-    #             bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
-    #             mf_image_list.append(bbox_img)
-    #
-    #         images_formula_list.append(formula_list)
-    #         backfill_list += formula_list
-    #
-    #     dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
-    #     dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
-    #     mfr_res = []
-    #     for mf_img in dataloader:
-    #         mf_img = mf_img.to(self.device)
-    #         with torch.no_grad():
-    #             output = self.model.generate({"image": mf_img})
-    #         mfr_res.extend(output["pred_str"])
-    #     for res, latex in zip(backfill_list, mfr_res):
-    #         res["latex"] = latex_rm_whitespace(latex)
-    #     return images_formula_list
-
     def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
         images_formula_list = []
         mf_image_list = []
@@ -149,7 +70,7 @@ class UnimernetModel(object):
         # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
-            pil_img = Image.fromarray(images[image_index])
+            np_array_image = images[image_index]
             formula_list = []
 
             for idx, (xyxy, conf, cla) in enumerate(zip(
@@ -163,7 +84,7 @@ class UnimernetModel(object):
                     "latex": "",
                 }
                 formula_list.append(new_item)
-                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+                bbox_img = np_array_image[ymin:ymax, xmin:xmax]
                 area = (xmax - xmin) * (ymax - ymin)
 
                 curr_idx = len(mf_image_list)
@@ -182,22 +103,23 @@ class UnimernetModel(object):
         index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
 
         # Create dataset with sorted images
-        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
+        dataset = MathDataset(sorted_images, transform=self.model.transform)
         dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
 
         # Process batches and store results
         mfr_res = []
         for mf_img in dataloader:
+            mf_img = mf_img.to(dtype=self.model.dtype)
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["pred_str"])
+            mfr_res.extend(output["fixed_str"])
 
         # Restore original order
         unsorted_results = [""] * len(mfr_res)
         for new_idx, latex in enumerate(mfr_res):
             original_idx = index_mapping[new_idx]
-            unsorted_results[original_idx] = latex_rm_whitespace(latex)
+            unsorted_results[original_idx] = latex
 
         # Fill results back
         for res, latex in zip(backfill_list, unsorted_results):

+ 13 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py

@@ -0,0 +1,13 @@
+from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
+from .unimer_mbart import UnimerMBartConfig, UnimerMBartModel, UnimerMBartForCausalLM
+from .modeling_unimernet import UnimernetModel
+
+__all__ = [
+    "UnimerSwinConfig",
+    "UnimerSwinModel",
+    "UnimerSwinImageProcessor",
+    "UnimerMBartConfig",
+    "UnimerMBartModel",
+    "UnimerMBartForCausalLM",
+    "UnimernetModel",
+]

+ 189 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py

@@ -0,0 +1,189 @@
+import os
+import re
+import warnings
+from typing import Optional
+
+import torch
+from ftfy import fix_text
+
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
+from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
+from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import logger as base_model_logger
+
+from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
+from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
+
+AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
+AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
+AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
+AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
+
+
+# TODO: rewrite tokenizer
+class TokenizerWrapper:
+    def __init__(self, tokenizer):
+        self.tokenizer = tokenizer
+        self.pad_token_id = self.tokenizer.pad_token_id
+        self.bos_token_id = self.tokenizer.bos_token_id
+        self.eos_token_id = self.tokenizer.eos_token_id
+
+    def __len__(self):
+        return len(self.tokenizer)
+
+    def tokenize(self, text, **kwargs):
+        return self.tokenizer(
+            text,
+            return_token_type_ids=False,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            **kwargs,
+        )
+
+    def token2str(self, tokens) -> list:
+        generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
+        generated_text = [fix_text(text) for text in generated_text]
+        return generated_text
+
+    def detokenize(self, tokens):
+        toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
+        for b in range(len(toks)):
+            for i in reversed(range(len(toks[b]))):
+                if toks[b][i] is None:
+                    toks[b][i] = ''
+                toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
+                if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
+                    del toks[b][i]
+        return toks
+
+
+def latex_rm_whitespace(s: str):
+    """Remove unnecessary whitespace from LaTeX code.
+    """
+    text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
+    letter = r'[a-zA-Z]'
+    noletter = r'[\W_^\d]'
+    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
+    s = re.sub(text_reg, lambda _: str(names.pop(0)), s)
+    news = s
+    while True:
+        s = news
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
+        news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
+        if news == s:
+            break
+    return s
+
+
+class UnimernetModel(VisionEncoderDecoderModel):
+    def __init__(
+        self,
+        config: Optional[PretrainedConfig] = None,
+        encoder: Optional[PreTrainedModel] = None,
+        decoder: Optional[PreTrainedModel] = None,
+    ):
+        # VisionEncoderDecoderModel's checking log has bug, disable for temp.
+        base_model_logger.disabled = True
+        try:
+            super().__init__(config, encoder, decoder)
+        finally:
+            base_model_logger.disabled = False
+
+        if not config or not hasattr(config, "_name_or_path"):
+            raise RuntimeError("config._name_or_path is required by UnimernetModel.")
+
+        model_path = config._name_or_path
+        self.transform = UnimerSwinImageProcessor()
+        self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
+        self._post_check()
+    
+    def _post_check(self):
+        tokenizer = self.tokenizer
+
+        if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
+            warnings.warn(
+                f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
+                f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
+                f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
+            tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
+
+        assert self.config.decoder.vocab_size == len(tokenizer)
+        assert self.config.decoder_start_token_id == tokenizer.bos_token_id
+        assert self.config.pad_token_id == tokenizer.pad_token_id
+
+    @classmethod
+    def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
+        config = VisionEncoderDecoderConfig.from_pretrained(model_path)
+        config._name_or_path = model_path
+        config.encoder = UnimerSwinConfig(**vars(config.encoder))
+        config.decoder = UnimerMBartConfig(**vars(config.decoder))
+
+        encoder = UnimerSwinModel(config.encoder)
+        decoder = UnimerMBartForCausalLM(config.decoder)
+        model = cls(config, encoder, decoder)
+
+        # load model weights
+        model_file_path = os.path.join(model_path, model_filename)
+        checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
+        state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
+        if not state_dict:
+            raise RuntimeError("state_dict is empty.")
+        if state_dict_strip_prefix:
+            state_dict = {
+                k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
+                for k, v in state_dict.items()
+            }
+        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+        if len(unexpected_keys) > 0:
+            warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
+        if len(missing_keys) > 0:
+            raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
+        return model
+
+    def forward_bak(self, samples):
+        pixel_values, text = samples["image"], samples["text_input"]
+
+        text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
+        decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
+
+        num_channels = pixel_values.shape[1]
+        if num_channels == 1:
+            pixel_values = pixel_values.repeat(1, 3, 1, 1)
+
+        labels = decoder_input_ids * 1
+        labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
+
+        loss = self.model(
+            pixel_values=pixel_values,
+            decoder_input_ids=decoder_input_ids[:, :-1],
+            decoder_attention_mask=decoder_attention_mask[:, :-1],
+            labels=labels[:, 1:],
+        ).loss
+        return {"loss": loss}
+
+    def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95):
+        pixel_values = samples["image"]
+        num_channels = pixel_values.shape[1]
+        if num_channels == 1:
+            pixel_values = pixel_values.repeat(1, 3, 1, 1)
+        
+        kwargs = {}
+        if do_sample:
+            kwargs["temperature"] = temperature
+            kwargs["top_p"] = top_p
+        
+        outputs = super().generate(
+            pixel_values=pixel_values,
+            max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
+            decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
+            do_sample=do_sample,
+            **kwargs,
+        )
+
+        outputs = outputs[:, 1:].cpu().numpy()
+        pred_tokens = self.tokenizer.detokenize(outputs)
+        pred_str = self.tokenizer.token2str(outputs)
+        fixed_str = [latex_rm_whitespace(s) for s in pred_str]
+        return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}
+

+ 8 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py

@@ -0,0 +1,8 @@
+from .configuration_unimer_mbart import UnimerMBartConfig
+from .modeling_unimer_mbart import UnimerMBartModel, UnimerMBartForCausalLM
+
+__all__ = [
+    "UnimerMBartConfig",
+    "UnimerMBartModel",
+    "UnimerMBartForCausalLM",
+]

+ 163 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py

@@ -0,0 +1,163 @@
+# coding=utf-8
+# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""UnimerMBART model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class UnimerMBartConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the MBART
+    [facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50265):
+            Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
+        d_model (`int`, *optional*, defaults to 1024):
+            Dimensionality of the layers and the pooler layer.
+        qk_squeeze (`int`, *optional*, defaults to 2):
+            Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
+            Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
+            thereby accelerating the computation of attention.
+        encoder_layers (`int`, *optional*, defaults to 12):
+            Number of encoder layers.
+        decoder_layers (`int`, *optional*, defaults to 12):
+            Number of decoder layers.
+        encoder_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        decoder_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+        encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+        activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"silu"` and `"gelu_new"` are supported.
+        dropout (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        activation_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for activations inside the fully connected layer.
+        classifier_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for classifier.
+        max_position_embeddings (`int`, *optional*, defaults to 1024):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        init_std (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+            The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+            for more details.
+        decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+            The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+            for more details.
+        scale_embedding (`bool`, *optional*, defaults to `False`):
+            Scale embeddings by diving by sqrt(d_model).
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models)
+        forced_eos_token_id (`int`, *optional*, defaults to 2):
+            The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+            `eos_token_id`.
+
+    Example:
+
+    ```python
+    >>> from transformers import MBartConfig, MBartModel
+
+    >>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
+    >>> configuration = MBartConfig()
+
+    >>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
+    >>> model = MBartModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "unimer-mbart"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+    def __init__(
+        self,
+        vocab_size=50265,
+        max_position_embeddings=1024,
+        encoder_layers=12,
+        encoder_ffn_dim=4096,
+        encoder_attention_heads=16,
+        decoder_layers=12,
+        decoder_ffn_dim=4096,
+        decoder_attention_heads=16,
+        encoder_layerdrop=0.0,
+        decoder_layerdrop=0.0,
+        use_cache=True,
+        is_encoder_decoder=True,
+        activation_function="gelu",
+        d_model=1024,
+        qk_squeeze=2,
+        dropout=0.1,
+        attention_dropout=0.0,
+        activation_dropout=0.0,
+        init_std=0.02,
+        classifier_dropout=0.0,
+        scale_embedding=False,
+        pad_token_id=1,
+        bos_token_id=0,
+        eos_token_id=2,
+        forced_eos_token_id=2,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.d_model = d_model
+        self.qk_squeeze = qk_squeeze
+        self.encoder_ffn_dim = encoder_ffn_dim
+        self.encoder_layers = encoder_layers
+        self.encoder_attention_heads = encoder_attention_heads
+        self.decoder_ffn_dim = decoder_ffn_dim
+        self.decoder_layers = decoder_layers
+        self.decoder_attention_heads = decoder_attention_heads
+        self.dropout = dropout
+        self.attention_dropout = attention_dropout
+        self.activation_dropout = activation_dropout
+        self.activation_function = activation_function
+        self.init_std = init_std
+        self.encoder_layerdrop = encoder_layerdrop
+        self.decoder_layerdrop = decoder_layerdrop
+        self.classifier_dropout = classifier_dropout
+        self.use_cache = use_cache
+        self.num_hidden_layers = encoder_layers
+        self.scale_embedding = scale_embedding  # scale factor will be sqrt(d_model) if True
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            is_encoder_decoder=is_encoder_decoder,
+            forced_eos_token_id=forced_eos_token_id,
+            **kwargs,
+        )

+ 2351 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py

@@ -0,0 +1,2351 @@
+# coding=utf-8
+# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch UnimerMBART model."""
+
+import copy
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import (
+    _prepare_4d_attention_mask,
+    _prepare_4d_attention_mask_for_sdpa,
+    _prepare_4d_causal_attention_mask,
+    _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPastAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    Seq2SeqLMOutput,
+    Seq2SeqModelOutput,
+    Seq2SeqQuestionAnsweringModelOutput,
+    Seq2SeqSequenceClassifierOutput,
+)
+from transformers import GenerationMixin, PreTrainedModel
+from transformers.utils import (
+    add_code_sample_docstrings,
+    add_end_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    is_flash_attn_2_available,
+    is_flash_attn_greater_or_equal_2_10,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_unimer_mbart import UnimerMBartConfig
+
+
+if is_flash_attn_2_available():
+    from flash_attn import flash_attn_func, flash_attn_varlen_func
+    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
+_CONFIG_FOR_DOC = "MBartConfig"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+    max_seqlen_in_batch = seqlens_in_batch.max().item()
+    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+    return (
+        indices,
+        cu_seqlens,
+        max_seqlen_in_batch,
+    )
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
+    """
+    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
+    have a single `decoder_start_token_id` in contrast to other Bart-like models.
+    """
+    prev_output_tokens = input_ids.clone()
+
+    if pad_token_id is None:
+        raise ValueError("self.model.config.pad_token_id has to be defined.")
+    # replace possible -100 values in labels by `pad_token_id`
+    prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
+
+    index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
+    decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
+    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
+    prev_output_tokens[:, 0] = decoder_start_tokens
+
+    return prev_output_tokens
+
+@dataclass
+class CausalLMOutputWithCrossAttentionsAndCounting(CausalLMOutputWithCrossAttentions):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Cross attentions weights after the attention softmax, used to compute the weighted average in the
+            cross-attention heads.
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
+            value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
+            setting. Only relevant if `config.is_decoder = True`.
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        counting:
+            Counting
+    """
+    counting: Optional[torch.FloatTensor] = None
+
+# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
+class UnimerMBartLearnedPositionalEmbedding(nn.Embedding):
+    """
+    This module learns positional embeddings up to a fixed maximum size.
+    """
+
+    def __init__(self, num_embeddings: int, embedding_dim: int):
+        # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
+        # and adjust num_embeddings appropriately. Other models don't have this hack
+        self.offset = 2
+        super().__init__(num_embeddings + self.offset, embedding_dim)
+
+    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+        """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+        bsz, seq_len = input_ids.shape[:2]
+        positions = torch.arange(
+            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+        ).expand(bsz, -1)
+
+        return super().forward(positions + self.offset)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
+class UnimerMBartScaledWordEmbedding(nn.Embedding):
+    """
+    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+    """
+
+    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
+        super().__init__(num_embeddings, embedding_dim, padding_idx)
+        self.embed_scale = embed_scale
+
+    def forward(self, input_ids: torch.Tensor):
+        return super().forward(input_ids) * self.embed_scale
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
+class UnimerMBartAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper, with qk_squeeze"""
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        is_decoder: bool = False,
+        bias: bool = True,
+        is_causal: bool = False,
+        *,
+        config: UnimerMBartConfig,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {num_heads})."
+            )
+        
+        self.squeeze_dim = embed_dim // config.qk_squeeze
+        self.squeeze_head_dim = self.squeeze_dim // num_heads
+        self.scaling = self.squeeze_head_dim**-0.5
+        self.is_decoder = is_decoder
+        self.is_causal = is_causal
+
+        self.q_proj = nn.Linear(embed_dim, self.squeeze_dim, bias=bias)
+        self.k_proj = nn.Linear(embed_dim, self.squeeze_dim, bias=bias)
+        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+    def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
+
+    def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+        else:
+            # self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
+        value_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*value_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart
+class UnimerMBartFlashAttention2(UnimerMBartAttention):
+    """
+    MBart flash attention module. This module inherits from `MBartSqueezeAttention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+    flash attention and deal with padding tokens in case the input contains any of them.
+    """
+
+    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+    # def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+    #     return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+    def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim)
+
+    def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        # MBartFlashAttention2 attention does not support output_attentions
+        if output_attentions:
+            raise ValueError("MBartFlashAttention2 attention does not support output_attentions")
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, q_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self._shape_qk(self.q_proj(hidden_states), -1, bsz)
+
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0].transpose(1, 2)
+            value_states = past_key_value[1].transpose(1, 2)
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+            value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+        else:
+            # self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (LlamaRMSNorm handles it correctly)
+
+        input_dtype = query_states.dtype
+        if input_dtype == torch.float32:
+            if torch.is_autocast_enabled():
+                target_dtype = torch.get_autocast_gpu_dtype()
+            # Handle the case where the model is quantized
+            elif hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            else:
+                target_dtype = self.q_proj.weight.dtype
+
+            logger.warning_once(
+                f"The input hidden states seems to be silently casted in float32, this might be related to"
+                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+                f" {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        attn_output = self._flash_attention_forward(
+            query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
+        )
+
+        attn_output = attn_output.reshape(bsz, q_len, -1)
+        attn_output = self.out_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+    def _flash_attention_forward(
+        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+    ):
+        """
+        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+        first unpad the input, then computes the attention scores and pad the final attention scores.
+
+        Args:
+            query_states (`torch.Tensor`):
+                Input query states to be passed to Flash Attention API
+            key_states (`torch.Tensor`):
+                Input key states to be passed to Flash Attention API
+            value_states (`torch.Tensor`):
+                Input value states to be passed to Flash Attention API
+            attention_mask (`torch.Tensor`):
+                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+                position of padding tokens and 1 for the position of non-padding tokens.
+            dropout (`float`):
+                Attention dropout
+            softmax_scale (`float`, *optional*):
+                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+        """
+        if not self._flash_attn_uses_top_left_mask:
+            causal = self.is_causal
+        else:
+            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+            causal = self.is_causal and query_length != 1
+
+        # Contains at least one padding token in the sequence
+        if attention_mask is not None:
+            batch_size = query_states.shape[0]
+
+            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+                query_states, key_states, value_states, attention_mask, query_length
+            )
+
+            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+            attn_output_unpad = flash_attn_varlen_func(
+                query_states,
+                key_states,
+                value_states,
+                cu_seqlens_q=cu_seqlens_q,
+                cu_seqlens_k=cu_seqlens_k,
+                max_seqlen_q=max_seqlen_in_batch_q,
+                max_seqlen_k=max_seqlen_in_batch_k,
+                dropout_p=dropout,
+                softmax_scale=softmax_scale,
+                causal=causal,
+            )
+
+            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+        else:
+            attn_output = flash_attn_func(
+                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+            )
+
+        return attn_output
+
+    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+        key_layer = index_first_axis(
+            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+        )
+        value_layer = index_first_axis(
+            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+        )
+        if query_length == kv_seq_len:
+            query_layer = index_first_axis(
+                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+            )
+            cu_seqlens_q = cu_seqlens_k
+            max_seqlen_in_batch_q = max_seqlen_in_batch_k
+            indices_q = indices_k
+        elif query_length == 1:
+            max_seqlen_in_batch_q = 1
+            cu_seqlens_q = torch.arange(
+                batch_size + 1, dtype=torch.int32, device=query_layer.device
+            )  # There is a memcpy here, that is very bad.
+            indices_q = cu_seqlens_q[:-1]
+            query_layer = query_layer.squeeze(1)
+        else:
+            # The -q_len: slice assumes left padding.
+            attention_mask = attention_mask[:, -query_length:]
+            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+        return (
+            query_layer,
+            key_layer,
+            value_layer,
+            indices_q,
+            (cu_seqlens_q, cu_seqlens_k),
+            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+        )
+
+class UnimerMBartSdpaAttention(UnimerMBartAttention):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+        if output_attentions or layer_head_mask is not None:
+            # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+            logger.warning(
+                "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+                ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+            )
+            return super().forward(
+                hidden_states,
+                key_value_states=key_value_states,
+                past_key_value=past_key_value,
+                attention_mask=attention_mask,
+                layer_head_mask=layer_head_mask,
+                output_attentions=output_attentions,
+            )
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states)
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+        else:
+            # self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        query_states = self._shape_qk(query_states, tgt_len, bsz)
+
+        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+        # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+        is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+
+        # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
+        # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
+        attn_output = torch.nn.functional.scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=attention_mask,
+            dropout_p=self.dropout if self.training else 0.0,
+            is_causal=is_causal,
+        )
+
+        if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+UNIMER_MBART_ATTENTION_CLASSES = {
+    "eager": UnimerMBartAttention,
+    "flash_attention_2": UnimerMBartFlashAttention2,
+    "sdpa": UnimerMBartSdpaAttention,
+}
+
+
+class UnimerMBartEncoderLayer(nn.Module):
+    def __init__(self, config: UnimerMBartConfig):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = UNIMER_MBART_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.encoder_attention_heads,
+            dropout=config.attention_dropout,
+            config=config,
+        )
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
+        layer_head_mask: torch.Tensor,
+        output_attentions: bool = False,
+    ) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+                `(encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        hidden_states, attn_weights, _ = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        if hidden_states.dtype == torch.float16 and (
+            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+        ):
+            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+class UnimerMBartDecoderLayer(nn.Module):
+    def __init__(self, config: UnimerMBartConfig):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = UNIMER_MBART_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            is_causal=True,
+            config=config,
+        )
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.encoder_attn = UNIMER_MBART_ATTENTION_CLASSES[config._attn_implementation](
+            self.embed_dim,
+            config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            config=config,
+        )
+        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = True,
+    ) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            encoder_hidden_states (`torch.FloatTensor`):
+                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+                `(encoder_attention_heads,)`.
+            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+                size `(decoder_attention_heads,)`.
+            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # Self Attention
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        # add present self-attn cache to positions 1,2 of present_key_value tuple
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            past_key_value=self_attn_past_key_value,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        # Cross-Attention Block
+        cross_attn_present_key_value = None
+        cross_attn_weights = None
+        if encoder_hidden_states is not None:
+            residual = hidden_states
+            hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+                hidden_states=hidden_states,
+                key_value_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+                layer_head_mask=cross_attn_layer_head_mask,
+                past_key_value=cross_attn_past_key_value,
+                output_attentions=output_attentions,
+            )
+            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+            hidden_states = residual + hidden_states
+
+            # add cross-attn to positions 3,4 of present_key_value tuple
+            present_key_value = present_key_value + cross_attn_present_key_value
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights, cross_attn_weights)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MBart
+class UnimerMBartClassificationHead(nn.Module):
+    """Head for sentence-level classification tasks."""
+
+    def __init__(
+        self,
+        input_dim: int,
+        inner_dim: int,
+        num_classes: int,
+        pooler_dropout: float,
+    ):
+        super().__init__()
+        self.dense = nn.Linear(input_dim, inner_dim)
+        self.dropout = nn.Dropout(p=pooler_dropout)
+        self.out_proj = nn.Linear(inner_dim, num_classes)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.dense(hidden_states)
+        hidden_states = torch.tanh(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.out_proj(hidden_states)
+        return hidden_states
+
+
+class UnimerMBartPreTrainedModel(PreTrainedModel):
+    config_class = UnimerMBartConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["MBartDecoderLayer", "MBartSqueezeAttention"]
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+
+    def _init_weights(self, module):
+        std = self.config.init_std
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+    @property
+    def dummy_inputs(self):
+        pad_token = self.config.pad_token_id
+        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+        dummy_inputs = {
+            "attention_mask": input_ids.ne(pad_token),
+            "input_ids": input_ids,
+        }
+        return dummy_inputs
+
+
+MBART_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`MBartConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MBART_GENERATION_EXAMPLE = r"""
+    Translation example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
+
+    >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
+    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")
+
+    >>> example_english_phrase = "42 is the answer"
+    >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
+
+    >>> # Translate
+    >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
+    >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+    '42 este răspuns'
+    ```
+
+    Mask filling example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
+
+    >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
+    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+
+    >>> # de_DE is the language symbol id <LID> for German
+    >>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
+
+    >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
+    >>> logits = model(input_ids).logits
+
+    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+    >>> probs = logits[0, masked_index].softmax(dim=0)
+    >>> values, predictions = probs.topk(5)
+
+    >>> tokenizer.decode(predictions).split()
+    ['nett', 'sehr', 'ganz', 'nicht', 'so']
+    ```
+"""
+
+MBART_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+            MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
+            varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
+            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            For translation and summarization training, `decoder_input_ids` should be provided. If no
+            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
+            for denoising pre-training following the paper.
+        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+            1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+            than the model's internal embedding lookup matrix.
+        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+            input (see `past_key_values`). This is useful if you want more control over how to convert
+            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+            of `inputs_embeds`.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class UnimerMBartEncoder(UnimerMBartPreTrainedModel):
+    """
+    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+    [`MBartEncoderLayer`].
+
+    Args:
+        config: MBartConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self, config: UnimerMBartConfig, embed_tokens: Optional[nn.Embedding] = None):
+        super().__init__(config)
+
+        self.dropout = config.dropout
+        self.layerdrop = config.encoder_layerdrop
+
+        embed_dim = config.d_model
+        self.padding_idx = config.pad_token_id
+        self.max_source_positions = config.max_position_embeddings
+        embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+        self.embed_tokens = UnimerMBartScaledWordEmbedding(
+            config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
+        )
+
+        if embed_tokens is not None:
+            self.embed_tokens.weight = embed_tokens.weight
+
+        self.embed_positions = UnimerMBartLearnedPositionalEmbedding(
+            config.max_position_embeddings,
+            embed_dim,
+        )
+        self.layers = nn.ModuleList([UnimerMBartEncoderLayer(config) for _ in range(config.encoder_layers)])
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        self._use_sdpa = config._attn_implementation == "sdpa"
+        self.layernorm_embedding = nn.LayerNorm(embed_dim)
+        self.layer_norm = nn.LayerNorm(config.d_model)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def _backward_compatibility_gradient_checkpointing(self):
+        # Override to not delete the attribute from the config
+        if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
+            self.gradient_checkpointing_enable()
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutput]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.shape
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        embed_pos = self.embed_positions(input)
+
+        hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device)
+        hidden_states = self.layernorm_embedding(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        # expand attention_mask
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            if self._use_flash_attention_2:
+                attention_mask = attention_mask if 0 in attention_mask else None
+            elif self._use_sdpa and head_mask is None and not output_attentions:
+                # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        # check if head_mask has a correct number of layers specified if desired
+        if head_mask is not None:
+            if head_mask.size()[0] != len(self.layers):
+                raise ValueError(
+                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+                    f" {head_mask.size()[0]}."
+                )
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            to_drop = False
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:  # skip the layer
+                    to_drop = True
+
+            if to_drop:
+                layer_outputs = (None, None)
+            else:
+                if self.gradient_checkpointing and self.training:
+                    layer_outputs = self._gradient_checkpointing_func(
+                        encoder_layer.__call__,
+                        hidden_states,
+                        attention_mask,
+                        (head_mask[idx] if head_mask is not None else None),
+                        output_attentions,
+                    )
+                else:
+                    layer_outputs = encoder_layer(
+                        hidden_states,
+                        attention_mask,
+                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                        output_attentions=output_attentions,
+                    )
+
+                hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+
+class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
+    """
+    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
+
+    Args:
+        config: MBartConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self, config: UnimerMBartConfig, embed_tokens: Optional[nn.Embedding] = None):
+        super().__init__(config)
+        self.dropout = config.dropout
+        self.layerdrop = config.decoder_layerdrop
+        self.padding_idx = config.pad_token_id
+        self.max_target_positions = config.max_position_embeddings
+        embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+        self.embed_tokens = UnimerMBartScaledWordEmbedding(
+            config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
+        )
+
+        if embed_tokens is not None:
+            self.embed_tokens.weight = embed_tokens.weight
+
+        self.embed_positions = UnimerMBartLearnedPositionalEmbedding(
+            config.max_position_embeddings,
+            config.d_model,
+        )
+        self.layers = nn.ModuleList([UnimerMBartDecoderLayer(config) for _ in range(config.decoder_layers)])
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        self._use_sdpa = config._attn_implementation == "sdpa"
+        self.layernorm_embedding = nn.LayerNorm(config.d_model)
+        self.layer_norm = nn.LayerNorm(config.d_model)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                of the decoder.
+            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+                selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+                cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if self._use_flash_attention_2:
+            # 2d mask is passed through the layers
+            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+        elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
+            # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+            # the manual implementation that requires a 4D causal mask in all cases.
+            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+                attention_mask,
+                input_shape,
+                inputs_embeds,
+                past_key_values_length,
+            )
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask, input_shape, inputs_embeds, past_key_values_length
+            )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
+                # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+                    encoder_attention_mask,
+                    inputs_embeds.dtype,
+                    tgt_len=input_shape[-1],
+                )
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.embed_positions(input, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+        # TODO: add counting context weight to hidden_states
+        if count_pred is not None:
+            count_context_weight = self.counting_context_weight(count_pred)
+            hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
+
+        hidden_states = self.layernorm_embedding(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+@add_start_docstrings(
+    "The bare MBART Model outputting raw hidden-states without any specific head on top.",
+    MBART_START_DOCSTRING,
+)
+class UnimerMBartModel(UnimerMBartPreTrainedModel):
+    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+    def __init__(self, config: UnimerMBartConfig):
+        super().__init__(config)
+
+        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+        self.encoder = UnimerMBartEncoder(config, self.shared)
+        self.decoder = UnimerMBartDecoder(config, self.shared)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.shared
+
+    def set_input_embeddings(self, value):
+        self.shared = value
+        self.encoder.embed_tokens = self.shared
+        self.decoder.embed_tokens = self.shared
+
+    def get_encoder(self):
+        return self.encoder
+
+    def get_decoder(self):
+        return self.decoder
+
+    def _tie_weights(self):
+        if self.config.tie_word_embeddings:
+            self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
+            self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=Seq2SeqModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Seq2SeqModelOutput, Tuple[torch.FloatTensor]]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # different to other models, MBart automatically creates decoder_input_ids from
+        # input_ids if no decoder_input_ids are provided
+        if decoder_input_ids is None and decoder_inputs_embeds is None:
+            decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
+
+        if encoder_outputs is None:
+            encoder_outputs = self.encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                head_mask=head_mask,
+                inputs_embeds=inputs_embeds,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+            encoder_outputs = BaseModelOutput(
+                last_hidden_state=encoder_outputs[0],
+                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+            )
+
+        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_outputs[0],
+            encoder_attention_mask=attention_mask,
+            head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if not return_dict:
+            return decoder_outputs + encoder_outputs
+
+        return Seq2SeqModelOutput(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.",
+    MBART_START_DOCSTRING,
+)
+class UnimerMBartForConditionalGeneration(UnimerMBartPreTrainedModel, GenerationMixin):
+    base_model_prefix = "model"
+    _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
+
+    def __init__(self, config: UnimerMBartConfig):
+        super().__init__(config)
+        self.model = UnimerMBartModel(config)
+        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_encoder(self):
+        return self.model.get_encoder()
+
+    def get_decoder(self):
+        return self.model.get_decoder()
+
+    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
+        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+        return new_embeddings
+
+    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+        old_num_tokens = self.final_logits_bias.shape[-1]
+        if new_num_tokens <= old_num_tokens:
+            new_bias = self.final_logits_bias[:, :new_num_tokens]
+        else:
+            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+        self.register_buffer("final_logits_bias", new_bias)
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    @add_end_docstrings(MBART_GENERATION_EXAMPLE)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if labels is not None:
+            if use_cache:
+                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+            use_cache = False
+            if decoder_input_ids is None and decoder_inputs_embeds is None:
+                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            encoder_outputs=encoder_outputs,
+            decoder_attention_mask=decoder_attention_mask,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + outputs[1:]
+            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+        return Seq2SeqLMOutput(
+            loss=masked_lm_loss,
+            logits=lm_logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        decoder_input_ids,
+        past_key_values=None,
+        attention_mask=None,
+        head_mask=None,
+        decoder_head_mask=None,
+        cross_attn_head_mask=None,
+        use_cache=None,
+        encoder_outputs=None,
+        **kwargs,
+    ):
+        # cut decoder_input_ids if past is used
+        if past_key_values is not None:
+            past_length = past_key_values[0][0].shape[2]
+
+            # Some generation methods already pass only the last input ID
+            if decoder_input_ids.shape[1] > past_length:
+                remove_prefix_length = past_length
+            else:
+                # Default to old behavior: keep only final ID
+                remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+        return {
+            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
+            "encoder_outputs": encoder_outputs,
+            "past_key_values": past_key_values,
+            "decoder_input_ids": decoder_input_ids,
+            "attention_mask": attention_mask,
+            "head_mask": head_mask,
+            "decoder_head_mask": decoder_head_mask,
+            "cross_attn_head_mask": cross_attn_head_mask,
+            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
+        }
+
+    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+        return shift_tokens_right(labels, self.config.pad_token_id)
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            # cached cross_attention states don't have to be reordered -> they are always the same
+            reordered_past += (
+                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+                + layer_past[2:],
+            )
+        return reordered_past
+
+
+@add_start_docstrings(
+    """
+    MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
+    tasks.
+    """,
+    MBART_START_DOCSTRING,
+)
+class UnimerMBartForSequenceClassification(UnimerMBartPreTrainedModel):
+    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
+
+    def __init__(self, config: UnimerMBartConfig, **kwargs):
+        super().__init__(config, **kwargs)
+        self.model = UnimerMBartModel(config)
+        self.classification_head = UnimerMBartClassificationHead(
+            config.d_model,
+            config.d_model,
+            config.num_labels,
+            config.classifier_dropout,
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=Seq2SeqSequenceClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if labels is not None:
+            use_cache = False
+
+        if input_ids is None and inputs_embeds is not None:
+            raise NotImplementedError(
+                f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
+            )
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            encoder_outputs=encoder_outputs,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = outputs[0]  # last hidden state
+
+        eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
+
+        if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
+            raise ValueError("All examples must have the same number of <eos> tokens.")
+        sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
+            :, -1, :
+        ]
+        logits = self.classification_head(sentence_representation)
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            if self.config.problem_type is None:
+                if self.config.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.config.num_labels == 1:
+                    loss = loss_fct(logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(logits, labels)
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return Seq2SeqSequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    MBART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+    layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+    """,
+    MBART_START_DOCSTRING,
+)
+class UnimerMBartForQuestionAnswering(UnimerMBartPreTrainedModel):
+    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        config.num_labels = 2
+        self.num_labels = config.num_labels
+
+        self.model = UnimerMBartModel(config)
+        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=Seq2SeqQuestionAnsweringModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
+    def forward(
+        self,
+        input_ids: torch.Tensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+        start_positions: Optional[torch.LongTensor] = None,
+        end_positions: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
+        r"""
+        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for position (index) of the start of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+            are not taken into account for computing the loss.
+        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for position (index) of the end of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+            are not taken into account for computing the loss.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if start_positions is not None and end_positions is not None:
+            use_cache = False
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            encoder_outputs=encoder_outputs,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        logits = self.qa_outputs(sequence_output)
+        start_logits, end_logits = logits.split(1, dim=-1)
+        start_logits = start_logits.squeeze(-1).contiguous()
+        end_logits = end_logits.squeeze(-1).contiguous()
+
+        total_loss = None
+        if start_positions is not None and end_positions is not None:
+            # If we are on multi-GPU, split add a dimension
+            if len(start_positions.size()) > 1:
+                start_positions = start_positions.squeeze(-1)
+            if len(end_positions.size()) > 1:
+                end_positions = end_positions.squeeze(-1)
+            # sometimes the start/end positions are outside our model inputs, we ignore these terms
+            ignored_index = start_logits.size(1)
+            start_positions = start_positions.clamp(0, ignored_index)
+            end_positions = end_positions.clamp(0, ignored_index)
+
+            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+            start_loss = loss_fct(start_logits, start_positions)
+            end_loss = loss_fct(end_logits, end_positions)
+            total_loss = (start_loss + end_loss) / 2
+
+        if not return_dict:
+            output = (
+                start_logits,
+                end_logits,
+            ) + outputs[1:]
+            return ((total_loss,) + output) if total_loss is not None else output
+
+        return Seq2SeqQuestionAnsweringModelOutput(
+            loss=total_loss,
+            start_logits=start_logits,
+            end_logits=end_logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart
+class UnimerMBartDecoderWrapper(UnimerMBartPreTrainedModel):
+    """
+    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+    used in combination with the [`EncoderDecoderModel`] framework.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.decoder = UnimerMBartDecoder(config)
+
+    def forward(self, *args, **kwargs):
+        return self.decoder(*args, **kwargs)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
+class UnimerMBartForCausalLM(UnimerMBartPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        config = copy.deepcopy(config)
+        config.is_decoder = True
+        config.is_encoder_decoder = False
+        super().__init__(config)
+        self.model = UnimerMBartDecoderWrapper(config)
+
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.decoder.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.decoder.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model.decoder = decoder
+
+    def get_decoder(self):
+        return self.model.decoder
+
+    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentionsAndCounting, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        count_gt: Optional[torch.LongTensor] = None,
+    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                if the model is configured as a decoder.
+            encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+                in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, MBartForCausalLM
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+        >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25", add_cross_attention=False)
+        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+        >>> outputs = model(**inputs)
+
+        >>> logits = outputs.logits
+        >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+        >>> list(logits.shape) == expected_shape
+        True
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        count_pred = None
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            count_pred=count_pred,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        logits = self.lm_head(outputs[0])
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentionsAndCounting(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+            counting=count_pred,
+        )
+
+    def prepare_inputs_for_generation(
+        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
+    ):
+        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+        if attention_mask is None:
+            attention_mask = input_ids.new_ones(input_ids.shape)
+
+        if past_key_values:
+            past_length = past_key_values[0][0].shape[2]
+
+            # Some generation methods already pass only the last input ID
+            if input_ids.shape[1] > past_length:
+                remove_prefix_length = past_length
+            else:
+                # Default to old behavior: keep only final ID
+                remove_prefix_length = input_ids.shape[1] - 1
+
+            input_ids = input_ids[:, remove_prefix_length:]
+        # first step, decoder_cached_states are empty
+        return {
+            "input_ids": input_ids,  # encoder_outputs is defined. input_ids not needed
+            "attention_mask": attention_mask,
+            "past_key_values": past_key_values,
+            "use_cache": use_cache,
+        }
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            reordered_past += (
+                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+            )
+        return reordered_past

+ 0 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py


+ 9 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py

@@ -0,0 +1,9 @@
+from .configuration_unimer_swin import UnimerSwinConfig
+from .modeling_unimer_swin import UnimerSwinModel
+from .image_processing_unimer_swin import UnimerSwinImageProcessor
+
+__all__ = [
+    "UnimerSwinConfig",
+    "UnimerSwinModel",
+    "UnimerSwinImageProcessor",
+]

+ 132 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/configuration_unimer_swin.py

@@ -0,0 +1,132 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Donut Swin Transformer model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class UnimerSwinConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`UnimerSwinModel`]. It is used to instantiate a
+    Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the Donut
+    [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 4):
+            The size (resolution) of each patch.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        embed_dim (`int`, *optional*, defaults to 96):
+            Dimensionality of patch embedding.
+        depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
+            Depth of each layer in the Transformer encoder.
+        num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`):
+            Number of attention heads in each layer of the Transformer encoder.
+        window_size (`int`, *optional*, defaults to 7):
+            Size of windows.
+        mlp_ratio (`float`, *optional*, defaults to 4.0):
+            Ratio of MLP hidden dimensionality to embedding dimensionality.
+        qkv_bias (`bool`, *optional*, defaults to `True`):
+            Whether or not a learnable bias should be added to the queries, keys and values.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings and encoder.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        drop_path_rate (`float`, *optional*, defaults to 0.1):
+            Stochastic depth rate.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+            `"selu"` and `"gelu_new"` are supported.
+        use_absolute_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether or not to add absolute position embeddings to the patch embeddings.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the layer normalization layers.
+
+    Example:
+
+    ```python
+    >>> from transformers import UnimerSwinConfig, UnimerSwinModel
+
+    >>> # Initializing a Donut naver-clova-ix/donut-base style configuration
+    >>> configuration = UnimerSwinConfig()
+
+    >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration
+    >>> model = UnimerSwinModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "unimer-swin"
+
+    attribute_map = {
+        "num_attention_heads": "num_heads",
+        "num_hidden_layers": "num_layers",
+    }
+
+    def __init__(
+        self,
+        image_size=224,
+        patch_size=4,
+        num_channels=3,
+        embed_dim=96,
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 24],
+        window_size=7,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        drop_path_rate=0.1,
+        hidden_act="gelu",
+        use_absolute_embeddings=False,
+        initializer_range=0.02,
+        layer_norm_eps=1e-5,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.embed_dim = embed_dim
+        self.depths = depths
+        self.num_layers = len(depths)
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.drop_path_rate = drop_path_rate
+        self.hidden_act = hidden_act
+        self.use_absolute_embeddings = use_absolute_embeddings
+        self.layer_norm_eps = layer_norm_eps
+        self.initializer_range = initializer_range
+        # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
+        # this indicates the channel dimension after the last stage of the model
+        self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))

+ 114 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py

@@ -0,0 +1,114 @@
+from transformers.image_processing_utils import BaseImageProcessor
+import numpy as np
+import cv2
+import albumentations as alb
+from albumentations.pytorch import ToTensorV2
+
+
+# TODO: dereference cv2 if possible
+class UnimerSwinImageProcessor(BaseImageProcessor):
+    def __init__(
+            self,
+            image_size = (192, 672),
+        ):
+        self.input_size = [int(_) for _ in image_size]
+        assert len(self.input_size) == 2
+    
+        self.transform = alb.Compose(
+            [
+                alb.ToGray(always_apply=True),
+                alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
+                # alb.Sharpen()
+                ToTensorV2(),
+            ]
+        )
+
+    def __call__(self, item):
+        image = self.prepare_input(item)
+        return self.transform(image=image)['image'][:1]
+
+    @staticmethod
+    def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
+        """Crop margins of image using NumPy operations"""
+        # Convert to grayscale if it's a color image
+        if len(img.shape) == 3 and img.shape[2] == 3:
+            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+        else:
+            gray = img.copy()
+
+        # Normalize and threshold
+        if gray.max() == gray.min():
+            return img
+
+        normalized = (((gray - gray.min()) / (gray.max() - gray.min())) * 255).astype(np.uint8)
+        binary = 255 * (normalized < 200).astype(np.uint8)
+
+        # Find bounding box
+        coords = cv2.findNonZero(binary)  # Find all non-zero points (text)
+        x, y, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+
+        # Return cropped image
+        return img[y:y + h, x:x + w]
+
+    def prepare_input(self, img, random_padding: bool = False):
+        """
+        Convert PIL Image or numpy array to properly sized and padded image after:
+            - crop margins
+            - resize while maintaining aspect ratio
+            - pad to target size
+        """
+        if img is None:
+            return None
+
+        try:
+            img = self.crop_margin_numpy(img)
+        except Exception:
+            # might throw an error for broken files
+            return None
+
+        if img.shape[0] == 0 or img.shape[1] == 0:
+            return None
+
+        # Resize while preserving aspect ratio
+        h, w = img.shape[:2]
+        scale = min(self.input_size[0] / h, self.input_size[1] / w)
+        new_h, new_w = int(h * scale), int(w * scale)
+        resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
+
+        # Calculate padding
+        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
+
+        # Create and apply padding
+        channels = 3 if len(img.shape) == 3 else 1
+        padded_img = np.full((self.input_size[0], self.input_size[1], channels), 255, dtype=np.uint8)
+        padded_img[pad_height:pad_height + new_h, pad_width:pad_width + new_w] = resized_img
+
+        return padded_img
+
+    def _calculate_padding(self, new_w, new_h, random_padding):
+        """Calculate padding values for PIL images"""
+        delta_width = self.input_size[1] - new_w
+        delta_height = self.input_size[0] - new_h
+
+        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
+
+        return (
+            pad_width,
+            pad_height,
+            delta_width - pad_width,
+            delta_height - pad_height,
+        )
+
+    def _get_padding_values(self, new_w, new_h, random_padding):
+        """Get padding values based on image dimensions and padding strategy"""
+        delta_width = self.input_size[1] - new_w
+        delta_height = self.input_size[0] - new_h
+
+        if random_padding:
+            pad_width = np.random.randint(low=0, high=delta_width + 1)
+            pad_height = np.random.randint(low=0, high=delta_height + 1)
+        else:
+            pad_width = delta_width // 2
+            pad_height = delta_height // 2
+
+        return pad_width, pad_height

+ 1084 - 0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py

@@ -0,0 +1,1084 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch UnimerSwin Transformer model.
+
+This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden
+states."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
+from transformers.utils import (
+    ModelOutput,
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    torch_int,
+)
+from .configuration_unimer_swin import UnimerSwinConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "UnimerSwinConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->UnimerSwin
+class UnimerSwinEncoderOutput(ModelOutput):
+    """
+    UnimerSwin encoder's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, hidden_size, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+            include the spatial dimensions.
+    """
+
+    last_hidden_state: torch.FloatTensor = None
+    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+    reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->UnimerSwin
+class UnimerSwinModelOutput(ModelOutput):
+    """
+    UnimerSwin model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+            Average pooling of the last layer hidden-state.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, hidden_size, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+            include the spatial dimensions.
+    """
+
+    last_hidden_state: torch.FloatTensor = None
+    pooler_output: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+    reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+# Copied from transformers.models.swin.modeling_swin.window_partition
+def window_partition(input_feature, window_size):
+    """
+    Partitions the given input into windows.
+    """
+    batch_size, height, width, num_channels = input_feature.shape
+    input_feature = input_feature.view(
+        batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
+    )
+    windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+    return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.window_reverse
+def window_reverse(windows, window_size, height, width):
+    """
+    Merges windows to produce higher resolution features.
+    """
+    num_channels = windows.shape[-1]
+    windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
+    windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
+    return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->UnimerSwin
+class UnimerSwinEmbeddings(nn.Module):
+    """
+    Construct the patch and position embeddings. Optionally, also the mask token.
+    """
+
+    def __init__(self, config, use_mask_token=False):
+        super().__init__()
+
+        self.patch_embeddings = UnimerSwinPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.patch_grid = self.patch_embeddings.grid_size
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+
+        if config.use_absolute_embeddings:
+            self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+        else:
+            self.position_embeddings = None
+
+        ### code added. ###
+        if config.use_2d_embeddings:
+            self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim))
+            self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim))
+        else:
+            self.row_embeddings = None
+            self.column_embeddings = None
+        ######
+
+        self.norm = nn.LayerNorm(config.embed_dim)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images.
+
+        Source:
+        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+        """
+
+        num_patches = embeddings.shape[1] - 1
+        num_positions = self.position_embeddings.shape[1] - 1
+        if num_patches == num_positions and height == width:
+            return self.position_embeddings
+        class_pos_embed = self.position_embeddings[:, 0]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        dim = embeddings.shape[-1]
+        h0 = height // self.config.patch_size
+        w0 = width // self.config.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        h0, w0 = h0 + 0.1, w0 + 0.1
+        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+            mode="bicubic",
+            align_corners=False,
+        )
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor],
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> Tuple[torch.Tensor]:
+        _, num_channels, height, width = pixel_values.shape
+        embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+        embeddings = self.norm(embeddings)
+        batch_size, seq_len, _ = embeddings.size()
+
+        if bool_masked_pos is not None:
+            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        if self.position_embeddings is not None:
+            # if interpolate_pos_encoding:
+            #     embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+            # else:
+            #     embeddings = embeddings + self.position_embeddings
+            embeddings = embeddings + self.position_embeddings[:, :seq_len, :] # code edited.
+
+        ### code added. ###
+        if self.row_embeddings is not None and self.column_embeddings is not None:
+            # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
+            row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1)
+            column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1)
+            embeddings = embeddings + row_embeddings + column_embeddings
+        ######
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings, output_dimensions
+
+class StemLayer(nn.Module):
+    r""" Stem layer of InternImage
+    Args:
+        in_chans (int): number of input channels
+        out_chans (int): number of output channels
+        act_layer (str): activation layer
+        norm_layer (str): normalization layer
+    """
+
+    def __init__(self, in_chans=3, out_chans=96, act_layer=nn.GELU, norm_layer='BN'):
+        super().__init__()
+        self.conv1 = nn.Conv2d(in_chans, out_chans // 2, kernel_size=3, stride=2, padding=1)
+        self.norm1 = self.build_norm_layer(out_chans // 2, norm_layer)
+        self.act = act_layer()
+        self.conv2 = nn.Conv2d(out_chans // 2, out_chans, kernel_size=3, stride=2, padding=1)
+
+    def build_norm_layer(self, dim, norm_layer):
+        layers = []
+        if norm_layer == 'BN':
+            layers.append(nn.BatchNorm2d(dim))
+        else:
+            raise NotImplementedError(f'build_norm_layer does not support {norm_layer}')
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.norm1(x)
+        x = self.act(x)
+        x = self.conv2(x)
+        return x
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->UnimerSwin
+class UnimerSwinPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.embed_dim
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+        self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+        ### code edited. ###
+        # self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+        self.projection = StemLayer(in_chans=num_channels, out_chans=hidden_size)
+        ###
+
+    def maybe_pad(self, pixel_values, height, width):
+        if width % self.patch_size[1] != 0:
+            pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
+            pixel_values = nn.functional.pad(pixel_values, pad_values)
+        if height % self.patch_size[0] != 0:
+            pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
+            pixel_values = nn.functional.pad(pixel_values, pad_values)
+        return pixel_values
+
+    def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+        _, num_channels, height, width = pixel_values.shape
+        # pad the input to be divisible by self.patch_size, if needed
+        pixel_values = self.maybe_pad(pixel_values, height, width)
+        embeddings = self.projection(pixel_values)
+        _, _, height, width = embeddings.shape
+        output_dimensions = (height, width)
+        embeddings = embeddings.flatten(2).transpose(1, 2)
+
+        return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
+class UnimerSwinPatchMerging(nn.Module):
+    """
+    Patch Merging Layer.
+
+    Args:
+        input_resolution (`Tuple[int]`):
+            Resolution of input feature.
+        dim (`int`):
+            Number of input channels.
+        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+            Normalization layer class.
+    """
+
+    def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def maybe_pad(self, input_feature, height, width):
+        should_pad = (height % 2 == 1) or (width % 2 == 1)
+        if should_pad:
+            pad_values = (0, 0, 0, width % 2, 0, height % 2)
+            input_feature = nn.functional.pad(input_feature, pad_values)
+
+        return input_feature
+
+    def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
+        height, width = input_dimensions
+        # `dim` is height * width
+        batch_size, dim, num_channels = input_feature.shape
+
+        input_feature = input_feature.view(batch_size, height, width, num_channels)
+        # pad input to be disible by width and height, if needed
+        input_feature = self.maybe_pad(input_feature, height, width)
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_0 = input_feature[:, 0::2, 0::2, :]
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_1 = input_feature[:, 1::2, 0::2, :]
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_2 = input_feature[:, 0::2, 1::2, :]
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_3 = input_feature[:, 1::2, 1::2, :]
+        # batch_size height/2 width/2 4*num_channels
+        input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+        input_feature = input_feature.view(batch_size, -1, 4 * num_channels)  # batch_size height/2*width/2 4*C
+
+        input_feature = self.norm(input_feature)
+        input_feature = self.reduction(input_feature)
+
+        return input_feature
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinDropPath
+class UnimerSwinDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return "p={}".format(self.drop_prob)
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->UnimerSwin
+class UnimerSwinSelfAttention(nn.Module):
+    def __init__(self, config, dim, num_heads, window_size):
+        super().__init__()
+        if dim % num_heads != 0:
+            raise ValueError(
+                f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+            )
+
+        self.num_attention_heads = num_heads
+        self.attention_head_size = int(dim / num_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+        self.window_size = (
+            window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+        )
+
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
+        )
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
+        coords_flatten = torch.flatten(coords, 1)
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+        relative_coords[:, :, 0] += self.window_size[0] - 1
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+        self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+        self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor]:
+        batch_size, dim, num_channels = hidden_states.shape
+        mixed_query_layer = self.query(hidden_states)
+
+        key_layer = self.transpose_for_scores(self.key(hidden_states))
+        value_layer = self.transpose_for_scores(self.value(hidden_states))
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
+        relative_position_bias = relative_position_bias.view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+        )
+
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+        attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
+
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in UnimerSwinModel forward() function)
+            mask_shape = attention_mask.shape[0]
+            attention_scores = attention_scores.view(
+                batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
+            )
+            attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
+            attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
+class UnimerSwinSelfOutput(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(dim, dim)
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->UnimerSwin
+class UnimerSwinAttention(nn.Module):
+    def __init__(self, config, dim, num_heads, window_size):
+        super().__init__()
+        self.self = UnimerSwinSelfAttention(config, dim, num_heads, window_size)
+        self.output = UnimerSwinSelfOutput(config, dim)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor]:
+        self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
+class UnimerSwinIntermediate(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinOutput
+class UnimerSwinOutput(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+class ConvEnhance(nn.Module):
+    """Depth-wise convolution to get the positional information.
+    """
+    def __init__(self, config, dim, k=3):
+        super(ConvEnhance, self).__init__()
+        self.proj = nn.Conv2d(dim,
+                              dim,
+                              (k,k),
+                              (1,1),
+                              (k // 2,k // 2),
+                              groups=dim)
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, x, size: Tuple[int, int]):
+        B, N, C = x.shape
+        H, W = size
+        assert N == H * W
+
+        feat = x.transpose(1, 2).view(B, C, H, W)
+        feat = self.proj(feat)
+        feat = self.act_fn(feat)
+        feat = feat.flatten(2).transpose(1, 2)
+
+        x = x + feat
+        return x
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->UnimerSwin
+class UnimerSwinLayer(nn.Module):
+    def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.shift_size = shift_size
+        self.window_size = config.window_size
+        self.input_resolution = input_resolution
+        self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+
+        self.ce = nn.ModuleList([ConvEnhance(config, dim=dim, k=3),
+                                  ConvEnhance(config, dim=dim, k=3)])
+
+        self.attention = UnimerSwinAttention(config, dim, num_heads, window_size=self.window_size)
+        self.drop_path = UnimerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+        self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+        self.intermediate = UnimerSwinIntermediate(config, dim)
+        self.output = UnimerSwinOutput(config, dim)
+
+    def set_shift_and_window_size(self, input_resolution):
+        if min(input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = torch_int(0)
+            self.window_size = (
+                torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
+            )
+
+    def get_attn_mask(self, height, width, dtype, device):
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
+            height_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
+            width_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
+            count = 0
+            for height_slice in height_slices:
+                for width_slice in width_slices:
+                    img_mask[:, height_slice, width_slice, :] = count
+                    count += 1
+
+            mask_windows = window_partition(img_mask, self.window_size)
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+        else:
+            attn_mask = None
+        return attn_mask
+
+    def maybe_pad(self, hidden_states, height, width):
+        pad_right = (self.window_size - width % self.window_size) % self.window_size
+        pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+        pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
+        hidden_states = nn.functional.pad(hidden_states, pad_values)
+        return hidden_states, pad_values
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        input_dimensions: Tuple[int, int],
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+        always_partition: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if not always_partition:
+            self.set_shift_and_window_size(input_dimensions)
+        else:
+            pass
+        height, width = input_dimensions
+        batch_size, _, channels = hidden_states.size()
+        
+
+
+        hidden_states = self.ce[0](hidden_states, input_dimensions)
+        shortcut = hidden_states
+
+
+        hidden_states = self.layernorm_before(hidden_states)
+        hidden_states = hidden_states.view(batch_size, height, width, channels)
+
+        # pad hidden_states to multiples of window size
+        hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+        _, height_pad, width_pad, _ = hidden_states.shape
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+        else:
+            shifted_hidden_states = hidden_states
+
+        # partition windows
+        hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+        hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
+        attn_mask = self.get_attn_mask(
+            height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
+        )
+
+        attention_outputs = self.attention(
+            hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
+        )
+
+        attention_output = attention_outputs[0]
+
+        attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
+        shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            attention_windows = shifted_windows
+
+        was_padded = pad_values[3] > 0 or pad_values[5] > 0
+        if was_padded:
+            attention_windows = attention_windows[:, :height, :width, :].contiguous()
+
+        attention_windows = attention_windows.view(batch_size, height * width, channels)
+
+        hidden_states = shortcut + self.drop_path(attention_windows)
+
+
+
+        hidden_states = self.ce[1](hidden_states, input_dimensions)
+        layer_output = self.layernorm_after(hidden_states)
+        layer_output = self.intermediate(layer_output)
+        layer_output = hidden_states + self.output(layer_output)
+
+        layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+        return layer_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->UnimerSwin
+class UnimerSwinStage(nn.Module):
+    def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
+        super().__init__()
+        self.config = config
+        self.dim = dim
+        self.blocks = nn.ModuleList(
+            [
+                UnimerSwinLayer(
+                    config=config,
+                    dim=dim,
+                    input_resolution=input_resolution,
+                    num_heads=num_heads,
+                    shift_size=0,
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
+        else:
+            self.downsample = None
+
+        self.pointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        input_dimensions: Tuple[int, int],
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+        always_partition: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor]:
+        height, width = input_dimensions
+        for i, layer_module in enumerate(self.blocks):
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            layer_outputs = layer_module(
+                hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+            )
+
+            hidden_states = layer_outputs[0]
+
+        hidden_states_before_downsampling = hidden_states
+        if self.downsample is not None:
+            height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+            output_dimensions = (height, width, height_downsampled, width_downsampled)
+            hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
+        else:
+            output_dimensions = (height, width, height, width)
+
+        stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
+
+        if output_attentions:
+            stage_outputs += layer_outputs[1:]
+        return stage_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->UnimerSwin
+class UnimerSwinEncoder(nn.Module):
+    def __init__(self, config, grid_size):
+        super().__init__()
+        self.num_layers = len(config.depths)
+        self.config = config
+        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
+        self.layers = nn.ModuleList(
+            [
+                UnimerSwinStage(
+                    config=config,
+                    dim=int(config.embed_dim * 2**i_layer),
+                    input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+                    depth=config.depths[i_layer],
+                    num_heads=config.num_heads[i_layer],
+                    drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+                    downsample=UnimerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
+                )
+                for i_layer in range(self.num_layers)
+            ]
+        )
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        input_dimensions: Tuple[int, int],
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        output_hidden_states_before_downsampling: Optional[bool] = False,
+        always_partition: Optional[bool] = False,
+        return_dict: Optional[bool] = True,
+    ) -> Union[Tuple, UnimerSwinEncoderOutput]:
+        all_hidden_states = () if output_hidden_states else None
+        all_reshaped_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        if output_hidden_states:
+            batch_size, _, hidden_size = hidden_states.shape
+            # rearrange b (h w) c -> b c h w
+            reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+            reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+            all_hidden_states += (hidden_states,)
+            all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+        for i, layer_module in enumerate(self.layers):
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    layer_module.__call__,
+                    hidden_states,
+                    input_dimensions,
+                    layer_head_mask,
+                    output_attentions,
+                    always_partition,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+                )
+
+            hidden_states = layer_outputs[0]
+            hidden_states_before_downsampling = layer_outputs[1]
+            output_dimensions = layer_outputs[2]
+
+            input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+
+            if output_hidden_states and output_hidden_states_before_downsampling:
+                batch_size, _, hidden_size = hidden_states_before_downsampling.shape
+                # rearrange b (h w) c -> b c h w
+                # here we use the original (not downsampled) height and width
+                reshaped_hidden_state = hidden_states_before_downsampling.view(
+                    batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
+                )
+                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+                all_hidden_states += (hidden_states_before_downsampling,)
+                all_reshaped_hidden_states += (reshaped_hidden_state,)
+            elif output_hidden_states and not output_hidden_states_before_downsampling:
+                batch_size, _, hidden_size = hidden_states.shape
+                # rearrange b (h w) c -> b c h w
+                reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+                all_hidden_states += (hidden_states,)
+                all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+            if output_attentions:
+                all_self_attentions += layer_outputs[3:]
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+        return UnimerSwinEncoderOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            reshaped_hidden_states=all_reshaped_hidden_states,
+        )
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->UnimerSwin
+class UnimerSwinPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = UnimerSwinConfig
+    base_model_prefix = "unimer-swin"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["UnimerSwinStage"]
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+SWIN_START_DOCSTRING = r"""
+    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+    behavior.
+
+    Parameters:
+        config ([`UnimerSwinConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWIN_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+            [`DonutImageProcessor.__call__`] for details.
+        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+            Whether to interpolate the pre-trained position encodings.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare UnimerSwin Model transformer outputting raw hidden-states without any specific head on top.",
+    SWIN_START_DOCSTRING,
+)
+class UnimerSwinModel(UnimerSwinPreTrainedModel):
+    def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+        super().__init__(config)
+        self.config = config
+        self.num_layers = len(config.depths)
+        self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+        self.embeddings = UnimerSwinEmbeddings(config, use_mask_token=use_mask_token)
+        self.encoder = UnimerSwinEncoder(config, self.embeddings.patch_grid)
+        self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=UnimerSwinModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+        modality="vision",
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, UnimerSwinModelOutput]:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+        embedding_output, input_dimensions = self.embeddings(
+            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+        )
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            input_dimensions,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = encoder_outputs[0]
+
+        pooled_output = None
+        if self.pooler is not None:
+            pooled_output = self.pooler(sequence_output.transpose(1, 2))
+            pooled_output = torch.flatten(pooled_output, 1)
+
+        if not return_dict:
+            output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+            return output
+
+        return UnimerSwinModelOutput(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+        )

+ 14 - 12
magic_pdf/model/sub_modules/model_init.py

@@ -5,12 +5,13 @@ from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.model_list import AtomicModel
 from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
 from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
-from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
 from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
 from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
 
 try:
-    from magic_pdf_ascend_plugin.libs.license_verifier import load_license, LicenseFormatError, LicenseSignatureError, LicenseExpiredError
+    from magic_pdf_ascend_plugin.libs.license_verifier import (
+        LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
+        load_license)
     from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
     from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
     license_key = load_license()
@@ -20,26 +21,26 @@ except Exception as e:
     if isinstance(e, ImportError):
         pass
     elif isinstance(e, LicenseFormatError):
-        logger.error("Ascend Plugin: Invalid license format. Please check the license file.")
+        logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
     elif isinstance(e, LicenseSignatureError):
-        logger.error("Ascend Plugin: Invalid signature. The license may be tampered with.")
+        logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
     elif isinstance(e, LicenseExpiredError):
-        logger.error("Ascend Plugin: License has expired. Please renew your license.")
+        logger.error('Ascend Plugin: License has expired. Please renew your license.')
     elif isinstance(e, FileNotFoundError):
-        logger.error("Ascend Plugin: Not found License file.")
+        logger.error('Ascend Plugin: Not found License file.')
     else:
-        logger.error(f"Ascend Plugin: {e}")
+        logger.error(f'Ascend Plugin: {e}')
     from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
     # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
     from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
 
-from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
-from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
+        from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
         table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
+        from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
         config = {
             'model_dir': model_path,
             'device': _device_
@@ -55,7 +56,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
 
 
 def mfd_model_init(weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     mfd_model = YOLOv8MFDModel(weight, device)
     return mfd_model
@@ -67,19 +68,20 @@ def mfr_model_init(weight_dir, cfg_path, device='cpu'):
 
 
 def layout_model_init(weight, config_file, device):
+    from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
     model = Layoutlmv3_Predictor(weight, config_file, device)
     return model
 
 
 def doclayout_yolo_model_init(weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     model = DocLayoutYOLOModel(weight, device)
     return model
 
 
 def langdetect_model_init(langdetect_model_weight, device='cpu'):
-    if str(device).startswith("npu"):
+    if str(device).startswith('npu'):
         device = torch.device(device)
     model = YOLOv11LangDetModel(langdetect_model_weight, device)
     return model

+ 17 - 11
magic_pdf/model/sub_modules/model_utils.py

@@ -1,25 +1,31 @@
 import time
-
 import torch
-from PIL import Image
 from loguru import logger
-
+import numpy as np
 from magic_pdf.libs.clean_memory import clean_memory
 
 
-def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
+def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
+
     crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
     crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
-    # Create a white background with an additional width and height of 50
+
+    # Calculate new dimensions
     crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
     crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
-    return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
 
-    # Crop image
-    crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
-    cropped_img = input_pil_img.crop(crop_box)
-    return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
-    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
+    # Create a white background array
+    return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
+
+    # Crop the original image using numpy slicing
+    cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
+
+    # Paste the cropped image onto the white background
+    return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
+    crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
+
+    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
+                   crop_new_height]
     return return_image, return_list
 
 

+ 1 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -5,6 +5,7 @@ import cv2
 import numpy as np
 import torch
 
+
 from paddleocr import PaddleOCR
 from ppocr.utils.logging import get_logger
 from ppocr.utils.utility import alpha_to_color, binarize_img

+ 1 - 0
magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py

@@ -2,6 +2,7 @@ import os
 
 import cv2
 import numpy as np
+from paddleocr import PaddleOCR
 from ppstructure.table.predict_table import TableSystem
 from ppstructure.utility import init_args
 from PIL import Image

+ 1 - 1
magic_pdf/pdf_parse_union_core_v2.py

@@ -492,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     else:
         return [[x0, y0, x1, y1]]
 
-# @measure_time
+
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
 

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -2,7 +2,7 @@ weights:
   layoutlmv3: Layout/LayoutLMv3/model_final.pth
   doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
   yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
-  unimernet_small: MFR/unimernet_small_2501
+  unimernet_small: MFR/unimernet_hf_small_2503
   struct_eqtable: TabRec/StructEqTable
   tablemaster: TabRec/TableMaster
   rapid_table: TabRec/RapidTable

+ 19 - 10
magic_pdf/tools/cli.py

@@ -1,15 +1,18 @@
 import os
 import shutil
 import tempfile
+from pathlib import Path
+
 import click
 import fitz
 from loguru import logger
-from pathlib import Path
 
 import magic_pdf.model as model_config
+from magic_pdf.data.batch_build_dataset import batch_build_dataset
 from magic_pdf.data.data_reader_writer import FileBasedDataReader
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.version import __version__
-from magic_pdf.tools.common import do_parse, parse_pdf_methods
+from magic_pdf.tools.common import batch_do_parse, do_parse, parse_pdf_methods
 from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
 
 pdf_suffixes = ['.pdf']
@@ -94,30 +97,33 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
     def read_fn(path: Path):
         if path.suffix in ms_office_suffixes:
             convert_file_to_pdf(str(path), temp_dir)
-            fn = os.path.join(temp_dir, f"{path.stem}.pdf")
+            fn = os.path.join(temp_dir, f'{path.stem}.pdf')
         elif path.suffix in image_suffixes:
             with open(str(path), 'rb') as f:
                 bits = f.read()
             pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
-            fn = os.path.join(temp_dir, f"{path.stem}.pdf")
+            fn = os.path.join(temp_dir, f'{path.stem}.pdf')
             with open(fn, 'wb') as f:
                 f.write(pdf_bytes)
         elif path.suffix in pdf_suffixes:
             fn = str(path)
         else:
-            raise Exception(f"Unknown file suffix: {path.suffix}")
-        
+            raise Exception(f'Unknown file suffix: {path.suffix}')
+
         disk_rw = FileBasedDataReader(os.path.dirname(fn))
         return disk_rw.read(os.path.basename(fn))
 
-    def parse_doc(doc_path: Path):
+    def parse_doc(doc_path: Path, dataset: Dataset | None = None):
         try:
             file_name = str(Path(doc_path).stem)
-            pdf_data = read_fn(doc_path)
+            if dataset is None:
+                pdf_data_or_dataset = read_fn(doc_path)
+            else:
+                pdf_data_or_dataset = dataset
             do_parse(
                 output_dir,
                 file_name,
-                pdf_data,
+                pdf_data_or_dataset,
                 [],
                 method,
                 debug_able,
@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
             logger.exception(e)
 
     if os.path.isdir(path):
+        doc_paths = []
         for doc_path in Path(path).glob('*'):
             if doc_path.suffix in pdf_suffixes + image_suffixes + ms_office_suffixes:
-                parse_doc(doc_path)
+                doc_paths.append(doc_path)
+        datasets = batch_build_dataset(doc_paths, 4, lang)
+        batch_do_parse(output_dir, [str(doc_path.stem) for doc_path in doc_paths], datasets, method, debug_able, lang=lang)
     else:
         parse_doc(Path(path))
 

+ 88 - 11
magic_pdf/tools/common.py

@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import Dataset, PymuDocDataset
 from magic_pdf.libs.draw_bbox import draw_char_bbox
-from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.operators.models import InferenceResult
+from magic_pdf.model.doc_analyze_by_custom_model import (batch_doc_analyze,
+                                                         doc_analyze)
 
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i
     return output_bytes
 
 
-def do_parse(
+def _do_parse(
     output_dir,
     pdf_file_name,
-    pdf_bytes,
+    pdf_bytes_or_dataset,
     model_list,
     parse_method,
     debug_able,
@@ -92,16 +92,21 @@ def do_parse(
     formula_enable=None,
     table_enable=None,
 ):
+    from magic_pdf.operators.models import InferenceResult
     if debug_able:
         logger.warning('debug mode is on')
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
         # f_draw_char_bbox = True
 
-    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
-        pdf_bytes, start_page_id, end_page_id
-    )
-
+    if isinstance(pdf_bytes_or_dataset, bytes):
+        pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+            pdf_bytes_or_dataset, start_page_id, end_page_id
+        )
+        ds = PymuDocDataset(pdf_bytes, lang=lang)
+    else:
+        ds = pdf_bytes_or_dataset
+    pdf_bytes = ds._raw_data
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
 
     image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
@@ -109,8 +114,6 @@ def do_parse(
     )
     image_dir = str(os.path.basename(local_image_dir))
 
-    ds = PymuDocDataset(pdf_bytes, lang=lang)
-
     if len(model_list) == 0:
         if model_config.__use_inside_model__:
             if parse_method == 'auto':
@@ -241,5 +244,79 @@ def do_parse(
 
     logger.info(f'local output dir is {local_md_dir}')
 
+def do_parse(
+    output_dir,
+    pdf_file_name,
+    pdf_bytes_or_dataset,
+    model_list,
+    parse_method,
+    debug_able,
+    f_draw_span_bbox=True,
+    f_draw_layout_bbox=True,
+    f_dump_md=True,
+    f_dump_middle_json=True,
+    f_dump_model_json=True,
+    f_dump_orig_pdf=True,
+    f_dump_content_list=True,
+    f_make_md_mode=MakeMode.MM_MD,
+    f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
+    f_draw_char_bbox=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
+    parallel_count = 1
+    if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
+        parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
+
+    if parallel_count > 1:
+        if isinstance(pdf_bytes_or_dataset, bytes):
+            pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+                pdf_bytes_or_dataset, start_page_id, end_page_id
+            )
+            ds = PymuDocDataset(pdf_bytes, lang=lang)
+        else:
+            ds = pdf_bytes_or_dataset
+        batch_do_parse(output_dir, [pdf_file_name], [ds], parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+    else:
+        _do_parse(output_dir, pdf_file_name, pdf_bytes_or_dataset, model_list, parse_method, debug_able, start_page_id=start_page_id, end_page_id=end_page_id, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable,  f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+
+
+def batch_do_parse(
+    output_dir,
+    pdf_file_names: list[str],
+    pdf_bytes_or_datasets: list[bytes | Dataset],
+    parse_method,
+    debug_able,
+    f_draw_span_bbox=True,
+    f_draw_layout_bbox=True,
+    f_dump_md=True,
+    f_dump_middle_json=True,
+    f_dump_model_json=True,
+    f_dump_orig_pdf=True,
+    f_dump_content_list=True,
+    f_make_md_mode=MakeMode.MM_MD,
+    f_draw_model_bbox=False,
+    f_draw_line_sort_bbox=False,
+    f_draw_char_bbox=False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
+    dss = []
+    for v in pdf_bytes_or_datasets:
+        if isinstance(v, bytes):
+            dss.append(PymuDocDataset(v, lang=lang))
+        else:
+            dss.append(v)
+    infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
+    for idx, infer_result in enumerate(infer_results):
+        _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
+
 
 parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])

+ 70 - 45
projects/web_api/app.py

@@ -3,6 +3,7 @@ import os
 from base64 import b64encode
 from glob import glob
 from io import StringIO
+import tempfile
 from typing import Tuple, Union
 
 import uvicorn
@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile
 from fastapi.responses import JSONResponse
 from loguru import logger
 
+from magic_pdf.data.read_api import read_local_images, read_local_office
 import magic_pdf.model as model_config
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
 from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
 from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.operators.models import InferenceResult
@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True
 
 app = FastAPI()
 
+pdf_extensions = [".pdf"]
+office_extensions = [".ppt", ".pptx", ".doc", ".docx"]
+image_extensions = [".png", ".jpg"]
 
 class MemoryDataWriter(DataWriter):
     def __init__(self):
@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter):
 
 
 def init_writers(
-    pdf_path: str = None,
-    pdf_file: UploadFile = None,
+    file_path: str = None,
+    file: UploadFile = None,
     output_path: str = None,
     output_image_path: str = None,
 ) -> Tuple[
@@ -59,19 +64,19 @@ def init_writers(
     Initialize writers based on path type
 
     Args:
-        pdf_path: PDF file path (local path or S3 path)
-        pdf_file: Uploaded PDF file object
+        file_path: file path (local path or S3 path)
+        file: Uploaded file object
         output_path: Output directory path
         output_image_path: Image output directory path
 
     Returns:
-        Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
-        file content
+        Tuple[writer, image_writer, file_bytes]: Returns initialized writer tuple and file content
     """
-    if pdf_path:
-        is_s3_path = pdf_path.startswith("s3://")
+    file_extension:str = None
+    if file_path:
+        is_s3_path = file_path.startswith("s3://")
         if is_s3_path:
-            bucket = get_bucket_name(pdf_path)
+            bucket = get_bucket_name(file_path)
             ak, sk, endpoint = get_s3_config(bucket)
 
             writer = S3DataWriter(
@@ -84,25 +89,29 @@ def init_writers(
             temp_reader = S3DataReader(
                 "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
             )
-            pdf_bytes = temp_reader.read(pdf_path)
+            file_bytes = temp_reader.read(file_path)
+            file_extension = os.path.splitext(file_path)[1]
         else:
             writer = FileBasedDataWriter(output_path)
             image_writer = FileBasedDataWriter(output_image_path)
             os.makedirs(output_image_path, exist_ok=True)
-            with open(pdf_path, "rb") as f:
-                pdf_bytes = f.read()
+            with open(file_path, "rb") as f:
+                file_bytes = f.read()
+            file_extension = os.path.splitext(file_path)[1]
     else:
         # 处理上传的文件
-        pdf_bytes = pdf_file.file.read()
+        file_bytes = file.file.read()
+        file_extension = os.path.splitext(file.filename)[1]
         writer = FileBasedDataWriter(output_path)
         image_writer = FileBasedDataWriter(output_image_path)
         os.makedirs(output_image_path, exist_ok=True)
 
-    return writer, image_writer, pdf_bytes
+    return writer, image_writer, file_bytes, file_extension
 
 
-def process_pdf(
-    pdf_bytes: bytes,
+def process_file(
+    file_bytes: bytes,
+    file_extension: str,
     parse_method: str,
     image_writer: Union[S3DataWriter, FileBasedDataWriter],
 ) -> Tuple[InferenceResult, PipeResult]:
@@ -110,14 +119,30 @@ def process_pdf(
     Process PDF file content
 
     Args:
-        pdf_bytes: Binary content of PDF file
+        file_bytes: Binary content of file
+        file_extension: file extension
         parse_method: Parse method ('ocr', 'txt', 'auto')
         image_writer: Image writer
 
     Returns:
         Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
     """
-    ds = PymuDocDataset(pdf_bytes)
+
+    ds = Union[PymuDocDataset, ImageDataset]
+    if file_extension in pdf_extensions:
+        ds = PymuDocDataset(file_bytes)
+    elif file_extension in office_extensions:
+        # 需要使用office解析
+        temp_dir = tempfile.mkdtemp()
+        with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
+            f.write(file_bytes)
+        ds = read_local_office(temp_dir)[0]
+    elif file_extension in image_extensions:
+        # 需要使用ocr解析
+        temp_dir = tempfile.mkdtemp()
+        with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
+            f.write(file_bytes)
+        ds = read_local_images(temp_dir)[0]
     infer_result: InferenceResult = None
     pipe_result: PipeResult = None
 
@@ -145,13 +170,13 @@ def encode_image(image_path: str) -> str:
 
 
 @app.post(
-    "/pdf_parse",
+    "/file_parse",
     tags=["projects"],
-    summary="Parse PDF files (supports local files and S3)",
+    summary="Parse files (supports local files and S3)",
 )
-async def pdf_parse(
-    pdf_file: UploadFile = None,
-    pdf_path: str = None,
+async def file_parse(
+    file: UploadFile = None,
+    file_path: str = None,
     parse_method: str = "auto",
     is_json_md_dump: bool = False,
     output_dir: str = "output",
@@ -165,10 +190,10 @@ async def pdf_parse(
     to the specified directory.
 
     Args:
-        pdf_file: The PDF file to be parsed. Must not be specified together with
-            `pdf_path`
-        pdf_path: The path to the PDF file to be parsed. Must not be specified together
-            with `pdf_file`
+        file: The PDF file to be parsed. Must not be specified together with
+            `file_path`
+        file_path: The path to the PDF file to be parsed. Must not be specified together
+            with `file`
         parse_method: Parsing method, can be auto, ocr, or txt. Default is auto. If
             results are not satisfactory, try ocr
         is_json_md_dump: Whether to write parsed data to .json and .md files. Default
@@ -181,31 +206,31 @@ async def pdf_parse(
         return_content_list: Whether to return parsed PDF content list. Default to False
     """
     try:
-        if (pdf_file is None and pdf_path is None) or (
-            pdf_file is not None and pdf_path is not None
+        if (file is None and file_path is None) or (
+            file is not None and file_path is not None
         ):
             return JSONResponse(
-                content={"error": "Must provide either pdf_file or pdf_path"},
+                content={"error": "Must provide either file or file_path"},
                 status_code=400,
             )
 
         # Get PDF filename
-        pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split(
+        file_name = os.path.basename(file_path if file_path else file.filename).split(
             "."
         )[0]
-        output_path = f"{output_dir}/{pdf_name}"
+        output_path = f"{output_dir}/{file_name}"
         output_image_path = f"{output_path}/images"
 
         # Initialize readers/writers and get PDF content
-        writer, image_writer, pdf_bytes = init_writers(
-            pdf_path=pdf_path,
-            pdf_file=pdf_file,
+        writer, image_writer, file_bytes, file_extension = init_writers(
+            file_path=file_path,
+            file=file,
             output_path=output_path,
             output_image_path=output_image_path,
         )
 
         # Process PDF
-        infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)
+        infer_result, pipe_result = process_file(file_bytes, file_extension, parse_method, image_writer)
 
         # Use MemoryDataWriter to get results
         content_list_writer = MemoryDataWriter()
@@ -226,23 +251,23 @@ async def pdf_parse(
         # If results need to be saved
         if is_json_md_dump:
             writer.write_string(
-                f"{pdf_name}_content_list.json", content_list_writer.get_value()
+                f"{file_name}_content_list.json", content_list_writer.get_value()
             )
-            writer.write_string(f"{pdf_name}.md", md_content)
+            writer.write_string(f"{file_name}.md", md_content)
             writer.write_string(
-                f"{pdf_name}_middle.json", middle_json_writer.get_value()
+                f"{file_name}_middle.json", middle_json_writer.get_value()
             )
             writer.write_string(
-                f"{pdf_name}_model.json",
+                f"{file_name}_model.json",
                 json.dumps(model_json, indent=4, ensure_ascii=False),
             )
             # Save visualization results
-            pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf"))
-            pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf"))
+            pipe_result.draw_layout(os.path.join(output_path, f"{file_name}_layout.pdf"))
+            pipe_result.draw_span(os.path.join(output_path, f"{file_name}_spans.pdf"))
             pipe_result.draw_line_sort(
-                os.path.join(output_path, f"{pdf_name}_line_sort.pdf")
+                os.path.join(output_path, f"{file_name}_line_sort.pdf")
             )
-            infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf"))
+            infer_result.draw_model(os.path.join(output_path, f"{file_name}_model.pdf"))
 
         # Build return data
         data = {}

+ 3 - 2
requirements.txt

@@ -7,7 +7,8 @@ numpy>=1.21.6,<2.0.0
 pydantic>=2.7.2
 PyMuPDF>=1.24.9,<=1.24.14
 scikit-learn>=1.0.2
-torch>=2.2.2
-transformers
+torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
+torchvision
+transformers>=4.49.0
 pdfminer.six==20231228
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.

+ 13 - 5
scripts/download_models.py

@@ -1,4 +1,5 @@
 import json
+import shutil
 import os
 
 import requests
@@ -16,7 +17,7 @@ def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
         config_version = data.get('config_version', '0.0.0')
-        if config_version < '1.1.1':
+        if config_version < '1.2.0':
             data = download_json(url)
     else:
         data = download_json(url)
@@ -32,12 +33,13 @@ def download_and_modify_json(url, local_filename, modifications):
 
 if __name__ == '__main__':
     mineru_patterns = [
-        "models/Layout/LayoutLMv3/*",
+        # "models/Layout/LayoutLMv3/*",
         "models/Layout/YOLO/*",
         "models/MFD/YOLO/*",
-        "models/MFR/unimernet_small_2501/*",
-        "models/TabRec/TableMaster/*",
-        "models/TabRec/StructEqTable/*",
+        "models/MFR/unimernet_hf_small_2503/*",
+        "models/OCR/paddleocr/*",
+        # "models/TabRec/TableMaster/*",
+        # "models/TabRec/StructEqTable/*",
     ]
     model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
     layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
@@ -45,6 +47,12 @@ if __name__ == '__main__':
     print(f'model_dir is: {model_dir}')
     print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
 
+    paddleocr_model_dir = model_dir + '/OCR/paddleocr'
+    user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
+    if os.path.exists(user_paddleocr_dir):
+        shutil.rmtree(user_paddleocr_dir)
+    shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
+
     json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
     config_file_name = 'magic-pdf.json'
     home_dir = os.path.expanduser('~')

+ 13 - 5
scripts/download_models_hf.py

@@ -1,5 +1,6 @@
 import json
 import os
+import shutil
 
 import requests
 from huggingface_hub import snapshot_download
@@ -16,7 +17,7 @@ def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
         config_version = data.get('config_version', '0.0.0')
-        if config_version < '1.1.1':
+        if config_version < '1.2.0':
             data = download_json(url)
     else:
         data = download_json(url)
@@ -33,12 +34,13 @@ def download_and_modify_json(url, local_filename, modifications):
 if __name__ == '__main__':
 
     mineru_patterns = [
-        "models/Layout/LayoutLMv3/*",
+        # "models/Layout/LayoutLMv3/*",
         "models/Layout/YOLO/*",
         "models/MFD/YOLO/*",
-        "models/MFR/unimernet_small_2501/*",
-        "models/TabRec/TableMaster/*",
-        "models/TabRec/StructEqTable/*",
+        "models/MFR/unimernet_hf_small_2503/*",
+        "models/OCR/paddleocr/*",
+        # "models/TabRec/TableMaster/*",
+        # "models/TabRec/StructEqTable/*",
     ]
     model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
 
@@ -52,6 +54,12 @@ if __name__ == '__main__':
     print(f'model_dir is: {model_dir}')
     print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
 
+    paddleocr_model_dir = model_dir + '/OCR/paddleocr'
+    user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
+    if os.path.exists(user_paddleocr_dir):
+        shutil.rmtree(user_paddleocr_dir)
+    shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
+
     json_url = 'https://github.com/opendatalab/MinerU/raw/master/magic-pdf.template.json'
     config_file_name = 'magic-pdf.json'
     home_dir = os.path.expanduser('~')

+ 11 - 8
setup.py

@@ -36,29 +36,32 @@ if __name__ == '__main__':
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
                      ],
-            "full": ["unimernet==0.2.3",  # unimernet升级0.2.3,移除torchtext/eva-decord的依赖
-                     "torch>=2.2.2,<=2.3.1",  # torch2.4.0及之后版本未测试,先卡住版本上限
-                     "torchvision>=0.17.2,<=0.18.1",  # torchvision 受torch版本约束
+            "full": [
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
                      "ultralytics>=8.3.48",  # yolov8,公式检测
                      "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
                      "paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'",  # 解决linux的段异常问题
                      "paddlepaddle==2.6.1;platform_system=='Windows'",  # windows版本3.0.0效率下降,需锁定2.6.1
-                     "struct-eqtable==0.3.2",  # 表格解析
-                     "einops",  # struct-eqtable依赖
-                     "accelerate",  # struct-eqtable依赖
                      "doclayout_yolo==0.0.2b1",  # doclayout_yolo
                      "rapidocr-paddle>=1.4.5,<2.0.0",  # rapidocr-paddle
                      "rapidocr_onnxruntime>=1.4.4,<2.0.0",
                      "rapid_table>=1.0.3,<2.0.0",  # rapid_table
                      "PyYAML",  # yaml
+                     "ftfy"
                      "openai",  # openai SDK
-                     "detectron2"
                      ],
             "old_linux":[
                 "albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统
-            ]
+            ],
+            "layoutlmv3":[
+                "detectron2"
+            ],
+            "struct_eqtable":[
+                "struct-eqtable==0.3.2",  # 表格解析
+                "einops",  # struct-eqtable依赖
+                "accelerate",  # struct-eqtable依赖
+            ],
         },
         description="A practical tool for converting PDF to Markdown",  # 简短描述
         long_description=long_description,  # 详细描述

+ 24 - 0
signatures/version1/cla.json

@@ -183,6 +183,30 @@
       "created_at": "2025-02-26T09:23:25Z",
       "repoId": 765083837,
       "pullRequestNo": 1785
+    },
+    {
+      "name": "rschutski",
+      "id": 179498169,
+      "comment_id": 2705150371,
+      "created_at": "2025-03-06T23:16:30Z",
+      "repoId": 765083837,
+      "pullRequestNo": 1863
+    },
+    {
+      "name": "qbit-",
+      "id": 4794088,
+      "comment_id": 2705914730,
+      "created_at": "2025-03-07T09:09:13Z",
+      "repoId": 765083837,
+      "pullRequestNo": 1863
+    },
+    {
+      "name": "mauryaland",
+      "id": 22381129,
+      "comment_id": 2717322316,
+      "created_at": "2025-03-12T10:03:11Z",
+      "repoId": 765083837,
+      "pullRequestNo": 1906
     }
   ]
 }