Kaynağa Gözat

update:Integrate the PDF-Extract-Kit inside

myhloli 1 yıl önce
ebeveyn
işleme
1fac6aa72d
30 değiştirilmiş dosya ile 6049 ekleme ve 131 silme
  1. 1 0
      magic_pdf/cli/magicpdf.py
  2. 27 22
      magic_pdf/model/doc_analyze_by_custom_model.py
  3. 8 0
      magic_pdf/model/model_list.py
  4. 70 109
      magic_pdf/model/pdf_extract_kit.py
  5. 0 0
      magic_pdf/model/pek_sub_modules/__init__.py
  6. 0 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py
  7. 179 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py
  8. 671 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py
  9. 476 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py
  10. 7 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py
  11. 2 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py
  12. 171 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py
  13. 124 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py
  14. 136 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py
  15. 284 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py
  16. 213 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py
  17. 7 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py
  18. 24 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py
  19. 60 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py
  20. 1282 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py
  21. 32 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py
  22. 34 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py
  23. 141 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py
  24. 163 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py
  25. 1236 0
      magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py
  26. 36 0
      magic_pdf/model/pek_sub_modules/post_process.py
  27. 259 0
      magic_pdf/model/pek_sub_modules/self_modify.py
  28. 46 0
      magic_pdf/resources/model_config/UniMERNet/demo.yaml
  29. 351 0
      magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml
  30. 9 0
      magic_pdf/resources/model_config/model_configs.yaml

+ 1 - 0
magic_pdf/cli/magicpdf.py

@@ -85,6 +85,7 @@ def do_parse(
     orig_model_list = copy.deepcopy(model_list)
 
     local_image_dir, local_md_dir = prepare_env(pdf_file_name, parse_method)
+    logger.info(f"local output dir is {local_md_dir}")
     image_writer, md_writer = DiskReaderWriter(local_image_dir), DiskReaderWriter(local_md_dir)
     image_dir = str(os.path.basename(local_image_dir))
 

+ 27 - 22
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,7 +1,7 @@
 import fitz
 import numpy as np
 from loguru import logger
-from magic_pdf.model.model_list import MODEL
+from magic_pdf.model.model_list import MODEL, MODEL_TYPE
 import magic_pdf.model as model_config
 
 
@@ -34,8 +34,8 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
             pm = page.get_pixmap(matrix=mat, alpha=False)
 
             # if width or height > 3000 pixels, don't enlarge the image
-            if pix.width > 3000 or pix.height > 3000:
-                pix = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+            if pm.width > 3000 or pm.height > 3000:
+                pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
             img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
             img = np.array(img)
@@ -44,31 +44,36 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
     return images
 
 
-def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.Paddle):
-
+def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.Paddle,
+                model_type=MODEL_TYPE.SINGLE_PAGE):
+    custom_model = None
     if model_config.__use_inside_model__:
-        from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
+        if model == MODEL.Paddle:
+            from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
+            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
+        elif model == MODEL.PEK:
+            from magic_pdf.model.pdf_extract_kit import CustomPEKModel
+            custom_model = CustomPEKModel(ocr=ocr, show_log=show_log)
+        else:
+            logger.error("Not allow model_name!")
+            exit(1)
     else:
         logger.error("use_inside_model is False, not allow to use inside model")
         exit(1)
 
     images = load_images_from_pdf(pdf_bytes)
-    custom_model = None
-    if model == MODEL.Paddle:
-        custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
-    else:
-        pass
-    model_json = []
-    for index, img_dict in enumerate(images):
-        img = img_dict["img"]
-        page_width = img_dict["width"]
-        page_height = img_dict["height"]
-        result = custom_model(img)
-        page_info = {"page_no": index, "height": page_height, "width": page_width}
-        page_dict = {"layout_dets": result, "page_info": page_info}
-
-        model_json.append(page_dict)
 
-    # @todo 把公式识别放在后置位置,待整本全部模型结果出来之后再补公式数据
+    model_json = []
+    if model_type == MODEL_TYPE.SINGLE_PAGE:
+        for index, img_dict in enumerate(images):
+            img = img_dict["img"]
+            page_width = img_dict["width"]
+            page_height = img_dict["height"]
+            result = custom_model(img)
+            page_info = {"page_no": index, "height": page_height, "width": page_width}
+            page_dict = {"layout_dets": result, "page_info": page_info}
+            model_json.append(page_dict)
+    elif model_type == MODEL_TYPE.MULTI_PAGE:
+        model_json = custom_model(images)
 
     return model_json

+ 8 - 0
magic_pdf/model/model_list.py

@@ -1,2 +1,10 @@
 class MODEL:
     Paddle = "pp_structure_v2"
+    PEK = "pdf_extract_kit"
+
+
+class MODEL_TYPE:
+    # 单页解析
+    SINGLE_PAGE = 1
+    # 多页解析
+    MULTI_PAGE = 2

+ 70 - 109
magic_pdf/model/pdf_extract_kit.py

@@ -1,126 +1,87 @@
-
 import os
-import time
-
-import cv2
-import fitz
 import numpy as np
-import torch
-import unimernet.tasks as tasks
 import yaml
-from PIL import Image
-from torch.utils.data import DataLoader, Dataset
-from torchvision import transforms
 from ultralytics import YOLO
+from loguru import logger
+from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
 from unimernet.common.config import Config
+import unimernet.tasks as tasks
 from unimernet.processors import load_processor
+import argparse
+from torchvision import transforms
 
+from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 
 
-class CustomPEKModel:
-    def __init__(self, ocr: bool = False, show_log: bool = False):
-        ## ======== model init ========##
-        with open('configs/model_configs.yaml') as f:
-            model_configs = yaml.load(f, Loader=yaml.FullLoader)
-        img_size = model_configs['model_args']['img_size']
-        conf_thres = model_configs['model_args']['conf_thres']
-        iou_thres = model_configs['model_args']['iou_thres']
-        device = model_configs['model_args']['device']
-        dpi = model_configs['model_args']['pdf_dpi']
-        mfd_model = mfd_model_init(model_configs['model_args']['mfd_weight'])
-        mfr_model, mfr_vis_processors = mfr_model_init(model_configs['model_args']['mfr_weight'], device=device)
-        mfr_transform = transforms.Compose([mfr_vis_processors, ])
-        layout_model = layout_model_init(model_configs['model_args']['layout_weight'])
-        ocr_model = ModifiedPaddleOCR(show_log=True)
-        print(now.strftime('%Y-%m-%d %H:%M:%S'))
-        print('Model init done!')
-        ## ======== model init ========##
+def layout_model_init(weight, config_file):
+    model = Layoutlmv3_Predictor(weight, config_file)
+    return model
 
-    def __call__(self, image):
 
-        # layout检测 + 公式检测
-        doc_layout_result = []
-        latex_filling_list = []
-        mf_image_list = []
+def mfr_model_init(weight_dir, cfg_path, device='cpu'):
+    args = argparse.Namespace(cfg_path=cfg_path, options=None)
+    cfg = Config(args)
+    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
+    cfg.config.model.model_config.model_name = weight_dir
+    cfg.config.model.tokenizer_config.path = weight_dir
+    task = tasks.setup_task(cfg)
+    model = task.build_model(cfg)
+    model = model.to(device)
+    vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
+    return model, vis_processor
 
-            img_H, img_W = image.shape[0], image.shape[1]
-            layout_res = layout_model(image, ignore_catids=[])
-            # 公式检测
-            mfd_res = mfd_model.predict(image, imgsz=img_size, conf=conf_thres, iou=iou_thres, verbose=True)[0]
-            for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
-                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
-                new_item = {
-                    'category_id': 13 + int(cla.item()),
-                    'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                    'score': round(float(conf.item()), 2),
-                    'latex': '',
-                }
-                layout_res['layout_dets'].append(new_item)
-                latex_filling_list.append(new_item)
-                bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
-                mf_image_list.append(bbox_img)
 
-            layout_res['page_info'] = dict(
-                page_no=idx,
-                height=img_H,
-                width=img_W
+class CustomPEKModel:
+    def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
+        """
+        ======== model init ========
+        """
+        # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
+        current_file_path = os.path.abspath(__file__)
+        # 获取当前文件所在的目录(model)
+        current_dir = os.path.dirname(current_file_path)
+        # 上一级目录(magic_pdf)
+        root_dir = os.path.dirname(current_dir)
+        # model_config目录
+        model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
+        # 构建 model_configs.yaml 文件的完整路径
+        config_path = os.path.join(model_config_dir, 'model_configs.yaml')
+        with open(config_path, "r") as f:
+            self.configs = yaml.load(f, Loader=yaml.FullLoader)
+        # 初始化解析配置
+        self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
+        self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
+        self.apply_ocr = ocr
+        logger.info(
+            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
+                self.apply_layout, self.apply_formula, self.apply_ocr
             )
-            doc_layout_result.append(layout_res)
+        )
+        assert self.apply_layout, "DocAnalysis must contain layout model."
+        # 初始化解析方案
+        self.device = self.configs["config"]["device"]
+        logger.info("using device: {}".format(self.device))
+        # 初始化layout模型
+        self.layout_model = layout_model_init(
+            os.path.join(root_dir, self.configs['weights']['layout']),
+            os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")
+        )
+        # 初始化公式识别
+        if self.apply_formula:
+            # 初始化公式检测模型
+            self.mfd_model = YOLO(model=str(os.path.join(root_dir, self.configs["weights"]["mfd"])))
+            # 初始化公式解析模型
+            mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
+            self.mfr_model, mfr_vis_processors = mfr_model_init(
+                os.path.join(root_dir, self.configs["weights"]["mfr"]), mfr_config_path,
+                device=self.device)
+            self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
+        # 初始化ocr
+        if self.apply_ocr:
+            self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
 
-        # 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。
-        a = time.time()
-        dataset = MathDataset(mf_image_list, transform=mfr_transform)
-        dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
-        mfr_res = []
-        gpu_total_cost = 0
-        for imgs in dataloader:
-            imgs = imgs.to(device)
-            gpu_start = time.time()
-            output = mfr_model.generate({'image': imgs})
-            gpu_cost = time.time() - gpu_start
-            gpu_total_cost += gpu_cost
-            print(f"gpu_cost: {gpu_cost}")
-            mfr_res.extend(output['pred_str'])
-        print(f"gpu_total_cost: {gpu_total_cost}")
-        for res, latex in zip(latex_filling_list, mfr_res):
-            res['latex'] = latex_rm_whitespace(latex)
-        b = time.time()
-        print("formula nums:", len(mf_image_list), "mfr time:", round(b - a, 2))
+        logger.info('DocAnalysis init done!')
 
-        # ocr识别
-        for idx, image in enumerate(img_list):
-            pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
-            single_page_res = doc_layout_result[idx]['layout_dets']
-            single_page_mfdetrec_res = []
-            for res in single_page_res:
-                if int(res['category_id']) in [13, 14]:
-                    xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                    xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
-                    single_page_mfdetrec_res.append({
-                        "bbox": [xmin, ymin, xmax, ymax],
-                    })
-            for res in single_page_res:
-                if int(res['category_id']) in [0, 1, 2, 4, 6, 7]:  # 需要进行ocr的类别
-                    xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
-                    xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
-                    crop_box = [xmin, ymin, xmax, ymax]
-                    cropped_img = Image.new('RGB', pil_img.size, 'white')
-                    cropped_img.paste(pil_img.crop(crop_box), crop_box)
-                    cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
-                    ocr_res = ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
-                    if ocr_res:
-                        for box_ocr_res in ocr_res:
-                            p1, p2, p3, p4 = box_ocr_res[0]
-                            text, score = box_ocr_res[1]
-                            doc_layout_result[idx]['layout_dets'].append({
-                                'category_id': 15,
-                                'poly': p1 + p2 + p3 + p4,
-                                'score': round(score, 2),
-                                'text': text,
-                            })
 
-        output_dir = args.output
-        os.makedirs(output_dir, exist_ok=True)
-        basename = os.path.basename(single_pdf)[0:-4]
-        with open(os.path.join(output_dir, f'{basename}.json'), 'w') as f:
-            json.dump(doc_layout_result, f)
+    def __call__(self, image):
+        pass

+ 0 - 0
magic_pdf/model/pek_sub_modules/__init__.py


+ 0 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py


+ 179 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py

@@ -0,0 +1,179 @@
+# --------------------------------------------------------------------------------
+# VIT: Multi-Path Vision Transformer for Dense Prediction
+# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
+# All Rights Reserved.
+# Written by Youngwan Lee
+# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# CoaT: https://github.com/mlpc-ucsd/CoaT
+# --------------------------------------------------------------------------------
+
+
+import torch
+
+from detectron2.layers import (
+    ShapeSpec,
+)
+from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
+from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
+
+from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
+from .deit import deit_base_patch16, mae_base_patch16
+from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
+from transformers import AutoConfig
+
+__all__ = [
+    "build_vit_fpn_backbone",
+]
+
+
+class VIT_Backbone(Backbone):
+    """
+    Implement VIT backbone.
+    """
+
+    def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
+                 config_path=None, image_only=False, cfg=None):
+        super().__init__()
+        self._out_features = out_features
+        if 'base' in name:
+            self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
+            self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+        else:
+            self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
+            self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
+
+        if name == 'beit_base_patch16':
+            model_func = beit_base_patch16
+        elif name == 'dit_base_patch16':
+            model_func = dit_base_patch16
+        elif name == "deit_base_patch16":
+            model_func = deit_base_patch16
+        elif name == "mae_base_patch16":
+            model_func = mae_base_patch16
+        elif name == "dit_large_patch16":
+            model_func = dit_large_patch16
+        elif name == "beit_large_patch16":
+            model_func = beit_large_patch16
+
+        if 'beit' in name or 'dit' in name:
+            if pos_type == "abs":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_abs_pos_emb=True,
+                                           **model_kwargs)
+            elif pos_type == "shared_rel":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_shared_rel_pos_bias=True,
+                                           **model_kwargs)
+            elif pos_type == "rel":
+                self.backbone = model_func(img_size=img_size,
+                                           out_features=out_features,
+                                           drop_path_rate=drop_path,
+                                           use_rel_pos_bias=True,
+                                           **model_kwargs)
+            else:
+                raise ValueError()
+        elif "layoutlmv3" in name:
+            config = AutoConfig.from_pretrained(config_path)
+            # disable relative bias as DiT
+            config.has_spatial_attention_bias = False
+            config.has_relative_attention_bias = False
+            self.backbone = LayoutLMv3Model(config, detection=True,
+                                               out_features=out_features, image_only=image_only)
+        else:
+            self.backbone = model_func(img_size=img_size,
+                                       out_features=out_features,
+                                       drop_path_rate=drop_path,
+                                       **model_kwargs)
+        self.name = name
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+
+        Returns:
+            dict[str->Tensor]: names and the corresponding features
+        """
+        if "layoutlmv3" in self.name:
+            return self.backbone.forward(
+                input_ids=x["input_ids"] if "input_ids" in x else None,
+                bbox=x["bbox"] if "bbox" in x else None,
+                images=x["images"] if "images" in x else None,
+                attention_mask=x["attention_mask"] if "attention_mask" in x else None,
+                # output_hidden_states=True,
+            )
+        assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+        return self.backbone.forward_features(x)
+
+    def output_shape(self):
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+            )
+            for name in self._out_features
+        }
+
+
+def build_VIT_backbone(cfg):
+    """
+    Create a VIT instance from config.
+
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        A VIT backbone instance.
+    """
+    # fmt: off
+    name = cfg.MODEL.VIT.NAME
+    out_features = cfg.MODEL.VIT.OUT_FEATURES
+    drop_path = cfg.MODEL.VIT.DROP_PATH
+    img_size = cfg.MODEL.VIT.IMG_SIZE
+    pos_type = cfg.MODEL.VIT.POS_TYPE
+
+    model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
+
+    if 'layoutlmv3' in name:
+        if cfg.MODEL.CONFIG_PATH != '':
+            config_path = cfg.MODEL.CONFIG_PATH
+        else:
+            config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '')  # layoutlmv3 pre-trained models
+            config_path = config_path.replace('model_final.pth', '')  # detection fine-tuned models
+    else:
+        config_path = None
+
+    return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
+                        config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
+
+
+@BACKBONE_REGISTRY.register()
+def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
+    """
+    Create a VIT w/ FPN backbone.
+
+    Args:
+        cfg: a detectron2 CfgNode
+
+    Returns:
+        backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
+    """
+    bottom_up = build_VIT_backbone(cfg)
+    in_features = cfg.MODEL.FPN.IN_FEATURES
+    out_channels = cfg.MODEL.FPN.OUT_CHANNELS
+    backbone = FPN(
+        bottom_up=bottom_up,
+        in_features=in_features,
+        out_channels=out_channels,
+        norm=cfg.MODEL.FPN.NORM,
+        top_block=LastLevelMaxPool(),
+        fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
+    )
+    return backbone

+ 671 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py

@@ -0,0 +1,671 @@
+""" Vision Transformer (ViT) in PyTorch
+
+A PyTorch implement of Vision Transformers as described in
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
+
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+
+Status/TODO:
+* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
+* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
+* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
+* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
+
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import warnings
+import math
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        # x = self.drop(x)
+        # commit this for the orignal BERT implement
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+            proj_drop=0., window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        if attn_head_dim is not None:
+            head_dim = attn_head_dim
+        all_head_dim = head_dim * self.num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+        if qkv_bias:
+            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+        else:
+            self.q_bias = None
+            self.v_bias = None
+
+        if window_size:
+            self.window_size = window_size
+            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+            self.relative_position_bias_table = nn.Parameter(
+                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+            # cls to token & token 2 cls & cls to cls
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(window_size[0])
+            coords_w = torch.arange(window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = self.num_relative_distance - 3
+            relative_position_index[0:, 0] = self.num_relative_distance - 2
+            relative_position_index[0, 0] = self.num_relative_distance - 1
+
+            self.register_buffer("relative_position_index", relative_position_index)
+
+            # trunc_normal_(self.relative_position_bias_table, std=.0)
+        else:
+            self.window_size = None
+            self.relative_position_bias_table = None
+            self.relative_position_index = None
+
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(all_head_dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x, rel_pos_bias=None, training_window_size=None):
+        B, N, C = x.shape
+        qkv_bias = None
+        if self.q_bias is not None:
+            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        if self.relative_position_bias_table is not None:
+            if training_window_size == self.window_size:
+                relative_position_bias = \
+                    self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                        self.window_size[0] * self.window_size[1] + 1,
+                        self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+                attn = attn + relative_position_bias.unsqueeze(0)
+            else:
+                training_window_size = tuple(training_window_size.tolist())
+                new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+                # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+                new_relative_position_bias_table = F.interpolate(
+                    self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
+                                                                                 2 * self.window_size[0] - 1,
+                                                                                 2 * self.window_size[1] - 1),
+                    size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
+                    align_corners=False)
+                new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
+                                                                                         new_num_relative_distance - 3).permute(
+                    1, 0)
+                new_relative_position_bias_table = torch.cat(
+                    [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
+
+                # get pair-wise relative position index for each token inside the window
+                coords_h = torch.arange(training_window_size[0])
+                coords_w = torch.arange(training_window_size[1])
+                coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+                coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+                relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+                relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+                relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0
+                relative_coords[:, :, 1] += training_window_size[1] - 1
+                relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+                relative_position_index = \
+                    torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
+                                dtype=relative_coords.dtype)
+                relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+                relative_position_index[0, 0:] = new_num_relative_distance - 3
+                relative_position_index[0:, 0] = new_num_relative_distance - 2
+                relative_position_index[0, 0] = new_num_relative_distance - 1
+
+                relative_position_bias = \
+                    new_relative_position_bias_table[relative_position_index.view(-1)].view(
+                        training_window_size[0] * training_window_size[1] + 1,
+                        training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+                attn = attn + relative_position_bias.unsqueeze(0)
+
+        if rel_pos_bias is not None:
+            attn = attn + rel_pos_bias
+
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+                 window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if init_values is not None:
+            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+        else:
+            self.gamma_1, self.gamma_2 = None, None
+
+    def forward(self, x, rel_pos_bias=None, training_window_size=None):
+        if self.gamma_1 is None:
+            x = x + self.drop_path(
+                self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
+            x = x + self.drop_path(self.mlp(self.norm2(x)))
+        else:
+            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
+                                                            training_window_size=training_window_size))
+            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+        self.num_patches_w = self.patch_shape[0]
+        self.num_patches_h = self.patch_shape[1]
+        # the so-called patch_shape is the patch shape during pre-training
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x, position_embedding=None, **kwargs):
+        # FIXME look at relaxing size constraints
+        # assert H == self.img_size[0] and W == self.img_size[1], \
+        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        Hp, Wp = x.shape[2], x.shape[3]
+
+        if position_embedding is not None:
+            # interpolate the position embedding to the corresponding size
+            position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
+                                                                                                                  1, 2)
+            position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
+            x = x + position_embedding
+
+        x = x.flatten(2).transpose(1, 2)
+        return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+    """ CNN Feature Map Embedding
+    Extract feature map from CNN, flatten, project to embedding dim.
+    """
+
+    def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
+        super().__init__()
+        assert isinstance(backbone, nn.Module)
+        img_size = to_2tuple(img_size)
+        self.img_size = img_size
+        self.backbone = backbone
+        if feature_size is None:
+            with torch.no_grad():
+                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+                # map for all networks, the feature metadata has reliable channel and stride info, but using
+                # stride to calc feature dim requires info about padding of each stage that isn't captured.
+                training = backbone.training
+                if training:
+                    backbone.eval()
+                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+                feature_size = o.shape[-2:]
+                feature_dim = o.shape[1]
+                backbone.train(training)
+        else:
+            feature_size = to_2tuple(feature_size)
+            feature_dim = self.backbone.feature_info.channels()[-1]
+        self.num_patches = feature_size[0] * feature_size[1]
+        self.proj = nn.Linear(feature_dim, embed_dim)
+
+    def forward(self, x):
+        x = self.backbone(x)[-1]
+        x = x.flatten(2).transpose(1, 2)
+        x = self.proj(x)
+        return x
+
+
+class RelativePositionBias(nn.Module):
+
+    def __init__(self, window_size, num_heads):
+        super().__init__()
+        self.window_size = window_size
+        self.num_heads = num_heads
+        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+        # cls to token & token 2 cls & cls to cls
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(window_size[0])
+        coords_w = torch.arange(window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+        relative_position_index = \
+            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        relative_position_index[0, 0:] = self.num_relative_distance - 3
+        relative_position_index[0:, 0] = self.num_relative_distance - 2
+        relative_position_index[0, 0] = self.num_relative_distance - 1
+
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+    def forward(self, training_window_size):
+        if training_window_size == self.window_size:
+            relative_position_bias = \
+                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                    self.window_size[0] * self.window_size[1] + 1,
+                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        else:
+            training_window_size = tuple(training_window_size.tolist())
+            new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+            # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+            new_relative_position_bias_table = F.interpolate(
+                self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
+                                                                             2 * self.window_size[0] - 1,
+                                                                             2 * self.window_size[1] - 1),
+                size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
+                align_corners=False)
+            new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
+                                                                                     new_num_relative_distance - 3).permute(
+                1, 0)
+            new_relative_position_bias_table = torch.cat(
+                [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(training_window_size[0])
+            coords_w = torch.arange(training_window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += training_window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += training_window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
+                            dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = new_num_relative_distance - 3
+            relative_position_index[0:, 0] = new_num_relative_distance - 2
+            relative_position_index[0, 0] = new_num_relative_distance - 1
+
+            relative_position_bias = \
+                new_relative_position_bias_table[relative_position_index.view(-1)].view(
+                    training_window_size[0] * training_window_size[1] + 1,
+                    training_window_size[0] * training_window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+
+        return relative_position_bias
+
+
+class BEiT(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+
+    def __init__(self,
+                 img_size=[224, 224],
+                 patch_size=16,
+                 in_chans=3,
+                 num_classes=80,
+                 embed_dim=768,
+                 depth=12,
+                 num_heads=12,
+                 mlp_ratio=4.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 hybrid_backbone=None,
+                 norm_layer=None,
+                 init_values=None,
+                 use_abs_pos_emb=False,
+                 use_rel_pos_bias=False,
+                 use_shared_rel_pos_bias=False,
+                 use_checkpoint=True,
+                 pretrained=None,
+                 out_features=None,
+                 ):
+
+        super(BEiT, self).__init__()
+
+        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        self.use_checkpoint = use_checkpoint
+
+        if hybrid_backbone is not None:
+            self.patch_embed = HybridEmbed(
+                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+        else:
+            self.patch_embed = PatchEmbed(
+                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+        self.out_features = out_features
+        self.out_indices = [int(name[5:]) for name in out_features]
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        if use_abs_pos_emb:
+            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        else:
+            self.pos_embed = None
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
+        if use_shared_rel_pos_bias:
+            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+        else:
+            self.rel_pos_bias = None
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.use_rel_pos_bias = use_rel_pos_bias
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+            for i in range(depth)])
+
+        # trunc_normal_(self.mask_token, std=.02)
+
+        if patch_size == 16:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+                # nn.SyncBatchNorm(embed_dim),
+                nn.BatchNorm2d(embed_dim),
+                nn.GELU(),
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn3 = nn.Identity()
+
+            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+        elif patch_size == 8:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Identity()
+
+            self.fpn3 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=2, stride=2),
+            )
+
+            self.fpn4 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=4, stride=4),
+            )
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    '''
+    def init_weights(self):
+        """Initialize the weights in backbone.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        logger = get_root_logger()
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+
+        if self.init_cfg is None:
+            logger.warn(f'No pre-trained weights for '
+                        f'{self.__class__.__name__}, '
+                        f'training start from scratch')
+        else:
+            assert 'checkpoint' in self.init_cfg, f'Only support ' \
+                                                  f'specify `Pretrained` in ' \
+                                                  f'`init_cfg` in ' \
+                                                  f'{self.__class__.__name__} '
+            logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
+            load_checkpoint(self,
+                            filename=self.init_cfg['checkpoint'],
+                            strict=False,
+                            logger=logger,
+                            beit_spec_expand_rel_pos = self.use_rel_pos_bias,
+                            )
+    '''
+
+    def get_num_layers(self):
+        return len(self.blocks)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def forward_features(self, x):
+        B, C, H, W = x.shape
+        x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
+        # Hp, Wp are HW for patches
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        if self.pos_embed is not None:
+            cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
+        x = torch.cat((cls_tokens, x), dim=1)
+        x = self.pos_drop(x)
+
+        features = []
+        training_window_size = torch.tensor([Hp, Wp])
+
+        rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
+
+        for i, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
+            else:
+                x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
+            if i in self.out_indices:
+                xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+                features.append(xp.contiguous())
+
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for i in range(len(features)):
+            features[i] = ops[i](features[i])
+
+        feat_out = {}
+
+        for name, value in zip(self.out_features, features):
+            feat_out[name] = value
+
+        return feat_out
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+def beit_base_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=None,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def beit_large_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=None,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def dit_base_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=0.1,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def dit_large_patch16(pretrained=False, **kwargs):
+    model = BEiT(
+        patch_size=16,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        qkv_bias=True,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        init_values=1e-5,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+if __name__ == '__main__':
+    model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
+    model = model.to("cuda:0")
+    input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
+    input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
+    input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
+    output1 = model(input1)
+    output2 = model(input2)
+    output3 = model(input3)
+    print("all done")

+ 476 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py

@@ -0,0 +1,476 @@
+"""
+Mostly copy-paste from DINO and timm library:
+https://github.com/facebookresearch/dino
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+import warnings
+
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import trunc_normal_, drop_path, to_2tuple
+from functools import partial
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        B, N, C = x.shape
+        q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
+                                      C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
+                       act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        x = x + self.drop_path(self.attn(self.norm1(x)))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+
+        self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+
+        self.num_patches_w, self.num_patches_h = self.window_size
+
+        self.num_patches = self.window_size[0] * self.window_size[1]
+        self.img_size = img_size
+        self.patch_size = patch_size
+
+        self.proj = nn.Conv2d(in_chans, embed_dim,
+                              kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        x = self.proj(x)
+        return x
+
+
+class HybridEmbed(nn.Module):
+    """ CNN Feature Map Embedding
+    Extract feature map from CNN, flatten, project to embedding dim.
+    """
+
+    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+        super().__init__()
+        assert isinstance(backbone, nn.Module)
+        img_size = to_2tuple(img_size)
+        self.img_size = img_size
+        self.backbone = backbone
+        if feature_size is None:
+            with torch.no_grad():
+                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+                # map for all networks, the feature metadata has reliable channel and stride info, but using
+                # stride to calc feature dim requires info about padding of each stage that isn't captured.
+                training = backbone.training
+                if training:
+                    backbone.eval()
+                o = self.backbone(torch.zeros(
+                    1, in_chans, img_size[0], img_size[1]))[-1]
+                feature_size = o.shape[-2:]
+                feature_dim = o.shape[1]
+                backbone.train(training)
+        else:
+            feature_size = to_2tuple(feature_size)
+            feature_dim = self.backbone.feature_info.channels()[-1]
+        self.num_patches = feature_size[0] * feature_size[1]
+        self.proj = nn.Linear(feature_dim, embed_dim)
+
+    def forward(self, x):
+        x = self.backbone(x)[-1]
+        x = x.flatten(2).transpose(1, 2)
+        x = self.proj(x)
+        return x
+
+
+class ViT(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+
+    def __init__(self,
+                 model_name='vit_base_patch16_224',
+                 img_size=384,
+                 patch_size=16,
+                 in_chans=3,
+                 embed_dim=1024,
+                 depth=24,
+                 num_heads=16,
+                 num_classes=19,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.1,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 hybrid_backbone=None,
+                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
+                 norm_cfg=None,
+                 pos_embed_interp=False,
+                 random_init=False,
+                 align_corners=False,
+                 use_checkpoint=False,
+                 num_extra_tokens=1,
+                 out_features=None,
+                 **kwargs,
+                 ):
+
+        super(ViT, self).__init__()
+        self.model_name = model_name
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+        self.depth = depth
+        self.num_heads = num_heads
+        self.num_classes = num_classes
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.qk_scale = qk_scale
+        self.drop_rate = drop_rate
+        self.attn_drop_rate = attn_drop_rate
+        self.drop_path_rate = drop_path_rate
+        self.hybrid_backbone = hybrid_backbone
+        self.norm_layer = norm_layer
+        self.norm_cfg = norm_cfg
+        self.pos_embed_interp = pos_embed_interp
+        self.random_init = random_init
+        self.align_corners = align_corners
+        self.use_checkpoint = use_checkpoint
+        self.num_extra_tokens = num_extra_tokens
+        self.out_features = out_features
+        self.out_indices = [int(name[5:]) for name in out_features]
+
+        # self.num_stages = self.depth
+        # self.out_indices = tuple(range(self.num_stages))
+
+        if self.hybrid_backbone is not None:
+            self.patch_embed = HybridEmbed(
+                self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
+        else:
+            self.patch_embed = PatchEmbed(
+                img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
+        self.num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+        if self.num_extra_tokens == 2:
+            self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+
+        self.pos_embed = nn.Parameter(torch.zeros(
+            1, self.num_patches + self.num_extra_tokens, self.embed_dim))
+        self.pos_drop = nn.Dropout(p=self.drop_rate)
+
+        # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
+        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
+                                                self.depth)]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
+                qk_scale=self.qk_scale,
+                drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
+            for i in range(self.depth)])
+
+        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
+        # self.repr = nn.Linear(embed_dim, representation_size)
+        # self.repr_act = nn.Tanh()
+
+        if patch_size == 16:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+                nn.SyncBatchNorm(embed_dim),
+                nn.GELU(),
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn3 = nn.Identity()
+
+            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+        elif patch_size == 8:
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Identity()
+
+            self.fpn3 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=2, stride=2),
+            )
+
+            self.fpn4 = nn.Sequential(
+                nn.MaxPool2d(kernel_size=4, stride=4),
+            )
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        if self.num_extra_tokens==2:
+            trunc_normal_(self.dist_token, std=0.2)
+        self.apply(self._init_weights)
+        # self.fix_init_weight()
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    '''
+    def init_weights(self):
+        logger = get_root_logger()
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+
+        if self.init_cfg is None:
+            logger.warn(f'No pre-trained weights for '
+                        f'{self.__class__.__name__}, '
+                        f'training start from scratch')
+        else:
+            assert 'checkpoint' in self.init_cfg, f'Only support ' \
+                                                  f'specify `Pretrained` in ' \
+                                                  f'`init_cfg` in ' \
+                                                  f'{self.__class__.__name__} '
+            logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
+            load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
+    '''
+
+    def get_num_layers(self):
+        return len(self.blocks)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def _conv_filter(self, state_dict, patch_size=16):
+        """ convert patch embedding weight from manual patchify + linear proj to conv"""
+        out_dict = {}
+        for k, v in state_dict.items():
+            if 'patch_embed.proj.weight' in k:
+                v = v.reshape((v.shape[0], 3, patch_size, patch_size))
+            out_dict[k] = v
+        return out_dict
+
+    def to_2D(self, x):
+        n, hw, c = x.shape
+        h = w = int(math.sqrt(hw))
+        x = x.transpose(1, 2).reshape(n, c, h, w)
+        return x
+
+    def to_1D(self, x):
+        n, c, h, w = x.shape
+        x = x.reshape(n, c, -1).transpose(1, 2)
+        return x
+
+    def interpolate_pos_encoding(self, x, w, h):
+        npatch = x.shape[1] - self.num_extra_tokens
+        N = self.pos_embed.shape[1] - self.num_extra_tokens
+        if npatch == N and w == h:
+            return self.pos_embed
+
+        class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
+
+        patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
+
+        dim = x.shape[-1]
+        w0 = w // self.patch_embed.patch_size[0]
+        h0 = h // self.patch_embed.patch_size[1]
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        w0, h0 = w0 + 0.1, h0 + 0.1
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            mode='bicubic',
+        )
+        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
+
+    def prepare_tokens(self, x, mask=None):
+        B, nc, w, h = x.shape
+        # patch linear embedding
+        x = self.patch_embed(x)
+
+        # mask image modeling
+        if mask is not None:
+            x = self.mask_model(x, mask)
+        x = x.flatten(2).transpose(1, 2)
+
+        # add the [CLS] token to the embed patch tokens
+        all_tokens = [self.cls_token.expand(B, -1, -1)]
+
+        if self.num_extra_tokens == 2:
+            dist_tokens = self.dist_token.expand(B, -1, -1)
+            all_tokens.append(dist_tokens)
+        all_tokens.append(x)
+
+        x = torch.cat(all_tokens, dim=1)
+
+        # add positional encoding to each token
+        x = x + self.interpolate_pos_encoding(x, w, h)
+
+        return self.pos_drop(x)
+
+    def forward_features(self, x):
+        # print(f"==========shape of x is {x.shape}==========")
+        B, _, H, W = x.shape
+        Hp, Wp = H // self.patch_size, W // self.patch_size
+        x = self.prepare_tokens(x)
+
+        features = []
+        for i, blk in enumerate(self.blocks):
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+            if i in self.out_indices:
+                xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+                features.append(xp.contiguous())
+
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for i in range(len(features)):
+            features[i] = ops[i](features[i])
+
+        feat_out = {}
+
+        for name, value in zip(self.out_features, features):
+            feat_out[name] = value
+
+        return feat_out
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+def deit_base_patch16(pretrained=False, **kwargs):
+    model = ViT(
+        patch_size=16,
+        drop_rate=0.,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        num_classes=1000,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        use_checkpoint=True,
+        num_extra_tokens=2,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+def mae_base_patch16(pretrained=False, **kwargs):
+    model = ViT(
+        patch_size=16,
+        drop_rate=0.,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        num_classes=1000,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        use_checkpoint=True,
+        num_extra_tokens=1,
+        **kwargs)
+    model.default_cfg = _cfg()
+    return model

+ 7 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py

@@ -0,0 +1,7 @@
+from .models import (
+    LayoutLMv3Config,
+    LayoutLMv3ForTokenClassification,
+    LayoutLMv3ForQuestionAnswering,
+    LayoutLMv3ForSequenceClassification,
+    LayoutLMv3Tokenizer,
+)

+ 2 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py

@@ -0,0 +1,2 @@
+# flake8: noqa
+from .data_collator import DataCollatorForKeyValueExtraction

+ 171 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py

@@ -0,0 +1,171 @@
+'''
+Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
+'''
+
+
+import json
+import os
+from pathlib import Path
+import datasets
+from .image_utils import load_image, normalize_bbox
+logger = datasets.logging.get_logger(__name__)
+_CITATION = """\
+@article{park2019cord,
+  title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
+  author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
+  booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
+  year={2019}
+}
+"""
+_DESCRIPTION = """\
+https://github.com/clovaai/cord/
+"""
+
+def quad_to_box(quad):
+    # test 87 is wrongly annotated
+    box = (
+        max(0, quad["x1"]),
+        max(0, quad["y1"]),
+        quad["x3"],
+        quad["y3"]
+    )
+    if box[3] < box[1]:
+        bbox = list(box)
+        tmp = bbox[3]
+        bbox[3] = bbox[1]
+        bbox[1] = tmp
+        box = tuple(bbox)
+    if box[2] < box[0]:
+        bbox = list(box)
+        tmp = bbox[2]
+        bbox[2] = bbox[0]
+        bbox[0] = tmp
+        box = tuple(bbox)
+    return box
+
+def _get_drive_url(url):
+    base_url = 'https://drive.google.com/uc?id='
+    split_url = url.split('/')
+    return base_url + split_url[5]
+
+_URLS = [
+    _get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
+    _get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
+    # If you failed to download the dataset through the automatic downloader,
+    # you can download it manually and modify the code to get the local dataset.
+    # Or you can use the following links. Please follow the original LICENSE of CORD for usage.
+    # "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
+    # "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
+]
+
+class CordConfig(datasets.BuilderConfig):
+    """BuilderConfig for CORD"""
+    def __init__(self, **kwargs):
+        """BuilderConfig for CORD.
+        Args:
+          **kwargs: keyword arguments forwarded to super.
+        """
+        super(CordConfig, self).__init__(**kwargs)
+
+class Cord(datasets.GeneratorBasedBuilder):
+    BUILDER_CONFIGS = [
+        CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
+    ]
+
+    def _info(self):
+        return datasets.DatasetInfo(
+            description=_DESCRIPTION,
+            features=datasets.Features(
+                {
+                    "id": datasets.Value("string"),
+                    "words": datasets.Sequence(datasets.Value("string")),
+                    "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
+                    "ner_tags": datasets.Sequence(
+                        datasets.features.ClassLabel(
+                            names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
+                        )
+                    ),
+                    "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
+                    "image_path": datasets.Value("string"),
+                }
+            ),
+            supervised_keys=None,
+            citation=_CITATION,
+            homepage="https://github.com/clovaai/cord/",
+        )
+
+    def _split_generators(self, dl_manager):
+        """Returns SplitGenerators."""
+        """Uses local files located with data_dir"""
+        downloaded_file = dl_manager.download_and_extract(_URLS)
+        # move files from the second URL together with files from the first one.
+        dest = Path(downloaded_file[0])/"CORD"
+        for split in ["train", "dev", "test"]:
+            for file_type in ["image", "json"]:
+                if split == "test" and file_type == "json":
+                    continue
+                files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
+                for f in files:
+                    os.rename(f, dest/split/file_type/f.name)
+        return [
+            datasets.SplitGenerator(
+                name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
+            ),
+            datasets.SplitGenerator(
+                name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
+            ),
+            datasets.SplitGenerator(
+                name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
+            ),
+        ]
+
+    def get_line_bbox(self, bboxs):
+        x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
+        y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
+
+        x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
+
+        assert x1 >= x0 and y1 >= y0
+        bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
+        return bbox
+
+    def _generate_examples(self, filepath):
+        logger.info("⏳ Generating examples from = %s", filepath)
+        ann_dir = os.path.join(filepath, "json")
+        img_dir = os.path.join(filepath, "image")
+        for guid, file in enumerate(sorted(os.listdir(ann_dir))):
+            words = []
+            bboxes = []
+            ner_tags = []
+            file_path = os.path.join(ann_dir, file)
+            with open(file_path, "r", encoding="utf8") as f:
+                data = json.load(f)
+            image_path = os.path.join(img_dir, file)
+            image_path = image_path.replace("json", "png")
+            image, size = load_image(image_path)
+            for item in data["valid_line"]:
+                cur_line_bboxes = []
+                line_words, label = item["words"], item["category"]
+                line_words = [w for w in line_words if w["text"].strip() != ""]
+                if len(line_words) == 0:
+                    continue
+                if label == "other":
+                    for w in line_words:
+                        words.append(w["text"])
+                        ner_tags.append("O")
+                        cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
+                else:
+                    words.append(line_words[0]["text"])
+                    ner_tags.append("B-" + label.upper())
+                    cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
+                    for w in line_words[1:]:
+                        words.append(w["text"])
+                        ner_tags.append("I-" + label.upper())
+                        cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
+                # by default: --segment_level_layout 1
+                # if do not want to use segment_level_layout, comment the following line
+                cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
+                bboxes.extend(cur_line_bboxes)
+            # yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
+            yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
+                         "image": image, "image_path": image_path}

+ 124 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py

@@ -0,0 +1,124 @@
+import torch
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from transformers import BatchEncoding, PreTrainedTokenizerBase
+from transformers.data.data_collator import (
+    DataCollatorMixin,
+    _torch_collate_batch,
+)
+from transformers.file_utils import PaddingStrategy
+
+from typing import NewType
+InputDataClass = NewType("InputDataClass", Any)
+
+def pre_calc_rel_mat(segment_ids):
+    valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
+                             device=segment_ids.device, dtype=torch.bool)
+    for i in range(segment_ids.shape[0]):
+        for j in range(segment_ids.shape[1]):
+            valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
+
+    return valid_span
+
+@dataclass
+class DataCollatorForKeyValueExtraction(DataCollatorMixin):
+    """
+    Data collator that will dynamically pad the inputs received, as well as the labels.
+    Args:
+        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
+            The tokenizer used for encoding the data.
+        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
+            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
+            among:
+            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+              sequence if provided).
+            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
+              maximum acceptable input length for the model if that argument is not provided.
+            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
+              different lengths).
+        max_length (:obj:`int`, `optional`):
+            Maximum length of the returned list and optionally padding length (see above).
+        pad_to_multiple_of (:obj:`int`, `optional`):
+            If set will pad the sequence to a multiple of the provided value.
+            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
+            7.5 (Volta).
+        label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
+            The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
+    """
+
+    tokenizer: PreTrainedTokenizerBase
+    padding: Union[bool, str, PaddingStrategy] = True
+    max_length: Optional[int] = None
+    pad_to_multiple_of: Optional[int] = None
+    label_pad_token_id: int = -100
+
+    def __call__(self, features):
+        label_name = "label" if "label" in features[0].keys() else "labels"
+        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
+
+        images = None
+        if "images" in features[0]:
+            images = torch.stack([torch.tensor(d.pop("images")) for d in features])
+            IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
+
+        batch = self.tokenizer.pad(
+            features,
+            padding=self.padding,
+            max_length=self.max_length,
+            pad_to_multiple_of=self.pad_to_multiple_of,
+            # Conversion to tensors will fail if we have labels as they are not of the same length yet.
+            return_tensors="pt" if labels is None else None,
+        )
+
+        if images is not None:
+            batch["images"] = images
+            batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
+                     for k, v in batch.items()}
+            visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
+            batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
+
+        if labels is None:
+            return batch
+
+        has_bbox_input = "bbox" in features[0]
+        has_position_input = "position_ids" in features[0]
+        padding_idx=self.tokenizer.pad_token_id
+        sequence_length = torch.tensor(batch["input_ids"]).shape[1]
+        padding_side = self.tokenizer.padding_side
+        if padding_side == "right":
+            batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
+            if has_bbox_input:
+                batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
+            if has_position_input:
+                batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
+                                          for position_id in batch["position_ids"]]
+
+        else:
+            batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
+            if has_bbox_input:
+                batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
+            if has_position_input:
+                batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
+                                          + position_id for position_id in batch["position_ids"]]
+
+        if 'segment_ids' in batch:
+            assert 'position_ids' in batch
+            for i in range(len(batch['segment_ids'])):
+                batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
+                    batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
+
+        batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
+
+        if 'segment_ids' in batch:
+            valid_span = pre_calc_rel_mat(
+                segment_ids=batch['segment_ids']
+            )
+            batch['valid_span'] = valid_span
+            del batch['segment_ids']
+
+        if images is not None:
+            visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
+            batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
+
+        return batch

+ 136 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py

@@ -0,0 +1,136 @@
+# coding=utf-8
+'''
+Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
+'''
+import json
+import os
+
+import datasets
+
+from .image_utils import load_image, normalize_bbox
+
+
+logger = datasets.logging.get_logger(__name__)
+
+
+_CITATION = """\
+@article{Jaume2019FUNSDAD,
+  title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
+  author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
+  journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
+  year={2019},
+  volume={2},
+  pages={1-6}
+}
+"""
+
+_DESCRIPTION = """\
+https://guillaumejaume.github.io/FUNSD/
+"""
+
+
+class FunsdConfig(datasets.BuilderConfig):
+    """BuilderConfig for FUNSD"""
+
+    def __init__(self, **kwargs):
+        """BuilderConfig for FUNSD.
+
+        Args:
+          **kwargs: keyword arguments forwarded to super.
+        """
+        super(FunsdConfig, self).__init__(**kwargs)
+
+
+class Funsd(datasets.GeneratorBasedBuilder):
+    """Conll2003 dataset."""
+
+    BUILDER_CONFIGS = [
+        FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
+    ]
+
+    def _info(self):
+        return datasets.DatasetInfo(
+            description=_DESCRIPTION,
+            features=datasets.Features(
+                {
+                    "id": datasets.Value("string"),
+                    "tokens": datasets.Sequence(datasets.Value("string")),
+                    "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
+                    "ner_tags": datasets.Sequence(
+                        datasets.features.ClassLabel(
+                            names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
+                        )
+                    ),
+                    "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
+                    "image_path": datasets.Value("string"),
+                }
+            ),
+            supervised_keys=None,
+            homepage="https://guillaumejaume.github.io/FUNSD/",
+            citation=_CITATION,
+        )
+
+    def _split_generators(self, dl_manager):
+        """Returns SplitGenerators."""
+        downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
+        return [
+            datasets.SplitGenerator(
+                name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
+            ),
+            datasets.SplitGenerator(
+                name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
+            ),
+        ]
+
+    def get_line_bbox(self, bboxs):
+        x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
+        y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
+
+        x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
+
+        assert x1 >= x0 and y1 >= y0
+        bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
+        return bbox
+
+    def _generate_examples(self, filepath):
+        logger.info("⏳ Generating examples from = %s", filepath)
+        ann_dir = os.path.join(filepath, "annotations")
+        img_dir = os.path.join(filepath, "images")
+        for guid, file in enumerate(sorted(os.listdir(ann_dir))):
+            tokens = []
+            bboxes = []
+            ner_tags = []
+
+            file_path = os.path.join(ann_dir, file)
+            with open(file_path, "r", encoding="utf8") as f:
+                data = json.load(f)
+            image_path = os.path.join(img_dir, file)
+            image_path = image_path.replace("json", "png")
+            image, size = load_image(image_path)
+            for item in data["form"]:
+                cur_line_bboxes = []
+                words, label = item["words"], item["label"]
+                words = [w for w in words if w["text"].strip() != ""]
+                if len(words) == 0:
+                    continue
+                if label == "other":
+                    for w in words:
+                        tokens.append(w["text"])
+                        ner_tags.append("O")
+                        cur_line_bboxes.append(normalize_bbox(w["box"], size))
+                else:
+                    tokens.append(words[0]["text"])
+                    ner_tags.append("B-" + label.upper())
+                    cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
+                    for w in words[1:]:
+                        tokens.append(w["text"])
+                        ner_tags.append("I-" + label.upper())
+                        cur_line_bboxes.append(normalize_bbox(w["box"], size))
+                # by default: --segment_level_layout 1
+                # if do not want to use segment_level_layout, comment the following line
+                cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
+                # box = normalize_bbox(item["box"], size)
+                # cur_line_bboxes = [box for _ in range(len(words))]
+                bboxes.extend(cur_line_bboxes)
+            yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
+                         "image": image, "image_path": image_path}

+ 284 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py

@@ -0,0 +1,284 @@
+import torchvision.transforms.functional as F
+import warnings
+import math
+import random
+import numpy as np
+from PIL import Image
+import torch
+
+from detectron2.data.detection_utils import read_image
+from detectron2.data.transforms import ResizeTransform, TransformList
+
+def normalize_bbox(bbox, size):
+    return [
+        int(1000 * bbox[0] / size[0]),
+        int(1000 * bbox[1] / size[1]),
+        int(1000 * bbox[2] / size[0]),
+        int(1000 * bbox[3] / size[1]),
+    ]
+
+
+def load_image(image_path):
+    image = read_image(image_path, format="BGR")
+    h = image.shape[0]
+    w = image.shape[1]
+    img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
+    image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1)  # copy to make it writeable
+    return image, (w, h)
+
+
+def crop(image, i, j, h, w, boxes=None):
+    cropped_image = F.crop(image, i, j, h, w)
+
+    if boxes is not None:
+        # Currently we cannot use this case since when some boxes is out of the cropped image,
+        # it may be better to drop out these boxes along with their text input (instead of min or clamp)
+        # which haven't been implemented here
+        max_size = torch.as_tensor([w, h], dtype=torch.float32)
+        cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
+        cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
+        cropped_boxes = cropped_boxes.clamp(min=0)
+        boxes = cropped_boxes.reshape(-1, 4)
+
+    return cropped_image, boxes
+
+
+def resize(image, size, interpolation, boxes=None):
+    # It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
+    # which is compatible with a square image size of 224x224
+    rescaled_image = F.resize(image, size, interpolation)
+
+    if boxes is None:
+        return rescaled_image, None
+
+    ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
+    ratio_width, ratio_height = ratios
+
+    # boxes = boxes.copy()
+    scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
+
+    return rescaled_image, scaled_boxes
+
+
+def clamp(num, min_value, max_value):
+    return max(min(num, max_value), min_value)
+
+
+def get_bb(bb, page_size):
+    bbs = [float(j) for j in bb]
+    xs, ys = [], []
+    for i, b in enumerate(bbs):
+        if i % 2 == 0:
+            xs.append(b)
+        else:
+            ys.append(b)
+    (width, height) = page_size
+    return_bb = [
+        clamp(min(xs), 0, width - 1),
+        clamp(min(ys), 0, height - 1),
+        clamp(max(xs), 0, width - 1),
+        clamp(max(ys), 0, height - 1),
+    ]
+    return_bb = [
+            int(1000 * return_bb[0] / width),
+            int(1000 * return_bb[1] / height),
+            int(1000 * return_bb[2] / width),
+            int(1000 * return_bb[3] / height),
+        ]
+    return return_bb
+
+
+class ToNumpy:
+
+    def __call__(self, pil_img):
+        np_img = np.array(pil_img, dtype=np.uint8)
+        if np_img.ndim < 3:
+            np_img = np.expand_dims(np_img, axis=-1)
+        np_img = np.rollaxis(np_img, 2)  # HWC to CHW
+        return np_img
+
+
+class ToTensor:
+
+    def __init__(self, dtype=torch.float32):
+        self.dtype = dtype
+
+    def __call__(self, pil_img):
+        np_img = np.array(pil_img, dtype=np.uint8)
+        if np_img.ndim < 3:
+            np_img = np.expand_dims(np_img, axis=-1)
+        np_img = np.rollaxis(np_img, 2)  # HWC to CHW
+        return torch.from_numpy(np_img).to(dtype=self.dtype)
+
+
+_pil_interpolation_to_str = {
+    F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
+    F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
+    F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
+    F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
+    F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
+    F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
+}
+
+
+def _pil_interp(method):
+    if method == 'bicubic':
+        return F.InterpolationMode.BICUBIC
+    elif method == 'lanczos':
+        return F.InterpolationMode.LANCZOS
+    elif method == 'hamming':
+        return F.InterpolationMode.HAMMING
+    else:
+        # default bilinear, do we want to allow nearest?
+        return F.InterpolationMode.BILINEAR
+
+
+class Compose:
+    """Composes several transforms together. This transform does not support torchscript.
+    Please, see the note below.
+
+    Args:
+        transforms (list of ``Transform`` objects): list of transforms to compose.
+
+    Example:
+        >>> transforms.Compose([
+        >>>     transforms.CenterCrop(10),
+        >>>     transforms.PILToTensor(),
+        >>>     transforms.ConvertImageDtype(torch.float),
+        >>> ])
+
+    .. note::
+        In order to script the transformations, please use ``torch.nn.Sequential`` as below.
+
+        >>> transforms = torch.nn.Sequential(
+        >>>     transforms.CenterCrop(10),
+        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+        >>> )
+        >>> scripted_transforms = torch.jit.script(transforms)
+
+        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
+        `lambda` functions or ``PIL.Image``.
+
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, img, augmentation=False, box=None):
+        for t in self.transforms:
+            img = t(img, augmentation, box)
+        return img
+
+
+class RandomResizedCropAndInterpolationWithTwoPic:
+    """Crop the given PIL Image to random size and aspect ratio with random interpolation.
+    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+    is finally resized to given size.
+    This is popularly used to train the Inception networks.
+    Args:
+        size: expected output size of each edge
+        scale: range of size of the origin size cropped
+        ratio: range of aspect ratio of the origin aspect ratio cropped
+        interpolation: Default: PIL.Image.BILINEAR
+    """
+
+    def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
+                 interpolation='bilinear', second_interpolation='lanczos'):
+        if isinstance(size, tuple):
+            self.size = size
+        else:
+            self.size = (size, size)
+        if second_size is not None:
+            if isinstance(second_size, tuple):
+                self.second_size = second_size
+            else:
+                self.second_size = (second_size, second_size)
+        else:
+            self.second_size = None
+        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+            warnings.warn("range should be of kind (min, max)")
+
+        self.interpolation = _pil_interp(interpolation)
+        self.second_interpolation = _pil_interp(second_interpolation)
+        self.scale = scale
+        self.ratio = ratio
+
+    @staticmethod
+    def get_params(img, scale, ratio):
+        """Get parameters for ``crop`` for a random sized crop.
+        Args:
+            img (PIL Image): Image to be cropped.
+            scale (tuple): range of size of the origin size cropped
+            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+        Returns:
+            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+                sized crop.
+        """
+        area = img.size[0] * img.size[1]
+
+        for attempt in range(10):
+            target_area = random.uniform(*scale) * area
+            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+            aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if w <= img.size[0] and h <= img.size[1]:
+                i = random.randint(0, img.size[1] - h)
+                j = random.randint(0, img.size[0] - w)
+                return i, j, h, w
+
+        # Fallback to central crop
+        in_ratio = img.size[0] / img.size[1]
+        if in_ratio < min(ratio):
+            w = img.size[0]
+            h = int(round(w / min(ratio)))
+        elif in_ratio > max(ratio):
+            h = img.size[1]
+            w = int(round(h * max(ratio)))
+        else:  # whole image
+            w = img.size[0]
+            h = img.size[1]
+        i = (img.size[1] - h) // 2
+        j = (img.size[0] - w) // 2
+        return i, j, h, w
+
+    def __call__(self, img, augmentation=False, box=None):
+        """
+        Args:
+            img (PIL Image): Image to be cropped and resized.
+        Returns:
+            PIL Image: Randomly cropped and resized image.
+        """
+        if augmentation:
+            i, j, h, w = self.get_params(img, self.scale, self.ratio)
+            img = F.crop(img, i, j, h, w)
+            # img, box = crop(img, i, j, h, w, box)
+        img = F.resize(img, self.size, self.interpolation)
+        second_img = F.resize(img, self.second_size, self.second_interpolation) \
+            if self.second_size is not None else None
+        return img, second_img
+
+    def __repr__(self):
+        if isinstance(self.interpolation, (tuple, list)):
+            interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
+        else:
+            interpolate_str = _pil_interpolation_to_str[self.interpolation]
+        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+        format_string += ', interpolation={0}'.format(interpolate_str)
+        if self.second_size is not None:
+            format_string += ', second_size={0}'.format(self.second_size)
+            format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
+        format_string += ')'
+        return format_string
+
+
+def pil_loader(path: str) -> Image.Image:
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    with open(path, 'rb') as f:
+        img = Image.open(f)
+        return img.convert('RGB')

+ 213 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py

@@ -0,0 +1,213 @@
+import os
+import json
+
+import torch
+from torch.utils.data.dataset import Dataset
+from torchvision import transforms
+from PIL import Image
+
+from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
+
+XFund_label2ids = {
+    "O":0,
+    'B-HEADER':1,
+    'I-HEADER':2,
+    'B-QUESTION':3,
+    'I-QUESTION':4,
+    'B-ANSWER':5,
+    'I-ANSWER':6,
+}
+
+class xfund_dataset(Dataset):
+    def box_norm(self, box, width, height):
+        def clip(min_num, num, max_num):
+            return min(max(num, min_num), max_num)
+
+        x0, y0, x1, y1 = box
+        x0 = clip(0, int((x0 / width) * 1000), 1000)
+        y0 = clip(0, int((y0 / height) * 1000), 1000)
+        x1 = clip(0, int((x1 / width) * 1000), 1000)
+        y1 = clip(0, int((y1 / height) * 1000), 1000)
+        assert x1 >= x0
+        assert y1 >= y0
+        return [x0, y0, x1, y1]
+
+    def get_segment_ids(self, bboxs):
+        segment_ids = []
+        for i in range(len(bboxs)):
+            if i == 0:
+                segment_ids.append(0)
+            else:
+                if bboxs[i - 1] == bboxs[i]:
+                    segment_ids.append(segment_ids[-1])
+                else:
+                    segment_ids.append(segment_ids[-1] + 1)
+        return segment_ids
+
+    def get_position_ids(self, segment_ids):
+        position_ids = []
+        for i in range(len(segment_ids)):
+            if i == 0:
+                position_ids.append(2)
+            else:
+                if segment_ids[i] == segment_ids[i - 1]:
+                    position_ids.append(position_ids[-1] + 1)
+                else:
+                    position_ids.append(2)
+        return position_ids
+
+    def load_data(
+            self,
+            data_file,
+    ):
+        # re-org data format
+        total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
+        for i in range(len(data_file['documents'])):
+            width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
+                'height']
+
+            cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
+            for j in range(len(data_file['documents'][i]['document'])):
+                cur_item = data_file['documents'][i]['document'][j]
+                cur_doc_lines.append(cur_item['text'])
+                cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
+                cur_doc_ner_tags.append(cur_item['label'])
+            total_data['id'] += [len(total_data['id'])]
+            total_data['lines'] += [cur_doc_lines]
+            total_data['bboxes'] += [cur_doc_bboxes]
+            total_data['ner_tags'] += [cur_doc_ner_tags]
+            total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
+
+        # tokenize text and get bbox/label
+        total_input_ids, total_bboxs, total_label_ids = [], [], []
+        for i in range(len(total_data['lines'])):
+            cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
+            for j in range(len(total_data['lines'][i])):
+                cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
+                if len(cur_input_ids) == 0: continue
+
+                cur_label = total_data['ner_tags'][i][j].upper()
+                if cur_label == 'OTHER':
+                    cur_labels = ["O"] * len(cur_input_ids)
+                    for k in range(len(cur_labels)):
+                        cur_labels[k] = self.label2ids[cur_labels[k]]
+                else:
+                    cur_labels = [cur_label] * len(cur_input_ids)
+                    cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
+                    for k in range(1, len(cur_labels)):
+                        cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
+                assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
+                cur_doc_input_ids += cur_input_ids
+                cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
+                cur_doc_labels += cur_labels
+            assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
+            assert len(cur_doc_input_ids) > 0
+
+            total_input_ids.append(cur_doc_input_ids)
+            total_bboxs.append(cur_doc_bboxs)
+            total_label_ids.append(cur_doc_labels)
+        assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
+
+        # split text to several slices because of over-length
+        input_ids, bboxs, labels = [], [], []
+        segment_ids, position_ids = [], []
+        image_path = []
+        for i in range(len(total_input_ids)):
+            start = 0
+            cur_iter = 0
+            while start < len(total_input_ids[i]):
+                end = min(start + 510, len(total_input_ids[i]))
+
+                input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
+                bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
+                labels.append([-100] + total_label_ids[i][start: end] + [-100])
+
+                cur_segment_ids = self.get_segment_ids(bboxs[-1])
+                cur_position_ids = self.get_position_ids(cur_segment_ids)
+                segment_ids.append(cur_segment_ids)
+                position_ids.append(cur_position_ids)
+                image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
+
+                start = end
+                cur_iter += 1
+
+        assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
+        assert len(segment_ids) == len(image_path)
+
+        res = {
+            'input_ids': input_ids,
+            'bbox': bboxs,
+            'labels': labels,
+            'segment_ids': segment_ids,
+            'position_ids': position_ids,
+            'image_path': image_path,
+        }
+        return res
+
+    def __init__(
+            self,
+            args,
+            tokenizer,
+            mode
+    ):
+        self.args = args
+        self.mode = mode
+        self.cur_la = args.language
+        self.tokenizer = tokenizer
+        self.label2ids = XFund_label2ids
+
+
+        self.common_transform = Compose([
+            RandomResizedCropAndInterpolationWithTwoPic(
+                size=args.input_size, interpolation=args.train_interpolation,
+            ),
+        ])
+
+        self.patch_transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize(
+                mean=torch.tensor((0.5, 0.5, 0.5)),
+                std=torch.tensor((0.5, 0.5, 0.5)))
+        ])
+
+        data_file = json.load(
+            open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
+                 'r'))
+
+        self.feature = self.load_data(data_file)
+
+    def __len__(self):
+        return len(self.feature['input_ids'])
+
+    def __getitem__(self, index):
+        input_ids = self.feature["input_ids"][index]
+
+        # attention_mask = self.feature["attention_mask"][index]
+        attention_mask = [1] * len(input_ids)
+        labels = self.feature["labels"][index]
+        bbox = self.feature["bbox"][index]
+        segment_ids = self.feature['segment_ids'][index]
+        position_ids = self.feature['position_ids'][index]
+
+        img = pil_loader(self.feature['image_path'][index])
+        for_patches, _ = self.common_transform(img, augmentation=False)
+        patch = self.patch_transform(for_patches)
+
+        assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
+
+        res = {
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            "labels": labels,
+            "bbox": bbox,
+            "segment_ids": segment_ids,
+            "position_ids": position_ids,
+            "images": patch,
+        }
+        return res
+
+def pil_loader(path: str) -> Image.Image:
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    with open(path, 'rb') as f:
+        img = Image.open(f)
+        return img.convert('RGB')

+ 7 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py

@@ -0,0 +1,7 @@
+from .layoutlmv3 import (
+    LayoutLMv3Config,
+    LayoutLMv3ForTokenClassification,
+    LayoutLMv3ForQuestionAnswering,
+    LayoutLMv3ForSequenceClassification,
+    LayoutLMv3Tokenizer,
+)

+ 24 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py

@@ -0,0 +1,24 @@
+from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \
+    AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
+from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter
+
+from .configuration_layoutlmv3 import LayoutLMv3Config
+from .modeling_layoutlmv3 import (
+    LayoutLMv3ForTokenClassification,
+    LayoutLMv3ForQuestionAnswering,
+    LayoutLMv3ForSequenceClassification,
+    LayoutLMv3Model,
+)
+from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
+from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
+
+
+#AutoConfig.register("layoutlmv3", LayoutLMv3Config)
+#AutoModel.register(LayoutLMv3Config, LayoutLMv3Model)
+#AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification)
+#AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering)
+#AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification)
+#AutoTokenizer.register(
+#    LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast
+#)
+SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv3Tokenizer": RobertaConverter})

+ 60 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py

@@ -0,0 +1,60 @@
+# coding=utf-8
+from transformers.models.bert.configuration_bert import BertConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+    "layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
+    "layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/resolve/main/config.json",
+    # See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3
+}
+
+
+class LayoutLMv3Config(BertConfig):
+    model_type = "layoutlmv3"
+
+    def __init__(
+        self,
+        pad_token_id=1,
+        bos_token_id=0,
+        eos_token_id=2,
+        max_2d_position_embeddings=1024,
+        coordinate_size=None,
+        shape_size=None,
+        has_relative_attention_bias=False,
+        rel_pos_bins=32,
+        max_rel_pos=128,
+        has_spatial_attention_bias=False,
+        rel_2d_pos_bins=64,
+        max_rel_2d_pos=256,
+        visual_embed=True,
+        mim=False,
+        wpa_task=False,
+        discrete_vae_weight_path='',
+        discrete_vae_type='dall-e',
+        input_size=224,
+        second_input_size=112,
+        device='cuda',
+        **kwargs
+    ):
+        """Constructs RobertaConfig."""
+        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+        self.max_2d_position_embeddings = max_2d_position_embeddings
+        self.coordinate_size = coordinate_size
+        self.shape_size = shape_size
+        self.has_relative_attention_bias = has_relative_attention_bias
+        self.rel_pos_bins = rel_pos_bins
+        self.max_rel_pos = max_rel_pos
+        self.has_spatial_attention_bias = has_spatial_attention_bias
+        self.rel_2d_pos_bins = rel_2d_pos_bins
+        self.max_rel_2d_pos = max_rel_2d_pos
+        self.visual_embed = visual_embed
+        self.mim = mim
+        self.wpa_task = wpa_task
+        self.discrete_vae_weight_path = discrete_vae_weight_path
+        self.discrete_vae_type = discrete_vae_type
+        self.input_size = input_size
+        self.second_input_size = second_input_size
+        self.device = device

+ 1282 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py

@@ -0,0 +1,1282 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch LayoutLMv3 model. """
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers import apply_chunking_to_forward
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    MaskedLMOutput,
+    TokenClassifierOutput,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from transformers.models.roberta.modeling_roberta import (
+    RobertaIntermediate,
+    RobertaLMHead,
+    RobertaOutput,
+    RobertaSelfOutput,
+)
+from transformers.utils import logging
+
+from .configuration_layoutlmv3 import LayoutLMv3Config
+from timm.models.layers import to_2tuple
+
+
+logger = logging.get_logger(__name__)
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        # The following variables are used in detection mycheckpointer.py
+        self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.num_patches_w = self.patch_shape[0]
+        self.num_patches_h = self.patch_shape[1]
+
+    def forward(self, x, position_embedding=None):
+        x = self.proj(x)
+
+        if position_embedding is not None:
+            # interpolate the position embedding to the corresponding size
+            position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3, 1, 2)
+            Hp, Wp = x.shape[2], x.shape[3]
+            position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
+            x = x + position_embedding
+
+        x = x.flatten(2).transpose(1, 2)
+        return x
+
+class LayoutLMv3Embeddings(nn.Module):
+    """
+    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+    """
+
+    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+        # End copy
+        self.padding_idx = config.pad_token_id
+        self.position_embeddings = nn.Embedding(
+            config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+        )
+
+        self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+        self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+        self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+        self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+
+    def _calc_spatial_position_embeddings(self, bbox):
+        try:
+            assert torch.all(0 <= bbox) and torch.all(bbox <= 1023)
+            left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
+            upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
+            right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
+            lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
+        except IndexError as e:
+            raise IndexError("The :obj:`bbox` coordinate values should be within 0-1000 range.") from e
+
+        h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
+        w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
+
+        # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
+        spatial_position_embeddings = torch.cat(
+            [
+                left_position_embeddings,
+                upper_position_embeddings,
+                right_position_embeddings,
+                lower_position_embeddings,
+                h_position_embeddings,
+                w_position_embeddings,
+            ],
+            dim=-1,
+        )
+        return spatial_position_embeddings
+
+    def create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):
+        """
+        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+        are ignored. This is modified from fairseq's `utils.make_positions`.
+
+        Args:
+            x: torch.Tensor x:
+
+        Returns: torch.Tensor
+        """
+        # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+        mask = input_ids.ne(padding_idx).int()
+        incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+        return incremental_indices.long() + padding_idx
+
+    def forward(
+        self,
+        input_ids=None,
+        bbox=None,
+        token_type_ids=None,
+        position_ids=None,
+        inputs_embeds=None,
+        past_key_values_length=0,
+    ):
+        if position_ids is None:
+            if input_ids is not None:
+                # Create the position ids from the input token ids. Any padded tokens remain padded.
+                position_ids = self.create_position_ids_from_input_ids(
+                    input_ids, self.padding_idx, past_key_values_length).to(input_ids.device)
+            else:
+                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+        if input_ids is not None:
+            input_shape = input_ids.size()
+        else:
+            input_shape = inputs_embeds.size()[:-1]
+
+        if token_type_ids is None:
+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+        token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+        embeddings = inputs_embeds + token_type_embeddings
+        position_embeddings = self.position_embeddings(position_ids)
+        embeddings += position_embeddings
+
+        spatial_position_embeddings = self._calc_spatial_position_embeddings(bbox)
+
+        embeddings = embeddings + spatial_position_embeddings
+
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+    def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+        """
+        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+        Args:
+            inputs_embeds: torch.Tensor≈
+
+        Returns: torch.Tensor
+        """
+        input_shape = inputs_embeds.size()[:-1]
+        sequence_length = input_shape[1]
+
+        position_ids = torch.arange(
+            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+        )
+        return position_ids.unsqueeze(0).expand(input_shape)
+
+
+class LayoutLMv3PreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = LayoutLMv3Config
+    base_model_prefix = "layoutlmv3"
+
+    # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, nn.Linear):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+class LayoutLMv3SelfAttention(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+                f"heads ({config.num_attention_heads})"
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.has_relative_attention_bias = config.has_relative_attention_bias
+        self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def cogview_attn(self, attention_scores, alpha=32):
+        '''
+        https://arxiv.org/pdf/2105.13290.pdf
+        Section 2.4 Stabilization of training: Precision Bottleneck Relaxation (PB-Relax).
+        A replacement of the original nn.Softmax(dim=-1)(attention_scores)
+        Seems the new attention_probs will result in a slower speed and a little bias
+        Can use torch.allclose(standard_attention_probs, cogview_attention_probs, atol=1e-08) for comparison
+        The smaller atol (e.g., 1e-08), the better.
+        '''
+        scaled_attention_scores = attention_scores / alpha
+        max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
+        # max_value = scaled_attention_scores.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1)
+        new_attention_scores = (scaled_attention_scores - max_value) * alpha
+        return nn.Softmax(dim=-1)(new_attention_scores)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        rel_pos=None,
+        rel_2d_pos=None,
+    ):
+        mixed_query_layer = self.query(hidden_states)
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+        if is_cross_attention and past_key_value is not None:
+            # reuse k,v, cross_attentions
+            key_layer = past_key_value[0]
+            value_layer = past_key_value[1]
+            attention_mask = encoder_attention_mask
+        elif is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
+        # Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf)
+        attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
+
+        if self.has_relative_attention_bias and self.has_spatial_attention_bias:
+            attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
+        elif self.has_relative_attention_bias:
+            attention_scores += rel_pos / math.sqrt(self.attention_head_size)
+
+        # if self.has_relative_attention_bias:
+        #     attention_scores += rel_pos
+        # if self.has_spatial_attention_bias:
+        #     attention_scores += rel_2d_pos
+
+        # attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        # attention_probs = nn.Softmax(dim=-1)(attention_scores)  # comment the line below and use this line for speedup
+        attention_probs = self.cogview_attn(attention_scores)  # to stablize training
+        # assert torch.allclose(attention_probs, nn.Softmax(dim=-1)(attention_scores), atol=1e-8)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        return outputs
+
+
+class LayoutLMv3Attention(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.self = LayoutLMv3SelfAttention(config)
+        self.output = RobertaSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        rel_pos=None,
+        rel_2d_pos=None,
+    ):
+        self_outputs = self.self(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            past_key_value,
+            output_attentions,
+            rel_pos=rel_pos,
+            rel_2d_pos=rel_2d_pos,
+        )
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class LayoutLMv3Layer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = LayoutLMv3Attention(config)
+        assert not config.is_decoder and not config.add_cross_attention, \
+            "This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder."
+        self.intermediate = RobertaIntermediate(config)
+        self.output = RobertaOutput(config)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        rel_pos=None,
+        rel_2d_pos=None,
+    ):
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        self_attention_outputs = self.attention(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            output_attentions=output_attentions,
+            past_key_value=self_attn_past_key_value,
+            rel_pos=rel_pos,
+            rel_2d_pos=rel_2d_pos,
+        )
+        attention_output = self_attention_outputs[0]
+
+        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+
+        layer_output = apply_chunking_to_forward(
+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+        )
+        outputs = (layer_output,) + outputs
+
+        return outputs
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+
+class LayoutLMv3Encoder(nn.Module):
+    def __init__(self, config, detection=False, out_features=None):
+        super().__init__()
+        self.config = config
+        self.detection = detection
+        self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+        self.has_relative_attention_bias = config.has_relative_attention_bias
+        self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+        if self.has_relative_attention_bias:
+            self.rel_pos_bins = config.rel_pos_bins
+            self.max_rel_pos = config.max_rel_pos
+            self.rel_pos_onehot_size = config.rel_pos_bins
+            self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False)
+
+        if self.has_spatial_attention_bias:
+            self.max_rel_2d_pos = config.max_rel_2d_pos
+            self.rel_2d_pos_bins = config.rel_2d_pos_bins
+            self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
+            self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
+            self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
+
+        if self.detection:
+            self.gradient_checkpointing = True
+            embed_dim = self.config.hidden_size
+            self.out_features = out_features
+            self.out_indices = [int(name[5:]) for name in out_features]
+            self.fpn1 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+                # nn.SyncBatchNorm(embed_dim),
+                nn.BatchNorm2d(embed_dim),
+                nn.GELU(),
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn2 = nn.Sequential(
+                nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+            )
+
+            self.fpn3 = nn.Identity()
+
+            self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+            self.ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+
+    def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+        ret = 0
+        if bidirectional:
+            num_buckets //= 2
+            ret += (relative_position > 0).long() * num_buckets
+            n = torch.abs(relative_position)
+        else:
+            n = torch.max(-relative_position, torch.zeros_like(relative_position))
+        # now n is in the range [0, inf)
+
+        # half of the buckets are for exact increments in positions
+        max_exact = num_buckets // 2
+        is_small = n < max_exact
+
+        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+        val_if_large = max_exact + (
+                torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+        ).to(torch.long)
+        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+        ret += torch.where(is_small, n, val_if_large)
+        return ret
+
+    def _cal_1d_pos_emb(self, hidden_states, position_ids, valid_span):
+        VISUAL_NUM = 196 + 1
+
+        rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
+
+        if valid_span is not None:
+            # for the text part, if two words are not in the same line,
+            # set their distance to the max value (position_ids.shape[-1])
+            rel_pos_mat[(rel_pos_mat > 0) & (valid_span == False)] = position_ids.shape[1]
+            rel_pos_mat[(rel_pos_mat < 0) & (valid_span == False)] = -position_ids.shape[1]
+
+            # image-text, minimum distance
+            rel_pos_mat[:, -VISUAL_NUM:, :-VISUAL_NUM] = 0
+            rel_pos_mat[:, :-VISUAL_NUM, -VISUAL_NUM:] = 0
+
+        rel_pos = self.relative_position_bucket(
+            rel_pos_mat,
+            num_buckets=self.rel_pos_bins,
+            max_distance=self.max_rel_pos,
+        )
+        rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states)
+        rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2)
+        rel_pos = rel_pos.contiguous()
+        return rel_pos
+
+    def _cal_2d_pos_emb(self, hidden_states, bbox):
+        position_coord_x = bbox[:, :, 0]
+        position_coord_y = bbox[:, :, 3]
+        rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
+        rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
+        rel_pos_x = self.relative_position_bucket(
+            rel_pos_x_2d_mat,
+            num_buckets=self.rel_2d_pos_bins,
+            max_distance=self.max_rel_2d_pos,
+        )
+        rel_pos_y = self.relative_position_bucket(
+            rel_pos_y_2d_mat,
+            num_buckets=self.rel_2d_pos_bins,
+            max_distance=self.max_rel_2d_pos,
+        )
+        rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
+        rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
+        rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2)
+        rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2)
+        rel_pos_x = rel_pos_x.contiguous()
+        rel_pos_y = rel_pos_y.contiguous()
+        rel_2d_pos = rel_pos_x + rel_pos_y
+        return rel_2d_pos
+
+    def forward(
+        self,
+        hidden_states,
+        bbox=None,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+        position_ids=None,
+        Hp=None,
+        Wp=None,
+        valid_span=None,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+        next_decoder_cache = () if use_cache else None
+
+        rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids, valid_span) if self.has_relative_attention_bias else None
+        rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
+
+        if self.detection:
+            feat_out = {}
+            j = 0
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            past_key_value = past_key_values[i] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                if use_cache:
+                    logger.warning(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                    )
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+                        # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos)
+                        # The above line will cause error:
+                        # RuntimeError: Trying to backward through the graph a second time
+                        # (or directly access saved tensors after they have already been freed).
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(layer_module),
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    past_key_value,
+                    output_attentions,
+                    rel_pos,
+                    rel_2d_pos
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    past_key_value,
+                    output_attentions,
+                    rel_pos=rel_pos,
+                    rel_2d_pos=rel_2d_pos,
+                )
+
+            hidden_states = layer_outputs[0]
+            if use_cache:
+                next_decoder_cache += (layer_outputs[-1],)
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+                if self.config.add_cross_attention:
+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+            if self.detection and i in self.out_indices:
+                xp = hidden_states[:, -Hp*Wp:, :].permute(0, 2, 1).reshape(len(hidden_states), -1, Hp, Wp)
+                feat_out[self.out_features[j]] = self.ops[j](xp.contiguous())
+                j += 1
+
+        if self.detection:
+            return feat_out
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_decoder_cache,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_decoder_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
+    """
+    """
+
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
+    def __init__(self, config, detection=False, out_features=None, image_only=False):
+        super().__init__(config)
+        self.config = config
+        assert not config.is_decoder and not config.add_cross_attention, \
+            "This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder."
+        self.detection = detection
+        if not self.detection:
+            self.image_only = False
+        else:
+            assert config.visual_embed
+            self.image_only = image_only
+
+        if not self.image_only:
+            self.embeddings = LayoutLMv3Embeddings(config)
+        self.encoder = LayoutLMv3Encoder(config, detection=detection, out_features=out_features)
+
+        if config.visual_embed:
+            embed_dim = self.config.hidden_size
+            # use the default pre-training parameters for fine-tuning (e.g., input_size)
+            # when the input_size is larger in fine-tuning, we will interpolate the position embedding in forward
+            self.patch_embed = PatchEmbed(embed_dim=embed_dim)
+
+            patch_size = 16
+            size = int(self.config.input_size / patch_size)
+            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+            self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, embed_dim))
+            self.pos_drop = nn.Dropout(p=0.)
+
+            self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+            self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+            if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+                self._init_visual_bbox(img_size=(size, size))
+
+            from functools import partial
+            norm_layer = partial(nn.LayerNorm, eps=1e-6)
+            self.norm = norm_layer(embed_dim)
+
+        self.init_weights()
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    def _init_visual_bbox(self, img_size=(14, 14), max_len=1000):
+        visual_bbox_x = torch.div(torch.arange(0, max_len * (img_size[1] + 1), max_len),
+                                  img_size[1], rounding_mode='trunc')
+        visual_bbox_y = torch.div(torch.arange(0, max_len * (img_size[0] + 1), max_len),
+                                  img_size[0], rounding_mode='trunc')
+        visual_bbox = torch.stack(
+            [
+                visual_bbox_x[:-1].repeat(img_size[0], 1),
+                visual_bbox_y[:-1].repeat(img_size[1], 1).transpose(0, 1),
+                visual_bbox_x[1:].repeat(img_size[0], 1),
+                visual_bbox_y[1:].repeat(img_size[1], 1).transpose(0, 1),
+            ],
+            dim=-1,
+        ).view(-1, 4)
+
+        cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
+        self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)
+
+    def _calc_visual_bbox(self, device, dtype, bsz):  # , img_size=(14, 14), max_len=1000):
+        visual_bbox = self.visual_bbox.repeat(bsz, 1, 1)
+        visual_bbox = visual_bbox.to(device).type(dtype)
+        return visual_bbox
+
+    def forward_image(self, x):
+        if self.detection:
+            x = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
+        else:
+            x = self.patch_embed(x)
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        if self.pos_embed is not None and self.detection:
+            cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
+
+        x = torch.cat((cls_tokens, x), dim=1)
+        if self.pos_embed is not None and not self.detection:
+            x = x + self.pos_embed
+        x = self.pos_drop(x)
+
+        x = self.norm(x)
+        return x
+
+    # Copied from transformers.models.bert.modeling_bert.BertModel.forward
+    def forward(
+        self,
+        input_ids=None,
+        bbox=None,
+        attention_mask=None,
+        token_type_ids=None,
+        valid_span=None,
+        position_ids=None,
+        head_mask=None,
+        inputs_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        images=None,
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        use_cache = False
+
+        # if input_ids is not None and inputs_embeds is not None:
+        #     raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        if input_ids is not None:
+            input_shape = input_ids.size()
+            batch_size, seq_length = input_shape
+            device = input_ids.device
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            batch_size, seq_length = input_shape
+            device = inputs_embeds.device
+        elif images is not None:
+            batch_size = len(images)
+            device = images.device
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds or images")
+
+        if not self.image_only:
+            # past_key_values_length
+            past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+            if attention_mask is None:
+                attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+            if token_type_ids is None:
+                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        # extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+
+        encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        if not self.image_only:
+            if bbox is None:
+                bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
+
+            embedding_output = self.embeddings(
+                input_ids=input_ids,
+                bbox=bbox,
+                position_ids=position_ids,
+                token_type_ids=token_type_ids,
+                inputs_embeds=inputs_embeds,
+                past_key_values_length=past_key_values_length,
+            )
+
+        final_bbox = final_position_ids = None
+        Hp = Wp = None
+        if images is not None:
+            patch_size = 16
+            Hp, Wp = int(images.shape[2] / patch_size), int(images.shape[3] / patch_size)
+            visual_emb = self.forward_image(images)
+            if self.detection:
+                visual_attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device)
+                if self.image_only:
+                    attention_mask = visual_attention_mask
+                else:
+                    attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
+            elif self.image_only:
+                attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device)
+
+            if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+                if self.config.has_spatial_attention_bias:
+                    visual_bbox = self._calc_visual_bbox(device, dtype=torch.long, bsz=batch_size)
+                    if self.image_only:
+                        final_bbox = visual_bbox
+                    else:
+                        final_bbox = torch.cat([bbox, visual_bbox], dim=1)
+
+                visual_position_ids = torch.arange(0, visual_emb.shape[1], dtype=torch.long, device=device).repeat(
+                    batch_size, 1)
+                if self.image_only:
+                    final_position_ids = visual_position_ids
+                else:
+                    position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
+                    position_ids = position_ids.expand_as(input_ids)
+                    final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
+
+            if self.image_only:
+                embedding_output = visual_emb
+            else:
+                embedding_output = torch.cat([embedding_output, visual_emb], dim=1)
+            embedding_output = self.LayerNorm(embedding_output)
+            embedding_output = self.dropout(embedding_output)
+        elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
+            if self.config.has_spatial_attention_bias:
+                final_bbox = bbox
+            if self.config.has_relative_attention_bias:
+                position_ids = self.embeddings.position_ids[:, :input_shape[1]]
+                position_ids = position_ids.expand_as(input_ids)
+                final_position_ids = position_ids
+
+        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            bbox=final_bbox,
+            position_ids=final_position_ids,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            Hp=Hp,
+            Wp=Wp,
+            valid_span=valid_span,
+        )
+
+        if self.detection:
+            return encoder_outputs
+
+        sequence_output = encoder_outputs[0]
+        pooled_output = None
+
+        if not return_dict:
+            return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
+
+class LayoutLMv3ClassificationHead(nn.Module):
+    """
+    Head for sentence-level classification tasks.
+    Reference: RobertaClassificationHead
+    """
+
+    def __init__(self, config, pool_feature=False):
+        super().__init__()
+        self.pool_feature = pool_feature
+        if pool_feature:
+            self.dense = nn.Linear(config.hidden_size*3, config.hidden_size)
+        else:
+            self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        classifier_dropout = (
+            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+        )
+        self.dropout = nn.Dropout(classifier_dropout)
+        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+    def forward(self, x):
+        # x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
+        x = self.dropout(x)
+        x = self.dense(x)
+        x = torch.tanh(x)
+        x = self.dropout(x)
+        x = self.out_proj(x)
+        return x
+
+
+class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.layoutlmv3 = LayoutLMv3Model(config)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        if config.num_labels < 10:
+            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+        else:
+            self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+        self.init_weights()
+
+    def forward(
+        self,
+        input_ids=None,
+        bbox=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        valid_span=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        images=None,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
+            1]``.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.layoutlmv3(
+            input_ids,
+            bbox=bbox,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            images=images,
+            valid_span=valid_span,
+        )
+
+        sequence_output = outputs[0]
+
+        sequence_output = self.dropout(sequence_output)
+        logits = self.classifier(sequence_output)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            # Only keep active parts of the loss
+            if attention_mask is not None:
+                active_loss = attention_mask.view(-1) == 1
+                active_logits = logits.view(-1, self.num_labels)
+                active_labels = torch.where(
+                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+                )
+                loss = loss_fct(active_logits, active_labels)
+            else:
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.layoutlmv3 = LayoutLMv3Model(config)
+        # self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+        self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+        self.init_weights()
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        valid_span=None,
+        head_mask=None,
+        inputs_embeds=None,
+        start_positions=None,
+        end_positions=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        bbox=None,
+        images=None,
+    ):
+        r"""
+        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for position (index) of the start of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+            sequence are not taken into account for computing the loss.
+        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for position (index) of the end of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+            sequence are not taken into account for computing the loss.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.layoutlmv3(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            bbox=bbox,
+            images=images,
+            valid_span=valid_span,
+        )
+
+        sequence_output = outputs[0]
+
+        logits = self.qa_outputs(sequence_output)
+        start_logits, end_logits = logits.split(1, dim=-1)
+        start_logits = start_logits.squeeze(-1).contiguous()
+        end_logits = end_logits.squeeze(-1).contiguous()
+
+        total_loss = None
+        if start_positions is not None and end_positions is not None:
+            # If we are on multi-GPU, split add a dimension
+            if len(start_positions.size()) > 1:
+                start_positions = start_positions.squeeze(-1)
+            if len(end_positions.size()) > 1:
+                end_positions = end_positions.squeeze(-1)
+            # sometimes the start/end positions are outside our model inputs, we ignore these terms
+            ignored_index = start_logits.size(1)
+            start_positions = start_positions.clamp(0, ignored_index)
+            end_positions = end_positions.clamp(0, ignored_index)
+
+            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+            start_loss = loss_fct(start_logits, start_positions)
+            end_loss = loss_fct(end_logits, end_positions)
+            total_loss = (start_loss + end_loss) / 2
+
+        if not return_dict:
+            output = (start_logits, end_logits) + outputs[2:]
+            return ((total_loss,) + output) if total_loss is not None else output
+
+        return QuestionAnsweringModelOutput(
+            loss=total_loss,
+            start_logits=start_logits,
+            end_logits=end_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.config = config
+        self.layoutlmv3 = LayoutLMv3Model(config)
+        self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
+
+        self.init_weights()
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        position_ids=None,
+        valid_span=None,
+        head_mask=None,
+        inputs_embeds=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        bbox=None,
+        images=None,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.layoutlmv3(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            bbox=bbox,
+            images=images,
+            valid_span=valid_span,
+        )
+
+        sequence_output = outputs[0][:, 0, :]
+        logits = self.classifier(sequence_output)
+
+        loss = None
+        if labels is not None:
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(logits, labels)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )

+ 32 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py

@@ -0,0 +1,32 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for LayoutLMv3, refer to RoBERTa."""
+
+from transformers.models.roberta import RobertaTokenizer
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+    "vocab_file": "vocab.json",
+    "merges_file": "merges.txt",
+}
+
+class LayoutLMv3Tokenizer(RobertaTokenizer):
+    vocab_files_names = VOCAB_FILES_NAMES
+    # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+    # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+    model_input_names = ["input_ids", "attention_mask"]

+ 34 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py

@@ -0,0 +1,34 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Tokenization classes for LayoutLMv3, refer to RoBERTa."""
+
+
+from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
+from transformers.utils import logging
+
+from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class LayoutLMv3TokenizerFast(RobertaTokenizerFast):
+    vocab_files_names = VOCAB_FILES_NAMES
+    # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+    # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+    model_input_names = ["input_ids", "attention_mask"]
+    slow_tokenizer_class = LayoutLMv3Tokenizer

+ 141 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py

@@ -0,0 +1,141 @@
+from .visualizer import Visualizer
+from .rcnn_vl import *
+from .backbone import *
+
+from detectron2.config import get_cfg
+from detectron2.config import CfgNode as CN
+from detectron2.data import MetadataCatalog, DatasetCatalog
+from detectron2.data.datasets import register_coco_instances
+from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor
+
+def add_vit_config(cfg):
+    """
+    Add config for VIT.
+    """
+    _C = cfg
+
+    _C.MODEL.VIT = CN()
+
+    # CoaT model name.
+    _C.MODEL.VIT.NAME = ""
+
+    # Output features from CoaT backbone.
+    _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
+
+    _C.MODEL.VIT.IMG_SIZE = [224, 224]
+
+    _C.MODEL.VIT.POS_TYPE = "shared_rel"
+
+    _C.MODEL.VIT.DROP_PATH = 0.
+
+    _C.MODEL.VIT.MODEL_KWARGS = "{}"
+
+    _C.SOLVER.OPTIMIZER = "ADAMW"
+
+    _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
+
+    _C.AUG = CN()
+
+    _C.AUG.DETR = False
+
+    _C.MODEL.IMAGE_ONLY = True
+    _C.PUBLAYNET_DATA_DIR_TRAIN = ""
+    _C.PUBLAYNET_DATA_DIR_TEST = ""
+    _C.FOOTNOTE_DATA_DIR_TRAIN = ""
+    _C.FOOTNOTE_DATA_DIR_VAL = ""
+    _C.SCIHUB_DATA_DIR_TRAIN = ""
+    _C.SCIHUB_DATA_DIR_TEST = ""
+    _C.JIAOCAI_DATA_DIR_TRAIN = ""
+    _C.JIAOCAI_DATA_DIR_TEST = ""
+    _C.ICDAR_DATA_DIR_TRAIN = ""
+    _C.ICDAR_DATA_DIR_TEST = ""
+    _C.M6DOC_DATA_DIR_TEST = ""
+    _C.DOCSTRUCTBENCH_DATA_DIR_TEST = ""
+    _C.DOCSTRUCTBENCHv2_DATA_DIR_TEST = ""
+    _C.CACHE_DIR = ""
+    _C.MODEL.CONFIG_PATH = ""
+
+    # effective update steps would be MAX_ITER/GRADIENT_ACCUMULATION_STEPS
+    # maybe need to set MAX_ITER *= GRADIENT_ACCUMULATION_STEPS
+    _C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
+
+
+def setup(args):
+    """
+    Create configs and perform basic setups.
+    """
+    cfg = get_cfg()
+    # add_coat_config(cfg)
+    add_vit_config(cfg)
+    cfg.merge_from_file(args.config_file)
+    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2  # set threshold for this model
+    cfg.merge_from_list(args.opts)
+    cfg.freeze()
+    default_setup(cfg, args)
+    
+    register_coco_instances(
+        "scihub_train",
+        {},
+        cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
+        cfg.SCIHUB_DATA_DIR_TRAIN
+    )
+    
+    return cfg
+
+
+class DotDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(DotDict, self).__init__(*args, **kwargs)
+
+    def __getattr__(self, key):
+        if key not in self.keys():
+            return None
+        value = self[key]
+        if isinstance(value, dict):
+            value = DotDict(value)
+        return value
+    
+    def __setattr__(self, key, value):
+        self[key] = value
+        
+class Layoutlmv3_Predictor(object):
+    def __init__(self, weights, config_file):
+        layout_args = {
+            "config_file": config_file,
+            "resume": False,
+            "eval_only": False,
+            "num_gpus": 1,
+            "num_machines": 1,
+            "machine_rank": 0,
+            "dist_url": "tcp://127.0.0.1:57823",
+            "opts": ["MODEL.WEIGHTS", weights],
+        }
+        layout_args = DotDict(layout_args)
+
+        cfg = setup(layout_args)
+        self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption", "table_footnote", "isolate_formula", "formula_caption"]
+        MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
+        self.predictor = DefaultPredictor(cfg)
+        
+    def __call__(self, image, ignore_catids=[]):
+        page_layout_result = {
+            "layout_dets": []
+        }
+        outputs = self.predictor(image)
+        boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist()
+        labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist()
+        scores = outputs["instances"].to("cpu")._fields["scores"].tolist()
+        for bbox_idx in range(len(boxes)):
+            if labels[bbox_idx] in ignore_catids:
+                continue
+            page_layout_result["layout_dets"].append({
+                "category_id": labels[bbox_idx],
+                "poly": [
+                    boxes[bbox_idx][0], boxes[bbox_idx][1],
+                    boxes[bbox_idx][2], boxes[bbox_idx][1],
+                    boxes[bbox_idx][2], boxes[bbox_idx][3],
+                    boxes[bbox_idx][0], boxes[bbox_idx][3],
+                ],
+                "score": scores[bbox_idx]
+            })
+        return page_layout_result

+ 163 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py

@@ -0,0 +1,163 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+from typing import Dict, List, Optional, Tuple
+import torch
+from torch import nn
+
+from detectron2.config import configurable
+from detectron2.structures import ImageList, Instances
+from detectron2.utils.events import get_event_storage
+
+from detectron2.modeling.backbone import Backbone, build_backbone
+from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
+
+from detectron2.modeling.meta_arch import GeneralizedRCNN
+
+from detectron2.modeling.postprocessing import detector_postprocess
+from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image
+from contextlib import contextmanager
+from itertools import count
+
+@META_ARCH_REGISTRY.register()
+class VLGeneralizedRCNN(GeneralizedRCNN):
+    """
+    Generalized R-CNN. Any models that contains the following three components:
+    1. Per-image feature extraction (aka backbone)
+    2. Region proposal generation
+    3. Per-region feature extraction and prediction
+    """
+
+    def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
+        """
+        Args:
+            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
+                Each item in the list contains the inputs for one image.
+                For now, each item in the list is a dict that contains:
+
+                * image: Tensor, image in (C, H, W) format.
+                * instances (optional): groundtruth :class:`Instances`
+                * proposals (optional): :class:`Instances`, precomputed proposals.
+
+                Other information that's included in the original dicts, such as:
+
+                * "height", "width" (int): the output resolution of the model, used in inference.
+                  See :meth:`postprocess` for details.
+
+        Returns:
+            list[dict]:
+                Each dict is the output for one input image.
+                The dict contains one key "instances" whose value is a :class:`Instances`.
+                The :class:`Instances` object has the following keys:
+                "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
+        """
+        if not self.training:
+            return self.inference(batched_inputs)
+
+        images = self.preprocess_image(batched_inputs)
+        if "instances" in batched_inputs[0]:
+            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
+        else:
+            gt_instances = None
+
+        # features = self.backbone(images.tensor)
+        input = self.get_batch(batched_inputs, images)
+        features = self.backbone(input)
+
+        if self.proposal_generator is not None:
+            proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
+        else:
+            assert "proposals" in batched_inputs[0]
+            proposals = [x["proposals"].to(self.device) for x in batched_inputs]
+            proposal_losses = {}
+
+        _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
+        if self.vis_period > 0:
+            storage = get_event_storage()
+            if storage.iter % self.vis_period == 0:
+                self.visualize_training(batched_inputs, proposals)
+
+        losses = {}
+        losses.update(detector_losses)
+        losses.update(proposal_losses)
+        return losses
+
+    def inference(
+        self,
+        batched_inputs: List[Dict[str, torch.Tensor]],
+        detected_instances: Optional[List[Instances]] = None,
+        do_postprocess: bool = True,
+    ):
+        """
+        Run inference on the given inputs.
+
+        Args:
+            batched_inputs (list[dict]): same as in :meth:`forward`
+            detected_instances (None or list[Instances]): if not None, it
+                contains an `Instances` object per image. The `Instances`
+                object contains "pred_boxes" and "pred_classes" which are
+                known boxes in the image.
+                The inference will then skip the detection of bounding boxes,
+                and only predict other per-ROI outputs.
+            do_postprocess (bool): whether to apply post-processing on the outputs.
+
+        Returns:
+            When do_postprocess=True, same as in :meth:`forward`.
+            Otherwise, a list[Instances] containing raw network outputs.
+        """
+        assert not self.training
+
+        images = self.preprocess_image(batched_inputs)
+        # features = self.backbone(images.tensor)
+        input = self.get_batch(batched_inputs, images)
+        features = self.backbone(input)
+
+        if detected_instances is None:
+            if self.proposal_generator is not None:
+                proposals, _ = self.proposal_generator(images, features, None)
+            else:
+                assert "proposals" in batched_inputs[0]
+                proposals = [x["proposals"].to(self.device) for x in batched_inputs]
+
+            results, _ = self.roi_heads(images, features, proposals, None)
+        else:
+            detected_instances = [x.to(self.device) for x in detected_instances]
+            results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
+
+        if do_postprocess:
+            assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
+            return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
+        else:
+            return results
+
+    def get_batch(self, examples, images):
+        if len(examples) >= 1 and "bbox" not in examples[0]:  # image_only
+            return {"images": images.tensor}
+
+        return input
+
+    def _batch_inference(self, batched_inputs, detected_instances=None):
+        """
+        Execute inference on a list of inputs,
+        using batch size = self.batch_size (e.g., 2), instead of the length of the list.
+
+        Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`
+        """
+        if detected_instances is None:
+            detected_instances = [None] * len(batched_inputs)
+
+        outputs = []
+        inputs, instances = [], []
+        for idx, input, instance in zip(count(), batched_inputs, detected_instances):
+            inputs.append(input)
+            instances.append(instance)
+            if len(inputs) == 2 or idx == len(batched_inputs) - 1:
+                outputs.extend(
+                    self.inference(
+                        inputs,
+                        instances if instances[0] is not None else None,
+                        do_postprocess=True,  # False
+                    )
+                )
+                inputs, instances = [], []
+        return outputs

+ 1236 - 0
magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py

@@ -0,0 +1,1236 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import colorsys
+import logging
+import math
+import numpy as np
+from enum import Enum, unique
+import cv2
+import matplotlib as mpl
+import matplotlib.colors as mplc
+import matplotlib.figure as mplfigure
+import pycocotools.mask as mask_util
+import torch
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+from PIL import Image
+
+from detectron2.data import MetadataCatalog
+from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
+from detectron2.utils.file_io import PathManager
+
+from detectron2.utils.colormap import random_color
+
+import pdb
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["ColorMode", "VisImage", "Visualizer"]
+
+
+_SMALL_OBJECT_AREA_THRESH = 1000
+_LARGE_MASK_AREA_THRESH = 120000
+_OFF_WHITE = (1.0, 1.0, 240.0 / 255)
+_BLACK = (0, 0, 0)
+_RED = (1.0, 0, 0)
+
+_KEYPOINT_THRESHOLD = 0.05
+
+#CLASS_NAMES = ["footnote", "footer", "header"]
+
+@unique
+class ColorMode(Enum):
+    """
+    Enum of different color modes to use for instance visualizations.
+    """
+
+    IMAGE = 0
+    """
+    Picks a random color for every instance and overlay segmentations with low opacity.
+    """
+    SEGMENTATION = 1
+    """
+    Let instances of the same category have similar colors
+    (from metadata.thing_colors), and overlay them with
+    high opacity. This provides more attention on the quality of segmentation.
+    """
+    IMAGE_BW = 2
+    """
+    Same as IMAGE, but convert all areas without masks to gray-scale.
+    Only available for drawing per-instance mask predictions.
+    """
+
+
+class GenericMask:
+    """
+    Attribute:
+        polygons (list[ndarray]): list[ndarray]: polygons for this mask.
+            Each ndarray has format [x, y, x, y, ...]
+        mask (ndarray): a binary mask
+    """
+
+    def __init__(self, mask_or_polygons, height, width):
+        self._mask = self._polygons = self._has_holes = None
+        self.height = height
+        self.width = width
+
+        m = mask_or_polygons
+        if isinstance(m, dict):
+            # RLEs
+            assert "counts" in m and "size" in m
+            if isinstance(m["counts"], list):  # uncompressed RLEs
+                h, w = m["size"]
+                assert h == height and w == width
+                m = mask_util.frPyObjects(m, h, w)
+            self._mask = mask_util.decode(m)[:, :]
+            return
+
+        if isinstance(m, list):  # list[ndarray]
+            self._polygons = [np.asarray(x).reshape(-1) for x in m]
+            return
+
+        if isinstance(m, np.ndarray):  # assumed to be a binary mask
+            assert m.shape[1] != 2, m.shape
+            assert m.shape == (
+                height,
+                width,
+            ), f"mask shape: {m.shape}, target dims: {height}, {width}"
+            self._mask = m.astype("uint8")
+            return
+
+        raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
+
+    @property
+    def mask(self):
+        if self._mask is None:
+            self._mask = self.polygons_to_mask(self._polygons)
+        return self._mask
+
+    @property
+    def polygons(self):
+        if self._polygons is None:
+            self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
+        return self._polygons
+
+    @property
+    def has_holes(self):
+        if self._has_holes is None:
+            if self._mask is not None:
+                self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
+            else:
+                self._has_holes = False  # if original format is polygon, does not have holes
+        return self._has_holes
+
+    def mask_to_polygons(self, mask):
+        # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
+        # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
+        # Internal contours (holes) are placed in hierarchy-2.
+        # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
+        mask = np.ascontiguousarray(mask)  # some versions of cv2 does not support incontiguous arr
+        res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+        hierarchy = res[-1]
+        if hierarchy is None:  # empty mask
+            return [], False
+        has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
+        res = res[-2]
+        res = [x.flatten() for x in res]
+        # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
+        # We add 0.5 to turn them into real-value coordinate space. A better solution
+        # would be to first +0.5 and then dilate the returned polygon by 0.5.
+        res = [x + 0.5 for x in res if len(x) >= 6]
+        return res, has_holes
+
+    def polygons_to_mask(self, polygons):
+        rle = mask_util.frPyObjects(polygons, self.height, self.width)
+        rle = mask_util.merge(rle)
+        return mask_util.decode(rle)[:, :]
+
+    def area(self):
+        return self.mask.sum()
+
+    def bbox(self):
+        p = mask_util.frPyObjects(self.polygons, self.height, self.width)
+        p = mask_util.merge(p)
+        bbox = mask_util.toBbox(p)
+        bbox[2] += bbox[0]
+        bbox[3] += bbox[1]
+        return bbox
+
+
+class _PanopticPrediction:
+    """
+    Unify different panoptic annotation/prediction formats
+    """
+
+    def __init__(self, panoptic_seg, segments_info, metadata=None):
+        if segments_info is None:
+            assert metadata is not None
+            # If "segments_info" is None, we assume "panoptic_img" is a
+            # H*W int32 image storing the panoptic_id in the format of
+            # category_id * label_divisor + instance_id. We reserve -1 for
+            # VOID label.
+            label_divisor = metadata.label_divisor
+            segments_info = []
+            for panoptic_label in np.unique(panoptic_seg.numpy()):
+                if panoptic_label == -1:
+                    # VOID region.
+                    continue
+                pred_class = panoptic_label // label_divisor
+                isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
+                segments_info.append(
+                    {
+                        "id": int(panoptic_label),
+                        "category_id": int(pred_class),
+                        "isthing": bool(isthing),
+                    }
+                )
+        del metadata
+
+        self._seg = panoptic_seg
+
+        self._sinfo = {s["id"]: s for s in segments_info}  # seg id -> seg info
+        segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
+        areas = areas.numpy()
+        sorted_idxs = np.argsort(-areas)
+        self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
+        self._seg_ids = self._seg_ids.tolist()
+        for sid, area in zip(self._seg_ids, self._seg_areas):
+            if sid in self._sinfo:
+                self._sinfo[sid]["area"] = float(area)
+
+    def non_empty_mask(self):
+        """
+        Returns:
+            (H, W) array, a mask for all pixels that have a prediction
+        """
+        empty_ids = []
+        for id in self._seg_ids:
+            if id not in self._sinfo:
+                empty_ids.append(id)
+        if len(empty_ids) == 0:
+            return np.zeros(self._seg.shape, dtype=np.uint8)
+        assert (
+            len(empty_ids) == 1
+        ), ">1 ids corresponds to no labels. This is currently not supported"
+        return (self._seg != empty_ids[0]).numpy().astype(np.bool)
+
+    def semantic_masks(self):
+        for sid in self._seg_ids:
+            sinfo = self._sinfo.get(sid)
+            if sinfo is None or sinfo["isthing"]:
+                # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
+                continue
+            yield (self._seg == sid).numpy().astype(np.bool), sinfo
+
+    def instance_masks(self):
+        for sid in self._seg_ids:
+            sinfo = self._sinfo.get(sid)
+            if sinfo is None or not sinfo["isthing"]:
+                continue
+            mask = (self._seg == sid).numpy().astype(np.bool)
+            if mask.sum() > 0:
+                yield mask, sinfo
+
+
+def _create_text_labels(classes, scores, class_names, is_crowd=None):
+    """
+    Args:
+        classes (list[int] or None):
+        scores (list[float] or None):
+        class_names (list[str] or None):
+        is_crowd (list[bool] or None):
+
+    Returns:
+        list[str] or None
+    """
+    #class_names = CLASS_NAMES
+    labels = None
+    if classes is not None:
+        if class_names is not None and len(class_names) > 0:
+            labels = [class_names[i] for i in classes]
+        else:
+            labels = [str(i) for i in classes]
+            
+    if scores is not None:
+        if labels is None:
+            labels = ["{:.0f}%".format(s * 100) for s in scores]
+        else:
+            labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
+    if labels is not None and is_crowd is not None:
+        labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
+    return labels
+
+
+class VisImage:
+    def __init__(self, img, scale=1.0):
+        """
+        Args:
+            img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
+            scale (float): scale the input image
+        """
+        self.img = img
+        self.scale = scale
+        self.width, self.height = img.shape[1], img.shape[0]
+        self._setup_figure(img)
+
+    def _setup_figure(self, img):
+        """
+        Args:
+            Same as in :meth:`__init__()`.
+
+        Returns:
+            fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
+            ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
+        """
+        fig = mplfigure.Figure(frameon=False)
+        self.dpi = fig.get_dpi()
+        # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
+        # (https://github.com/matplotlib/matplotlib/issues/15363)
+        fig.set_size_inches(
+            (self.width * self.scale + 1e-2) / self.dpi,
+            (self.height * self.scale + 1e-2) / self.dpi,
+        )
+        self.canvas = FigureCanvasAgg(fig)
+        # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
+        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
+        ax.axis("off")
+        self.fig = fig
+        self.ax = ax
+        self.reset_image(img)
+
+    def reset_image(self, img):
+        """
+        Args:
+            img: same as in __init__
+        """
+        img = img.astype("uint8")
+        self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
+
+    def save(self, filepath):
+        """
+        Args:
+            filepath (str): a string that contains the absolute path, including the file name, where
+                the visualized image will be saved.
+        """
+        self.fig.savefig(filepath)
+
+    def get_image(self):
+        """
+        Returns:
+            ndarray:
+                the visualized image of shape (H, W, 3) (RGB) in uint8 type.
+                The shape is scaled w.r.t the input image using the given `scale` argument.
+        """
+        canvas = self.canvas
+        s, (width, height) = canvas.print_to_buffer()
+        # buf = io.BytesIO()  # works for cairo backend
+        # canvas.print_rgba(buf)
+        # width, height = self.width, self.height
+        # s = buf.getvalue()
+
+        buffer = np.frombuffer(s, dtype="uint8")
+
+        img_rgba = buffer.reshape(height, width, 4)
+        rgb, alpha = np.split(img_rgba, [3], axis=2)
+        return rgb.astype("uint8")
+
+
+class Visualizer:
+    """
+    Visualizer that draws data about detection/segmentation on images.
+
+    It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
+    that draw primitive objects to images, as well as high-level wrappers like
+    `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
+    that draw composite data in some pre-defined style.
+
+    Note that the exact visualization style for the high-level wrappers are subject to change.
+    Style such as color, opacity, label contents, visibility of labels, or even the visibility
+    of objects themselves (e.g. when the object is too small) may change according
+    to different heuristics, as long as the results still look visually reasonable.
+
+    To obtain a consistent style, you can implement custom drawing functions with the
+    abovementioned primitive methods instead. If you need more customized visualization
+    styles, you can process the data yourself following their format documented in
+    tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
+    intend to satisfy everyone's preference on drawing styles.
+
+    This visualizer focuses on high rendering quality rather than performance. It is not
+    designed to be used for real-time applications.
+    """
+
+    # TODO implement a fast, rasterized version using OpenCV
+
+    def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
+        """
+        Args:
+            img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
+                the height and width of the image respectively. C is the number of
+                color channels. The image is required to be in RGB format since that
+                is a requirement of the Matplotlib library. The image is also expected
+                to be in the range [0, 255].
+            metadata (Metadata): dataset metadata (e.g. class names and colors)
+            instance_mode (ColorMode): defines one of the pre-defined style for drawing
+                instances on an image.
+        """
+        self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
+        if metadata is None:
+            metadata = MetadataCatalog.get("__nonexist__")
+        self.metadata = metadata
+        self.output = VisImage(self.img, scale=scale)
+        self.cpu_device = torch.device("cpu")
+
+        # too small texts are useless, therefore clamp to 9
+        self._default_font_size = max(
+            np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
+        )
+        self._instance_mode = instance_mode
+        self.keypoint_threshold = _KEYPOINT_THRESHOLD
+
+    def draw_instance_predictions(self, predictions):
+        """
+        Draw instance-level prediction results on an image.
+
+        Args:
+            predictions (Instances): the output of an instance detection/segmentation
+                model. Following fields will be used to draw:
+                "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
+        scores = predictions.scores if predictions.has("scores") else None
+        classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
+        labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None))
+        keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
+
+        if predictions.has("pred_masks"):
+            masks = np.asarray(predictions.pred_masks)
+            masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
+        else:
+            masks = None
+
+        if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
+            colors = [
+                self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
+            ]
+            alpha = 0.8
+        else:
+            colors = None
+            alpha = 0.5
+
+        if self._instance_mode == ColorMode.IMAGE_BW:
+            self.output.reset_image(
+                self._create_grayscale_image(
+                    (predictions.pred_masks.any(dim=0) > 0).numpy()
+                    if predictions.has("pred_masks")
+                    else None
+                )
+            )
+            alpha = 0.3
+
+        self.overlay_instances(
+            masks=masks,
+            boxes=boxes,
+            labels=labels,
+            keypoints=keypoints,
+            assigned_colors=colors,
+            alpha=alpha,
+        )
+        return self.output
+
+    def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
+        """
+        Draw semantic segmentation predictions/labels.
+
+        Args:
+            sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
+                Each value is the integer label of the pixel.
+            area_threshold (int): segments with less than `area_threshold` are not drawn.
+            alpha (float): the larger it is, the more opaque the segmentations are.
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        if isinstance(sem_seg, torch.Tensor):
+            sem_seg = sem_seg.numpy()
+        labels, areas = np.unique(sem_seg, return_counts=True)
+        sorted_idxs = np.argsort(-areas).tolist()
+        labels = labels[sorted_idxs]
+        for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
+            try:
+                mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
+            except (AttributeError, IndexError):
+                mask_color = None
+
+            binary_mask = (sem_seg == label).astype(np.uint8)
+            text = self.metadata.stuff_classes[label]
+            self.draw_binary_mask(
+                binary_mask,
+                color=mask_color,
+                edge_color=_OFF_WHITE,
+                text=text,
+                alpha=alpha,
+                area_threshold=area_threshold,
+            )
+        return self.output
+
+    def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7):
+        """
+        Draw panoptic prediction annotations or results.
+
+        Args:
+            panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
+                segment.
+            segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
+                If it is a ``list[dict]``, each dict contains keys "id", "category_id".
+                If None, category id of each pixel is computed by
+                ``pixel // metadata.label_divisor``.
+            area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
+
+        if self._instance_mode == ColorMode.IMAGE_BW:
+            self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
+
+        # draw mask for all semantic segments first i.e. "stuff"
+        for mask, sinfo in pred.semantic_masks():
+            category_idx = sinfo["category_id"]
+            try:
+                mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
+            except AttributeError:
+                mask_color = None
+
+            text = self.metadata.stuff_classes[category_idx]
+            self.draw_binary_mask(
+                mask,
+                color=mask_color,
+                edge_color=_OFF_WHITE,
+                text=text,
+                alpha=alpha,
+                area_threshold=area_threshold,
+            )
+
+        # draw mask for all instances second
+        all_instances = list(pred.instance_masks())
+        if len(all_instances) == 0:
+            return self.output
+        masks, sinfo = list(zip(*all_instances))
+        category_ids = [x["category_id"] for x in sinfo]
+
+        try:
+            scores = [x["score"] for x in sinfo]
+        except KeyError:
+            scores = None
+        labels = _create_text_labels(
+            category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo]
+        )
+
+        try:
+            colors = [
+                self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids
+            ]
+        except AttributeError:
+            colors = None
+        self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
+
+        return self.output
+
+    draw_panoptic_seg_predictions = draw_panoptic_seg  # backward compatibility
+
+    def draw_dataset_dict(self, dic):
+        """
+        Draw annotations/segmentaions in Detectron2 Dataset format.
+
+        Args:
+            dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        annos = dic.get("annotations", None)
+        if annos:
+            if "segmentation" in annos[0]:
+                masks = [x["segmentation"] for x in annos]
+            else:
+                masks = None
+            if "keypoints" in annos[0]:
+                keypts = [x["keypoints"] for x in annos]
+                keypts = np.array(keypts).reshape(len(annos), -1, 3)
+            else:
+                keypts = None
+
+            boxes = [
+                BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
+                if len(x["bbox"]) == 4
+                else x["bbox"]
+                for x in annos
+            ]
+
+            colors = None
+            category_ids = [x["category_id"] for x in annos]
+            if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
+                colors = [
+                    self._jitter([x / 255 for x in self.metadata.thing_colors[c]])
+                    for c in category_ids
+                ]
+            names = self.metadata.get("thing_classes", None)
+            labels = _create_text_labels(
+                category_ids,
+                scores=None,
+                class_names=names,
+                is_crowd=[x.get("iscrowd", 0) for x in annos],
+            )
+            self.overlay_instances(
+                labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
+            )
+
+        sem_seg = dic.get("sem_seg", None)
+        if sem_seg is None and "sem_seg_file_name" in dic:
+            with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
+                sem_seg = Image.open(f)
+                sem_seg = np.asarray(sem_seg, dtype="uint8")
+        if sem_seg is not None:
+            self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
+
+        pan_seg = dic.get("pan_seg", None)
+        if pan_seg is None and "pan_seg_file_name" in dic:
+            with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
+                pan_seg = Image.open(f)
+                pan_seg = np.asarray(pan_seg)
+                from panopticapi.utils import rgb2id
+
+                pan_seg = rgb2id(pan_seg)
+        if pan_seg is not None:
+            segments_info = dic["segments_info"]
+            pan_seg = torch.tensor(pan_seg)
+            self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
+        return self.output
+
+    def overlay_instances(
+        self,
+        *,
+        boxes=None,
+        labels=None,
+        masks=None,
+        keypoints=None,
+        assigned_colors=None,
+        alpha=0.5,
+    ):
+        """
+        Args:
+            boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
+                or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
+                or a :class:`RotatedBoxes`,
+                or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
+                for the N objects in a single image,
+            labels (list[str]): the text to be displayed for each instance.
+            masks (masks-like object): Supported types are:
+
+                * :class:`detectron2.structures.PolygonMasks`,
+                  :class:`detectron2.structures.BitMasks`.
+                * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
+                  The first level of the list corresponds to individual instances. The second
+                  level to all the polygon that compose the instance, and the third level
+                  to the polygon coordinates. The third level should have the format of
+                  [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
+                * list[ndarray]: each ndarray is a binary mask of shape (H, W).
+                * list[dict]: each dict is a COCO-style RLE.
+            keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
+                where the N is the number of instances and K is the number of keypoints.
+                The last dimension corresponds to (x, y, visibility or score).
+            assigned_colors (list[matplotlib.colors]): a list of colors, where each color
+                corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
+                for full list of formats that the colors are accepted in.
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        num_instances = 0
+        if boxes is not None:
+            boxes = self._convert_boxes(boxes)
+            num_instances = len(boxes)
+        if masks is not None:
+            masks = self._convert_masks(masks)
+            if num_instances:
+                assert len(masks) == num_instances
+            else:
+                num_instances = len(masks)
+        if keypoints is not None:
+            if num_instances:
+                assert len(keypoints) == num_instances
+            else:
+                num_instances = len(keypoints)
+            keypoints = self._convert_keypoints(keypoints)
+        if labels is not None:
+            assert len(labels) == num_instances
+        if assigned_colors is None:
+            assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
+        if num_instances == 0:
+            return self.output
+        if boxes is not None and boxes.shape[1] == 5:
+            return self.overlay_rotated_instances(
+                boxes=boxes, labels=labels, assigned_colors=assigned_colors
+            )
+
+        # Display in largest to smallest order to reduce occlusion.
+        areas = None
+        if boxes is not None:
+            areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
+        elif masks is not None:
+            areas = np.asarray([x.area() for x in masks])
+
+        if areas is not None:
+            sorted_idxs = np.argsort(-areas).tolist()
+            # Re-order overlapped instances in descending order.
+            boxes = boxes[sorted_idxs] if boxes is not None else None
+            labels = [labels[k] for k in sorted_idxs] if labels is not None else None
+            masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
+            assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
+            keypoints = keypoints[sorted_idxs] if keypoints is not None else None
+
+        for i in range(num_instances):
+            color = assigned_colors[i]
+            if boxes is not None:
+                self.draw_box(boxes[i], edge_color=color)
+
+            if masks is not None:
+                for segment in masks[i].polygons:
+                    self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
+
+            if labels is not None:
+                # first get a box
+                if boxes is not None:
+                    x0, y0, x1, y1 = boxes[i]
+                    text_pos = (x0, y0)  # if drawing boxes, put text on the box corner.
+                    horiz_align = "left"
+                elif masks is not None:
+                    # skip small mask without polygon
+                    if len(masks[i].polygons) == 0:
+                        continue
+
+                    x0, y0, x1, y1 = masks[i].bbox()
+
+                    # draw text in the center (defined by median) when box is not drawn
+                    # median is less sensitive to outliers.
+                    text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
+                    horiz_align = "center"
+                else:
+                    continue  # drawing the box confidence for keypoints isn't very useful.
+                # for small objects, draw text at the side to avoid occlusion
+                instance_area = (y1 - y0) * (x1 - x0)
+                if (
+                    instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
+                    or y1 - y0 < 40 * self.output.scale
+                ):
+                    if y1 >= self.output.height - 5:
+                        text_pos = (x1, y0)
+                    else:
+                        text_pos = (x0, y1)
+
+                height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
+                lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
+                font_size = (
+                    np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
+                    * 0.5
+                    * self._default_font_size
+                )
+                self.draw_text(
+                    labels[i],
+                    text_pos,
+                    color=lighter_color,
+                    horizontal_alignment=horiz_align,
+                    font_size=font_size,
+                )
+
+        # draw keypoints
+        if keypoints is not None:
+            for keypoints_per_instance in keypoints:
+                self.draw_and_connect_keypoints(keypoints_per_instance)
+
+        return self.output
+
+    def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
+        """
+        Args:
+            boxes (ndarray): an Nx5 numpy array of
+                (x_center, y_center, width, height, angle_degrees) format
+                for the N objects in a single image.
+            labels (list[str]): the text to be displayed for each instance.
+            assigned_colors (list[matplotlib.colors]): a list of colors, where each color
+                corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
+                for full list of formats that the colors are accepted in.
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        num_instances = len(boxes)
+
+        if assigned_colors is None:
+            assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
+        if num_instances == 0:
+            return self.output
+
+        # Display in largest to smallest order to reduce occlusion.
+        if boxes is not None:
+            areas = boxes[:, 2] * boxes[:, 3]
+
+        sorted_idxs = np.argsort(-areas).tolist()
+        # Re-order overlapped instances in descending order.
+        boxes = boxes[sorted_idxs]
+        labels = [labels[k] for k in sorted_idxs] if labels is not None else None
+        colors = [assigned_colors[idx] for idx in sorted_idxs]
+
+        for i in range(num_instances):
+            self.draw_rotated_box_with_label(
+                boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
+            )
+
+        return self.output
+
+    def draw_and_connect_keypoints(self, keypoints):
+        """
+        Draws keypoints of an instance and follows the rules for keypoint connections
+        to draw lines between appropriate keypoints. This follows color heuristics for
+        line color.
+
+        Args:
+            keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
+                and the last dimension corresponds to (x, y, probability).
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        visible = {}
+        keypoint_names = self.metadata.get("keypoint_names")
+        for idx, keypoint in enumerate(keypoints):
+            # draw keypoint
+            x, y, prob = keypoint
+            if prob > self.keypoint_threshold:
+                self.draw_circle((x, y), color=_RED)
+                if keypoint_names:
+                    keypoint_name = keypoint_names[idx]
+                    visible[keypoint_name] = (x, y)
+
+        if self.metadata.get("keypoint_connection_rules"):
+            for kp0, kp1, color in self.metadata.keypoint_connection_rules:
+                if kp0 in visible and kp1 in visible:
+                    x0, y0 = visible[kp0]
+                    x1, y1 = visible[kp1]
+                    color = tuple(x / 255.0 for x in color)
+                    self.draw_line([x0, x1], [y0, y1], color=color)
+
+        # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
+        # Note that this strategy is specific to person keypoints.
+        # For other keypoints, it should just do nothing
+        try:
+            ls_x, ls_y = visible["left_shoulder"]
+            rs_x, rs_y = visible["right_shoulder"]
+            mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
+        except KeyError:
+            pass
+        else:
+            # draw line from nose to mid-shoulder
+            nose_x, nose_y = visible.get("nose", (None, None))
+            if nose_x is not None:
+                self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
+
+            try:
+                # draw line from mid-shoulder to mid-hip
+                lh_x, lh_y = visible["left_hip"]
+                rh_x, rh_y = visible["right_hip"]
+            except KeyError:
+                pass
+            else:
+                mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
+                self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
+        return self.output
+
+    """
+    Primitive drawing functions:
+    """
+
+    def draw_text(
+        self,
+        text,
+        position,
+        *,
+        font_size=None,
+        color="g",
+        horizontal_alignment="center",
+        rotation=0,
+    ):
+        """
+        Args:
+            text (str): class label
+            position (tuple): a tuple of the x and y coordinates to place text on image.
+            font_size (int, optional): font of the text. If not provided, a font size
+                proportional to the image width is calculated and used.
+            color: color of the text. Refer to `matplotlib.colors` for full list
+                of formats that are accepted.
+            horizontal_alignment (str): see `matplotlib.text.Text`
+            rotation: rotation angle in degrees CCW
+
+        Returns:
+            output (VisImage): image object with text drawn.
+        """
+        if not font_size:
+            font_size = self._default_font_size
+
+        # since the text background is dark, we don't want the text to be dark
+        color = np.maximum(list(mplc.to_rgb(color)), 0.2)
+        color[np.argmax(color)] = max(0.8, np.max(color))
+
+        x, y = position
+        self.output.ax.text(
+            x,
+            y,
+            text,
+            size=font_size * self.output.scale,
+            family="sans-serif",
+            bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
+            verticalalignment="top",
+            horizontalalignment=horizontal_alignment,
+            color=color,
+            zorder=10,
+            rotation=rotation,
+        )
+        return self.output
+
+    def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
+        """
+        Args:
+            box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
+                are the coordinates of the image's top left corner. x1 and y1 are the
+                coordinates of the image's bottom right corner.
+            alpha (float): blending efficient. Smaller values lead to more transparent masks.
+            edge_color: color of the outline of the box. Refer to `matplotlib.colors`
+                for full list of formats that are accepted.
+            line_style (string): the string to use to create the outline of the boxes.
+
+        Returns:
+            output (VisImage): image object with box drawn.
+        """
+        x0, y0, x1, y1 = box_coord
+        width = x1 - x0
+        height = y1 - y0
+
+        linewidth = max(self._default_font_size / 4, 1)
+
+        self.output.ax.add_patch(
+            mpl.patches.Rectangle(
+                (x0, y0),
+                width,
+                height,
+                fill=False,
+                edgecolor=edge_color,
+                linewidth=linewidth * self.output.scale,
+                alpha=alpha,
+                linestyle=line_style,
+            )
+        )
+        return self.output
+
+    def draw_rotated_box_with_label(
+        self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
+    ):
+        """
+        Draw a rotated box with label on its top-left corner.
+
+        Args:
+            rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
+                where cnt_x and cnt_y are the center coordinates of the box.
+                w and h are the width and height of the box. angle represents how
+                many degrees the box is rotated CCW with regard to the 0-degree box.
+            alpha (float): blending efficient. Smaller values lead to more transparent masks.
+            edge_color: color of the outline of the box. Refer to `matplotlib.colors`
+                for full list of formats that are accepted.
+            line_style (string): the string to use to create the outline of the boxes.
+            label (string): label for rotated box. It will not be rendered when set to None.
+
+        Returns:
+            output (VisImage): image object with box drawn.
+        """
+        cnt_x, cnt_y, w, h, angle = rotated_box
+        area = w * h
+        # use thinner lines when the box is small
+        linewidth = self._default_font_size / (
+            6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
+        )
+
+        theta = angle * math.pi / 180.0
+        c = math.cos(theta)
+        s = math.sin(theta)
+        rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
+        # x: left->right ; y: top->down
+        rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
+        for k in range(4):
+            j = (k + 1) % 4
+            self.draw_line(
+                [rotated_rect[k][0], rotated_rect[j][0]],
+                [rotated_rect[k][1], rotated_rect[j][1]],
+                color=edge_color,
+                linestyle="--" if k == 1 else line_style,
+                linewidth=linewidth,
+            )
+
+        if label is not None:
+            text_pos = rotated_rect[1]  # topleft corner
+
+            height_ratio = h / np.sqrt(self.output.height * self.output.width)
+            label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
+            font_size = (
+                np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
+            )
+            self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
+
+        return self.output
+
+    def draw_circle(self, circle_coord, color, radius=3):
+        """
+        Args:
+            circle_coord (list(int) or tuple(int)): contains the x and y coordinates
+                of the center of the circle.
+            color: color of the polygon. Refer to `matplotlib.colors` for a full list of
+                formats that are accepted.
+            radius (int): radius of the circle.
+
+        Returns:
+            output (VisImage): image object with box drawn.
+        """
+        x, y = circle_coord
+        self.output.ax.add_patch(
+            mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
+        )
+        return self.output
+
+    def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
+        """
+        Args:
+            x_data (list[int]): a list containing x values of all the points being drawn.
+                Length of list should match the length of y_data.
+            y_data (list[int]): a list containing y values of all the points being drawn.
+                Length of list should match the length of x_data.
+            color: color of the line. Refer to `matplotlib.colors` for a full list of
+                formats that are accepted.
+            linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
+                for a full list of formats that are accepted.
+            linewidth (float or None): width of the line. When it's None,
+                a default value will be computed and used.
+
+        Returns:
+            output (VisImage): image object with line drawn.
+        """
+        if linewidth is None:
+            linewidth = self._default_font_size / 3
+        linewidth = max(linewidth, 1)
+        self.output.ax.add_line(
+            mpl.lines.Line2D(
+                x_data,
+                y_data,
+                linewidth=linewidth * self.output.scale,
+                color=color,
+                linestyle=linestyle,
+            )
+        )
+        return self.output
+
+    def draw_binary_mask(
+        self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=0
+    ):
+        """
+        Args:
+            binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
+                W is the image width. Each value in the array is either a 0 or 1 value of uint8
+                type.
+            color: color of the mask. Refer to `matplotlib.colors` for a full list of
+                formats that are accepted. If None, will pick a random color.
+            edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
+                full list of formats that are accepted.
+            text (str): if None, will be drawn in the object's center of mass.
+            alpha (float): blending efficient. Smaller values lead to more transparent masks.
+            area_threshold (float): a connected component small than this will not be shown.
+
+        Returns:
+            output (VisImage): image object with mask drawn.
+        """
+        if color is None:
+            color = random_color(rgb=True, maximum=1)
+        color = mplc.to_rgb(color)
+
+        has_valid_segment = False
+        binary_mask = binary_mask.astype("uint8")  # opencv needs uint8
+        mask = GenericMask(binary_mask, self.output.height, self.output.width)
+        shape2d = (binary_mask.shape[0], binary_mask.shape[1])
+
+        if not mask.has_holes:
+            # draw polygons for regular masks
+            for segment in mask.polygons:
+                area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
+                if area < (area_threshold or 0):
+                    continue
+                has_valid_segment = True
+                segment = segment.reshape(-1, 2)
+                self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
+        else:
+            # TODO: Use Path/PathPatch to draw vector graphics:
+            # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
+            rgba = np.zeros(shape2d + (4,), dtype="float32")
+            rgba[:, :, :3] = color
+            rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
+            has_valid_segment = True
+            self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
+
+        if text is not None and has_valid_segment:
+            # TODO sometimes drawn on wrong objects. the heuristics here can improve.
+            lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
+            _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
+            largest_component_id = np.argmax(stats[1:, -1]) + 1
+
+            # draw text on the largest component, as well as other very large components.
+            for cid in range(1, _num_cc):
+                if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
+                    # median is more stable than centroid
+                    # center = centroids[largest_component_id]
+                    center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
+                    self.draw_text(text, center, color=lighter_color)
+        return self.output
+
+    def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
+        """
+        Args:
+            segment: numpy array of shape Nx2, containing all the points in the polygon.
+            color: color of the polygon. Refer to `matplotlib.colors` for a full list of
+                formats that are accepted.
+            edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
+                full list of formats that are accepted. If not provided, a darker shade
+                of the polygon color will be used instead.
+            alpha (float): blending efficient. Smaller values lead to more transparent masks.
+
+        Returns:
+            output (VisImage): image object with polygon drawn.
+        """
+        if edge_color is None:
+            # make edge color darker than the polygon color
+            if alpha > 0.8:
+                edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
+            else:
+                edge_color = color
+        edge_color = mplc.to_rgb(edge_color) + (1,)
+
+        polygon = mpl.patches.Polygon(
+            segment,
+            fill=True,
+            facecolor=mplc.to_rgb(color) + (alpha,),
+            edgecolor=edge_color,
+            linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
+        )
+        self.output.ax.add_patch(polygon)
+        return self.output
+
+    """
+    Internal methods:
+    """
+
+    def _jitter(self, color):
+        """
+        Randomly modifies given color to produce a slightly different color than the color given.
+
+        Args:
+            color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
+                picked. The values in the list are in the [0.0, 1.0] range.
+
+        Returns:
+            jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
+                color after being jittered. The values in the list are in the [0.0, 1.0] range.
+        """
+        color = mplc.to_rgb(color)
+        vec = np.random.rand(3)
+        # better to do it in another color space
+        vec = vec / np.linalg.norm(vec) * 0.5
+        res = np.clip(vec + color, 0, 1)
+        return tuple(res)
+
+    def _create_grayscale_image(self, mask=None):
+        """
+        Create a grayscale version of the original image.
+        The colors in masked area, if given, will be kept.
+        """
+        img_bw = self.img.astype("f4").mean(axis=2)
+        img_bw = np.stack([img_bw] * 3, axis=2)
+        if mask is not None:
+            img_bw[mask] = self.img[mask]
+        return img_bw
+
+    def _change_color_brightness(self, color, brightness_factor):
+        """
+        Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
+        less or more saturation than the original color.
+
+        Args:
+            color: color of the polygon. Refer to `matplotlib.colors` for a full list of
+                formats that are accepted.
+            brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
+                0 will correspond to no change, a factor in [-1.0, 0) range will result in
+                a darker color and a factor in (0, 1.0] range will result in a lighter color.
+
+        Returns:
+            modified_color (tuple[double]): a tuple containing the RGB values of the
+                modified color. Each value in the tuple is in the [0.0, 1.0] range.
+        """
+        assert brightness_factor >= -1.0 and brightness_factor <= 1.0
+        color = mplc.to_rgb(color)
+        polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
+        modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
+        modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
+        modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
+        modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
+        return modified_color
+
+    def _convert_boxes(self, boxes):
+        """
+        Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
+        """
+        if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
+            return boxes.tensor.detach().numpy()
+        else:
+            return np.asarray(boxes)
+
+    def _convert_masks(self, masks_or_polygons):
+        """
+        Convert different format of masks or polygons to a tuple of masks and polygons.
+
+        Returns:
+            list[GenericMask]:
+        """
+
+        m = masks_or_polygons
+        if isinstance(m, PolygonMasks):
+            m = m.polygons
+        if isinstance(m, BitMasks):
+            m = m.tensor.numpy()
+        if isinstance(m, torch.Tensor):
+            m = m.numpy()
+        ret = []
+        for x in m:
+            if isinstance(x, GenericMask):
+                ret.append(x)
+            else:
+                ret.append(GenericMask(x, self.output.height, self.output.width))
+        return ret
+
+    def _convert_keypoints(self, keypoints):
+        if isinstance(keypoints, Keypoints):
+            keypoints = keypoints.tensor
+        keypoints = np.asarray(keypoints)
+        return keypoints
+
+    def get_output(self):
+        """
+        Returns:
+            output (VisImage): the image output containing the visualizations added
+            to the image.
+        """
+        return self.output

+ 36 - 0
magic_pdf/model/pek_sub_modules/post_process.py

@@ -0,0 +1,36 @@
+import re
+
+def layout_rm_equation(layout_res):
+    rm_idxs = []
+    for idx, ele in enumerate(layout_res['layout_dets']):
+        if ele['category_id'] == 10:
+            rm_idxs.append(idx)
+    
+    for idx in rm_idxs[::-1]:
+        del layout_res['layout_dets'][idx]
+    return layout_res
+
+
+def get_croped_image(image_pil, bbox):
+    x_min, y_min, x_max, y_max = bbox
+    croped_img = image_pil.crop((x_min, y_min, x_max, y_max))
+    return croped_img
+
+
+def latex_rm_whitespace(s: str):
+    """Remove unnecessary whitespace from LaTeX code.
+    """
+    text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
+    letter = '[a-zA-Z]'
+    noletter = '[\W_^\d]'
+    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
+    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+    news = s
+    while True:
+        s = news
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
+        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
+        news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
+        if news == s:
+            break
+    return s

+ 259 - 0
magic_pdf/model/pek_sub_modules/self_modify.py

@@ -0,0 +1,259 @@
+import time
+import copy
+import base64
+import cv2
+import numpy as np
+from io import BytesIO
+from PIL import Image
+
+from paddleocr import PaddleOCR
+from paddleocr.ppocr.utils.logging import get_logger
+from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
+from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
+logger = get_logger()
+
+def img_decode(content: bytes):
+    np_arr = np.frombuffer(content, dtype=np.uint8)
+    return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
+
+def check_img(img):
+    if isinstance(img, bytes):
+        img = img_decode(img)
+    if isinstance(img, str):
+        image_file = img
+        img, flag_gif, flag_pdf = check_and_read(image_file)
+        if not flag_gif and not flag_pdf:
+            with open(image_file, 'rb') as f:
+                img_str = f.read()
+                img = img_decode(img_str)
+            if img is None:
+                try:
+                    buf = BytesIO()
+                    image = BytesIO(img_str)
+                    im = Image.open(image)
+                    rgb = im.convert('RGB')
+                    rgb.save(buf, 'jpeg')
+                    buf.seek(0)
+                    image_bytes = buf.read()
+                    data_base64 = str(base64.b64encode(image_bytes),
+                                      encoding="utf-8")
+                    image_decode = base64.b64decode(data_base64)
+                    img_array = np.frombuffer(image_decode, np.uint8)
+                    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
+                except:
+                    logger.error("error in loading image:{}".format(image_file))
+                    return None
+        if img is None:
+            logger.error("error in loading image:{}".format(image_file))
+            return None
+    if isinstance(img, np.ndarray) and len(img.shape) == 2:
+        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+    return img
+
+def sorted_boxes(dt_boxes):
+    """
+    Sort text boxes in order from top to bottom, left to right
+    args:
+        dt_boxes(array):detected text boxes with shape [4, 2]
+    return:
+        sorted boxes(array) with shape [4, 2]
+    """
+    num_boxes = dt_boxes.shape[0]
+    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+    _boxes = list(sorted_boxes)
+
+    for i in range(num_boxes - 1):
+        for j in range(i, -1, -1):
+            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
+                    (_boxes[j + 1][0][0] < _boxes[j][0][0]):
+                tmp = _boxes[j]
+                _boxes[j] = _boxes[j + 1]
+                _boxes[j + 1] = tmp
+            else:
+                break
+    return _boxes
+
+
+def formula_in_text(mf_bbox, text_bbox):
+    x1, y1, x2, y2 = mf_bbox
+    x3, y3 = text_bbox[0]
+    x4, y4 = text_bbox[2]
+    left_box, right_box = None, None
+    same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2
+    if not same_line:
+        return False, left_box, right_box
+    else:
+        drop_origin = False
+        left_x = x1 - 1
+        right_x = x2 + 1
+        if x3 < x1 and x2 < x4:
+            drop_origin = True
+            left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
+            right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
+        if x3 < x1 and x1 <= x4 <= x2:
+            drop_origin = True
+            left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
+        if x1 <= x3 <= x2 and x2 < x4:
+            drop_origin = True
+            right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
+        if x1 <= x3 < x4 <= x2:
+            drop_origin = True
+        return drop_origin, left_box, right_box
+
+    
+def update_det_boxes(dt_boxes, mfdetrec_res):
+    new_dt_boxes = dt_boxes
+    for mf_box in mfdetrec_res:
+        flag, left_box, right_box = False, None, None
+        for idx, text_box in enumerate(new_dt_boxes):
+            ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box)
+            if ret:
+                new_dt_boxes.pop(idx)
+                if left_box is not None:
+                    new_dt_boxes.append(left_box)
+                if right_box is not None:
+                    new_dt_boxes.append(right_box)
+                break
+            
+    return new_dt_boxes
+
+class ModifiedPaddleOCR(PaddleOCR):
+    def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
+        """
+        OCR with PaddleOCR
+        args:
+            img: img for OCR, support ndarray, img_path and list or ndarray
+            det: use text detection or not. If False, only rec will be exec. Default is True
+            rec: use text recognition or not. If False, only det will be exec. Default is True
+            cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
+            bin: binarize image to black and white. Default is False.
+            inv: invert image colors. Default is False.
+            alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
+        """
+        assert isinstance(img, (np.ndarray, list, str, bytes))
+        if isinstance(img, list) and det == True:
+            logger.error('When input a list of images, det must be false')
+            exit(0)
+        if cls == True and self.use_angle_cls == False:
+            logger.warning(
+                'Since the angle classifier is not initialized, it will not be used during the forward process'
+            )
+
+        img = check_img(img)
+        # for infer pdf file
+        if isinstance(img, list):
+            if self.page_num > len(img) or self.page_num == 0:
+                self.page_num = len(img)
+            imgs = img[:self.page_num]
+        else:
+            imgs = [img]
+
+        def preprocess_image(_image):
+            _image = alpha_to_color(_image, alpha_color)
+            if inv:
+                _image = cv2.bitwise_not(_image)
+            if bin:
+                _image = binarize_img(_image)
+            return _image
+
+        if det and rec:
+            ocr_res = []
+            for idx, img in enumerate(imgs):
+                img = preprocess_image(img)
+                dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
+                if not dt_boxes and not rec_res:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [[box.tolist(), res]
+                           for box, res in zip(dt_boxes, rec_res)]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        elif det and not rec:
+            ocr_res = []
+            for idx, img in enumerate(imgs):
+                img = preprocess_image(img)
+                dt_boxes, elapse = self.text_detector(img)
+                if not dt_boxes:
+                    ocr_res.append(None)
+                    continue
+                tmp_res = [box.tolist() for box in dt_boxes]
+                ocr_res.append(tmp_res)
+            return ocr_res
+        else:
+            ocr_res = []
+            cls_res = []
+            for idx, img in enumerate(imgs):
+                if not isinstance(img, list):
+                    img = preprocess_image(img)
+                    img = [img]
+                if self.use_angle_cls and cls:
+                    img, cls_res_tmp, elapse = self.text_classifier(img)
+                    if not rec:
+                        cls_res.append(cls_res_tmp)
+                rec_res, elapse = self.text_recognizer(img)
+                ocr_res.append(rec_res)
+            if not rec:
+                return cls_res
+            return ocr_res
+        
+    def __call__(self, img, cls=True, mfd_res=None):
+        time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
+
+        if img is None:
+            logger.debug("no valid image provided")
+            return None, None, time_dict
+
+        start = time.time()
+        ori_im = img.copy()
+        dt_boxes, elapse = self.text_detector(img)
+        time_dict['det'] = elapse
+
+        if dt_boxes is None:
+            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
+            end = time.time()
+            time_dict['all'] = end - start
+            return None, None, time_dict
+        else:
+            logger.debug("dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), elapse))
+        img_crop_list = []
+
+        dt_boxes = sorted_boxes(dt_boxes)
+        if mfd_res:
+            bef = time.time()
+            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+            aft = time.time()
+            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
+                len(dt_boxes), aft-bef))
+
+        for bno in range(len(dt_boxes)):
+            tmp_box = copy.deepcopy(dt_boxes[bno])
+            if self.args.det_box_type == "quad":
+                img_crop = get_rotate_crop_image(ori_im, tmp_box)
+            else:
+                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
+            img_crop_list.append(img_crop)
+        if self.use_angle_cls and cls:
+            img_crop_list, angle_list, elapse = self.text_classifier(
+                img_crop_list)
+            time_dict['cls'] = elapse
+            logger.debug("cls num  : {}, elapsed : {}".format(
+                len(img_crop_list), elapse))
+
+        rec_res, elapse = self.text_recognizer(img_crop_list)
+        time_dict['rec'] = elapse
+        logger.debug("rec_res num  : {}, elapsed : {}".format(
+            len(rec_res), elapse))
+        if self.args.save_crop_res:
+            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
+                                   rec_res)
+        filter_boxes, filter_rec_res = [], []
+        for box, rec_result in zip(dt_boxes, rec_res):
+            text, score = rec_result
+            if score >= self.drop_score:
+                filter_boxes.append(box)
+                filter_rec_res.append(rec_result)
+        end = time.time()
+        time_dict['all'] = end - start
+        return filter_boxes, filter_rec_res, time_dict

+ 46 - 0
magic_pdf/resources/model_config/UniMERNet/demo.yaml

@@ -0,0 +1,46 @@
+model:
+  arch: unimernet
+  model_type: unimernet
+  model_config:
+    model_name: ./models
+    max_seq_len: 1024
+    length_aware: False
+  load_pretrained: True
+  pretrained: ./models/pytorch_model.bin
+  tokenizer_config:
+    path: ./models
+
+datasets:
+  formula_rec_eval:
+    vis_processor:
+      eval:
+        name: "formula_image_eval"
+        image_size:
+          - 192
+          - 672
+   
+run:
+  runner: runner_iter
+  task: unimernet_train
+
+  batch_size_train: 64
+  batch_size_eval: 64
+  num_workers: 1
+
+  iters_per_inner_epoch: 2000
+  max_iters: 60000
+
+  seed: 42
+  output_dir: "../output/demo"
+
+  evaluate: True
+  test_splits: [ "eval" ]
+
+  device: "cuda"
+  world_size: 1
+  dist_url: "env://"
+  distributed: True
+  distributed_type: ddp  # or fsdp when train llm
+
+  generate_cfg:
+    temperature: 0.0

+ 351 - 0
magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml

@@ -0,0 +1,351 @@
+AUG:
+  DETR: true
+CACHE_DIR: /mnt/localdata/users/yupanhuang/cache/huggingface
+CUDNN_BENCHMARK: false
+DATALOADER:
+  ASPECT_RATIO_GROUPING: true
+  FILTER_EMPTY_ANNOTATIONS: false
+  NUM_WORKERS: 4
+  REPEAT_THRESHOLD: 0.0
+  SAMPLER_TRAIN: TrainingSampler
+DATASETS:
+  PRECOMPUTED_PROPOSAL_TOPK_TEST: 1000
+  PRECOMPUTED_PROPOSAL_TOPK_TRAIN: 2000
+  PROPOSAL_FILES_TEST: []
+  PROPOSAL_FILES_TRAIN: []
+  TEST:
+  - scihub_train
+  TRAIN:
+  - scihub_train
+GLOBAL:
+  HACK: 1.0
+ICDAR_DATA_DIR_TEST: ''
+ICDAR_DATA_DIR_TRAIN: ''
+INPUT:
+  CROP:
+    ENABLED: true
+    SIZE:
+    - 384
+    - 600
+    TYPE: absolute_range
+  FORMAT: RGB
+  MASK_FORMAT: polygon
+  MAX_SIZE_TEST: 1333
+  MAX_SIZE_TRAIN: 1333
+  MIN_SIZE_TEST: 800
+  MIN_SIZE_TRAIN:
+  - 480
+  - 512
+  - 544
+  - 576
+  - 608
+  - 640
+  - 672
+  - 704
+  - 736
+  - 768
+  - 800
+  MIN_SIZE_TRAIN_SAMPLING: choice
+  RANDOM_FLIP: horizontal
+MODEL:
+  ANCHOR_GENERATOR:
+    ANGLES:
+    - - -90
+      - 0
+      - 90
+    ASPECT_RATIOS:
+    - - 0.5
+      - 1.0
+      - 2.0
+    NAME: DefaultAnchorGenerator
+    OFFSET: 0.0
+    SIZES:
+    - - 32
+    - - 64
+    - - 128
+    - - 256
+    - - 512
+  BACKBONE:
+    FREEZE_AT: 2
+    NAME: build_vit_fpn_backbone
+  CONFIG_PATH: ''
+  DEVICE: cpu
+  FPN:
+    FUSE_TYPE: sum
+    IN_FEATURES:
+    - layer3
+    - layer5
+    - layer7
+    - layer11
+    NORM: ''
+    OUT_CHANNELS: 256
+  IMAGE_ONLY: true
+  KEYPOINT_ON: false
+  LOAD_PROPOSALS: false
+  MASK_ON: true
+  META_ARCHITECTURE: VLGeneralizedRCNN
+  PANOPTIC_FPN:
+    COMBINE:
+      ENABLED: true
+      INSTANCES_CONFIDENCE_THRESH: 0.5
+      OVERLAP_THRESH: 0.5
+      STUFF_AREA_LIMIT: 4096
+    INSTANCE_LOSS_WEIGHT: 1.0
+  PIXEL_MEAN:
+  - 127.5
+  - 127.5
+  - 127.5
+  PIXEL_STD:
+  - 127.5
+  - 127.5
+  - 127.5
+  PROPOSAL_GENERATOR:
+    MIN_SIZE: 0
+    NAME: RPN
+  RESNETS:
+    DEFORM_MODULATED: false
+    DEFORM_NUM_GROUPS: 1
+    DEFORM_ON_PER_STAGE:
+    - false
+    - false
+    - false
+    - false
+    DEPTH: 50
+    NORM: FrozenBN
+    NUM_GROUPS: 1
+    OUT_FEATURES:
+    - res4
+    RES2_OUT_CHANNELS: 256
+    RES5_DILATION: 1
+    STEM_OUT_CHANNELS: 64
+    STRIDE_IN_1X1: true
+    WIDTH_PER_GROUP: 64
+  RETINANET:
+    BBOX_REG_LOSS_TYPE: smooth_l1
+    BBOX_REG_WEIGHTS:
+    - 1.0
+    - 1.0
+    - 1.0
+    - 1.0
+    FOCAL_LOSS_ALPHA: 0.25
+    FOCAL_LOSS_GAMMA: 2.0
+    IN_FEATURES:
+    - p3
+    - p4
+    - p5
+    - p6
+    - p7
+    IOU_LABELS:
+    - 0
+    - -1
+    - 1
+    IOU_THRESHOLDS:
+    - 0.4
+    - 0.5
+    NMS_THRESH_TEST: 0.5
+    NORM: ''
+    NUM_CLASSES: 10
+    NUM_CONVS: 4
+    PRIOR_PROB: 0.01
+    SCORE_THRESH_TEST: 0.05
+    SMOOTH_L1_LOSS_BETA: 0.1
+    TOPK_CANDIDATES_TEST: 1000
+  ROI_BOX_CASCADE_HEAD:
+    BBOX_REG_WEIGHTS:
+    - - 10.0
+      - 10.0
+      - 5.0
+      - 5.0
+    - - 20.0
+      - 20.0
+      - 10.0
+      - 10.0
+    - - 30.0
+      - 30.0
+      - 15.0
+      - 15.0
+    IOUS:
+    - 0.5
+    - 0.6
+    - 0.7
+  ROI_BOX_HEAD:
+    BBOX_REG_LOSS_TYPE: smooth_l1
+    BBOX_REG_LOSS_WEIGHT: 1.0
+    BBOX_REG_WEIGHTS:
+    - 10.0
+    - 10.0
+    - 5.0
+    - 5.0
+    CLS_AGNOSTIC_BBOX_REG: true
+    CONV_DIM: 256
+    FC_DIM: 1024
+    NAME: FastRCNNConvFCHead
+    NORM: ''
+    NUM_CONV: 0
+    NUM_FC: 2
+    POOLER_RESOLUTION: 7
+    POOLER_SAMPLING_RATIO: 0
+    POOLER_TYPE: ROIAlignV2
+    SMOOTH_L1_BETA: 0.0
+    TRAIN_ON_PRED_BOXES: false
+  ROI_HEADS:
+    BATCH_SIZE_PER_IMAGE: 512
+    IN_FEATURES:
+    - p2
+    - p3
+    - p4
+    - p5
+    IOU_LABELS:
+    - 0
+    - 1
+    IOU_THRESHOLDS:
+    - 0.5
+    NAME: CascadeROIHeads
+    NMS_THRESH_TEST: 0.5
+    NUM_CLASSES: 10
+    POSITIVE_FRACTION: 0.25
+    PROPOSAL_APPEND_GT: true
+    SCORE_THRESH_TEST: 0.05
+  ROI_KEYPOINT_HEAD:
+    CONV_DIMS:
+    - 512
+    - 512
+    - 512
+    - 512
+    - 512
+    - 512
+    - 512
+    - 512
+    LOSS_WEIGHT: 1.0
+    MIN_KEYPOINTS_PER_IMAGE: 1
+    NAME: KRCNNConvDeconvUpsampleHead
+    NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS: true
+    NUM_KEYPOINTS: 17
+    POOLER_RESOLUTION: 14
+    POOLER_SAMPLING_RATIO: 0
+    POOLER_TYPE: ROIAlignV2
+  ROI_MASK_HEAD:
+    CLS_AGNOSTIC_MASK: false
+    CONV_DIM: 256
+    NAME: MaskRCNNConvUpsampleHead
+    NORM: ''
+    NUM_CONV: 4
+    POOLER_RESOLUTION: 14
+    POOLER_SAMPLING_RATIO: 0
+    POOLER_TYPE: ROIAlignV2
+  RPN:
+    BATCH_SIZE_PER_IMAGE: 256
+    BBOX_REG_LOSS_TYPE: smooth_l1
+    BBOX_REG_LOSS_WEIGHT: 1.0
+    BBOX_REG_WEIGHTS:
+    - 1.0
+    - 1.0
+    - 1.0
+    - 1.0
+    BOUNDARY_THRESH: -1
+    CONV_DIMS:
+    - -1
+    HEAD_NAME: StandardRPNHead
+    IN_FEATURES:
+    - p2
+    - p3
+    - p4
+    - p5
+    - p6
+    IOU_LABELS:
+    - 0
+    - -1
+    - 1
+    IOU_THRESHOLDS:
+    - 0.3
+    - 0.7
+    LOSS_WEIGHT: 1.0
+    NMS_THRESH: 0.7
+    POSITIVE_FRACTION: 0.5
+    POST_NMS_TOPK_TEST: 1000
+    POST_NMS_TOPK_TRAIN: 2000
+    PRE_NMS_TOPK_TEST: 1000
+    PRE_NMS_TOPK_TRAIN: 2000
+    SMOOTH_L1_BETA: 0.0
+  SEM_SEG_HEAD:
+    COMMON_STRIDE: 4
+    CONVS_DIM: 128
+    IGNORE_VALUE: 255
+    IN_FEATURES:
+    - p2
+    - p3
+    - p4
+    - p5
+    LOSS_WEIGHT: 1.0
+    NAME: SemSegFPNHead
+    NORM: GN
+    NUM_CLASSES: 10
+  VIT:
+    DROP_PATH: 0.1
+    IMG_SIZE:
+    - 224
+    - 224
+    NAME: layoutlmv3_base
+    OUT_FEATURES:
+    - layer3
+    - layer5
+    - layer7
+    - layer11
+    POS_TYPE: abs
+  WEIGHTS: 
+OUTPUT_DIR: 
+SCIHUB_DATA_DIR_TRAIN: /mnt/petrelfs/share_data/zhaozhiyuan/publaynet/layout_scihub/train
+SEED: 42
+SOLVER:
+  AMP:
+    ENABLED: true
+  BACKBONE_MULTIPLIER: 1.0
+  BASE_LR: 0.0002
+  BIAS_LR_FACTOR: 1.0
+  CHECKPOINT_PERIOD: 2000
+  CLIP_GRADIENTS:
+    CLIP_TYPE: full_model
+    CLIP_VALUE: 1.0
+    ENABLED: true
+    NORM_TYPE: 2.0
+  GAMMA: 0.1
+  GRADIENT_ACCUMULATION_STEPS: 1
+  IMS_PER_BATCH: 32
+  LR_SCHEDULER_NAME: WarmupCosineLR
+  MAX_ITER: 20000
+  MOMENTUM: 0.9
+  NESTEROV: false
+  OPTIMIZER: ADAMW
+  REFERENCE_WORLD_SIZE: 0
+  STEPS:
+  - 10000
+  WARMUP_FACTOR: 0.01
+  WARMUP_ITERS: 333
+  WARMUP_METHOD: linear
+  WEIGHT_DECAY: 0.05
+  WEIGHT_DECAY_BIAS: null
+  WEIGHT_DECAY_NORM: 0.0
+TEST:
+  AUG:
+    ENABLED: false
+    FLIP: true
+    MAX_SIZE: 4000
+    MIN_SIZES:
+    - 400
+    - 500
+    - 600
+    - 700
+    - 800
+    - 900
+    - 1000
+    - 1100
+    - 1200
+  DETECTIONS_PER_IMAGE: 100
+  EVAL_PERIOD: 1000
+  EXPECTED_RESULTS: []
+  KEYPOINT_OKS_SIGMAS: []
+  PRECISE_BN:
+    ENABLED: false
+    NUM_ITER: 200
+VERSION: 2
+VIS_PERIOD: 0

+ 9 - 0
magic_pdf/resources/model_config/model_configs.yaml

@@ -0,0 +1,9 @@
+config:
+  device: cpu
+  layout: True
+  formula: True
+
+weights:
+  layout: resources/models/Layout/model_final.pth
+  mfd: resources/models/MFD/weights.pt
+  mfr: resources/models/MFR/UniMERNet