Ver Fonte

add chart2table (#3941)

Zhang Zelun há 6 meses atrás
pai
commit
cab88be119

+ 1 - 0
.precommit/check_imports.py

@@ -75,6 +75,7 @@ MOD_TO_DEP = {
     "shapely": "shapely",
     "soundfile": "soundfile",
     "starlette": "starlette",
+    "tiktoken": "tiktoken",
     "tokenizers": "tokenizers",
     "tqdm": "tqdm",
     "typing_extensions": "typing-extensions",

+ 33 - 10
paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py

@@ -18,14 +18,26 @@ from .base_batch_sampler import BaseBatchSampler
 
 
 class DocVLMBatchSampler(BaseBatchSampler):
-    def __init__(self):
+
+    model_names_only_supports_batchsize_of_one = {"PP-DocBee-2B", "PP-DocBee-7B"}
+
+    def __init__(self, model_name, batch_size: int = 1) -> None:
         """Initializes the BaseBatchSampler.
 
         Args:
+            model_name (str): The name of the model.
             batch_size (int, optional): The size of each batch. Only support 1.
         """
-        super().__init__()
-        self.batch_size = 1
+        self.model_name = model_name
+        if (
+            self.model_name in self.model_names_only_supports_batchsize_of_one
+            and batch_size != 1
+        ):
+            logging.warning(
+                f"doc vlm batch sampler only support batch size 1 for {self.model_name}, but got {batch_size} and it will not take effect."
+            )
+            batch_size = 1
+        super().__init__(batch_size)
 
     def sample(self, inputs):
         """Generate list of input file path.
@@ -37,14 +49,22 @@ class DocVLMBatchSampler(BaseBatchSampler):
             list: list of file path.
         """
         if isinstance(inputs, dict):
-            yield [inputs]
-        elif isinstance(inputs, list) and all(isinstance(i, dict) for i in inputs):
-            yield inputs
-        else:
+            inputs = [inputs]
+        if not (isinstance(inputs, list) and all(isinstance(i, dict) for i in inputs)):
             raise TypeError(
-                f"Not supported input data type! Only `dict` are supported, but got: {type(inputs)}."
+                f"Not supported input data type! Only `Dict` or `List[Dict]` are supported, but got: {type(inputs)}."
             )
 
+        batch = []
+        for input_ in inputs:
+            batch.append(input_)
+            if len(batch) == self.batch_size:
+                yield batch
+                batch = []
+
+        if len(batch) > 0:
+            yield batch
+
     @BaseBatchSampler.batch_size.setter
     def batch_size(self, batch_size):
         """Sets the batch size.
@@ -56,9 +76,12 @@ class DocVLMBatchSampler(BaseBatchSampler):
             Warning: If the batch size is not equal 1.
         """
         # only support batch size 1
-        if batch_size != 1:
+        if (
+            self.model_name in self.model_names_only_supports_batchsize_of_one
+            and batch_size != 1
+        ):
             logging.warning(
-                f"doc vlm batch sampler only support batch size 1, but got {batch_size}."
+                f"doc vlm batch sampler only support batch size 1 for {self.model_name}, but got {batch_size} and it will not take effect."
             )
         else:
             self._batch_size = batch_size

+ 1 - 0
paddlex/inference/models/common/tokenizer/__init__.py

@@ -16,4 +16,5 @@ from .bert_tokenizer import BertTokenizer
 from .clip_tokenizer import CLIPTokenizer
 from .gpt_tokenizer import GPTTokenizer
 from .qwen2_tokenizer import MIXQwen2Tokenizer, Qwen2Tokenizer
+from .qwen_tokenizer import QWenTokenizer
 from .tokenizer_utils import PretrainedTokenizer

+ 288 - 0
paddlex/inference/models/common/tokenizer/qwen_tokenizer.py

@@ -0,0 +1,288 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import importlib.util
+import os
+import unicodedata
+from typing import Collection, Dict, List, Set, Tuple, Union
+
+from .tokenizer_utils import PretrainedTokenizer
+from .tokenizer_utils_base import AddedToken
+
+__all__ = ["QWenTokenizer"]
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
+
+PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
+ENDOFTEXT = "<|endoftext|>"
+IMSTART = "<|im_start|>"
+IMEND = "<|im_end|>"
+# as the default behavior is changed to allow special tokens in
+# regular texts, the surface forms of special tokens need to be
+# as different as possible to minimize the impact
+EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
+SPECIAL_TOKENS = (
+    ENDOFTEXT,
+    IMSTART,
+    IMEND,
+) + EXTRAS
+
+tiktoken = None
+
+
+def is_tiktoken_available():
+    return importlib.util.find_spec("tiktoken") is not None
+
+
+def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
+    with open(tiktoken_bpe_file, "rb") as f:
+        contents = f.read()
+    return {
+        base64.b64decode(token): int(rank)
+        for token, rank in (line.split() for line in contents.splitlines() if line)
+    }
+
+
+class QWenTokenizer(PretrainedTokenizer):
+    """QWen tokenizer."""
+
+    model_input_names = ["input_ids", "attention_mask", "position_ids"]
+    resource_files_names = VOCAB_FILES_NAMES
+
+    def __init__(
+        self,
+        vocab_file,
+        errors="replace",
+        padding_side="left",
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        if not is_tiktoken_available():
+            raise ValueError(
+                "tiktoken is not installed, please install it use: pip install tiktoken"
+            )
+
+        import tiktoken as tk
+
+        tiktoken = tk
+
+        self.errors = errors  # how to handle errors in decoding
+
+        self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)  # type: dict[bytes, int]
+        self.special_tokens = {
+            token: index
+            for index, token in enumerate(
+                SPECIAL_TOKENS, start=len(self.mergeable_ranks)
+            )
+        }
+
+        enc = tiktoken.Encoding(
+            "Qwen",
+            pat_str=PAT_STR,
+            mergeable_ranks=self.mergeable_ranks,
+            special_tokens=self.special_tokens,
+        )
+        assert (
+            len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
+        ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
+
+        self.decoder = {
+            v: k for k, v in self.mergeable_ranks.items()
+        }  # type: dict[int, bytes|str]
+        self.decoder.update({v: k for k, v in self.special_tokens.items()})
+
+        self.tokenizer = enc  # type: tiktoken.Encoding
+
+        self.eod_id = self.tokenizer.eot_token
+        self.im_start_id = self.special_tokens[IMSTART]
+        self.im_end_id = self.special_tokens[IMEND]
+
+        if "pad_token_id" in kwargs:
+            self.pad_token_id = kwargs["pad_token_id"]
+        if "eos_token_id" in kwargs:
+            self.eos_token_id = kwargs["eos_token_id"]
+
+    def __len__(self) -> int:
+        return self.tokenizer.n_vocab
+
+    def get_vocab(self) -> Dict[bytes, int]:
+        return self.mergeable_ranks
+
+    def convert_tokens_to_ids(
+        self, tokens: Union[bytes, str, List[Union[bytes, str]]]
+    ) -> List[int]:
+        ids = []
+        if isinstance(tokens, (str, bytes)):
+            if tokens in self.special_tokens:
+                return self.special_tokens[tokens]
+            else:
+                return self.mergeable_ranks.get(tokens)
+        for token in tokens:
+            if token in self.special_tokens:
+                ids.append(self.special_tokens[token])
+            else:
+                ids.append(self.mergeable_ranks.get(token))
+        return ids
+
+    def _update_tiktoken(self, tokens: List[str], special_tokens: bool = False) -> int:
+        if special_tokens:
+            added_tokens = []
+            for token in tokens:
+                if token in self.special_tokens:
+                    continue
+
+                token_id = len(self.mergeable_ranks) + len(self.special_tokens)
+                self.special_tokens[token] = token_id
+                self.decoder[token_id] = token
+
+                added_tokens.append(token)
+
+            import tiktoken
+
+            self.tokenizer = tiktoken.Encoding(
+                "Qwen",
+                pat_str=PAT_STR,
+                mergeable_ranks=self.mergeable_ranks,
+                special_tokens=self.special_tokens,
+            )
+
+            return len(added_tokens)
+        else:
+            raise ValueError("Adding regular tokens is not supported")
+
+    def _add_tokens(
+        self,
+        new_tokens: Union[List[str], List[AddedToken]],
+        special_tokens: bool = False,
+    ) -> int:
+        if not special_tokens and new_tokens:
+            raise ValueError("Adding regular tokens is not supported")
+        new_tokens_str = []
+        for token in new_tokens:
+            surface_form = token.content if isinstance(token, AddedToken) else token
+            new_tokens_str.append(surface_form)
+
+        return self._update_tiktoken(new_tokens_str, special_tokens)
+
+    def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
+        """
+        Save only the vocabulary of the tokenizer (vocabulary).
+
+        Returns:
+            `Tuple(str)`: Paths to the files saved.
+        """
+        file_path = os.path.join(save_directory, "qwen.tiktoken")
+        with open(file_path, "w", encoding="utf8") as w:
+            for k, v in self.mergeable_ranks.items():
+                line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
+                w.write(line)
+        return (file_path,)
+
+    def tokenize(
+        self,
+        text: str,
+        allowed_special: Union[Set, str] = "all",
+        disallowed_special: Union[Collection, str] = (),
+        **kwargs,
+    ) -> List[Union[bytes, str]]:
+        """
+        Converts a string in a sequence of tokens.
+
+        Args:
+            text (`str`):
+                The sequence to be encoded.
+            allowed_special (`Literal["all"]` or `set`):
+                The surface forms of the tokens to be encoded as special tokens in regular texts.
+                Default to "all".
+            disallowed_special (`Literal["all"]` or `Collection`):
+                The surface forms of the tokens that should not be in regular texts and trigger errors.
+                Default to an empty tuple.
+
+            kwargs (additional keyword arguments, *optional*):
+                Will be passed to the underlying model specific encode method.
+
+        Returns:
+            `List[bytes|str]`: The list of tokens.
+        """
+        tokens = []
+        text = unicodedata.normalize("NFC", text)
+
+        # this implementation takes a detour: text -> token id -> token surface forms
+        for t in self.tokenizer.encode(
+            text, allowed_special=allowed_special, disallowed_special=disallowed_special
+        ):
+            tokens.append(self.decoder[t])
+        return tokens
+
+    def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
+        """
+        Converts a sequence of tokens in a single string.
+        """
+        text = ""
+        temp = b""
+        for t in tokens:
+            if isinstance(t, str):
+                if temp:
+                    text += temp.decode("utf-8", errors=self.errors)
+                    temp = b""
+                text += t
+            elif isinstance(t, bytes):
+                temp += t
+            else:
+                raise TypeError("token should only be of type types or str")
+        if temp:
+            text += temp.decode("utf-8", errors=self.errors)
+        return text
+
+    @property
+    def vocab_size(self):
+        return self.tokenizer.n_vocab
+
+    def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
+        """Converts an id to a token, special tokens included"""
+        if index in self.decoder:
+            return self.decoder[index]
+        raise ValueError("unknown ids")
+
+    def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
+        """Converts a token to an id using the vocab, special tokens included"""
+        if token in self.special_tokens:
+            return self.special_tokens[token]
+        if token in self.mergeable_ranks:
+            return self.mergeable_ranks[token]
+        raise ValueError("unknown token")
+
+    def _tokenize(self, text: str, **kwargs):
+        """
+        Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
+        vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
+
+        Do NOT take care of added tokens.
+        """
+        raise NotImplementedError
+
+    def _decode(
+        self,
+        token_ids: Union[int, List[int]],
+        skip_special_tokens: bool = False,
+        errors: str = None,
+        **kwargs,
+    ) -> str:
+        if isinstance(token_ids, int):
+            token_ids = [token_ids]
+        if skip_special_tokens:
+            token_ids = [i for i in token_ids if i < self.eod_id]
+        return self.tokenizer.decode(token_ids, errors=errors or self.errors)

+ 345 - 0
paddlex/inference/models/common/vlm/fusion_ops.py

@@ -0,0 +1,345 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import paddle
+import paddle.nn.functional as F
+
+try:
+    from paddle.incubate.nn.functional import fused_rotary_position_embedding
+except ImportError:
+    fused_rotary_position_embedding = None
+
+try:
+    from paddle.incubate.nn.functional import swiglu
+except ImportError:
+
+    def swiglu(x, y=None):
+        if y is None:
+            x, y = paddle.chunk(x, chunks=2, axis=-1)
+        return F.silu(x) * y
+
+
+from paddle.utils import try_import
+from paddlenlp.utils.tools import get_env_device
+
+try:
+    from paddle.incubate.nn.functional import fused_rotary_position_embedding
+except ImportError:
+    fused_rotary_position_embedding = None
+try:
+    if get_env_device() in ["npu", "mlu", "gcu"]:
+        from paddle.base import core
+
+        for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
+            if lib.endswith(".so"):
+                paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
+                    lib
+                )
+    from paddle.nn.functional.flash_attention import flash_attention
+except:
+    flash_attention = None
+
+from paddlenlp.transformers.refined_recompute import no_recompute
+from paddlenlp.transformers.ring_flash_attention import RingFlashAttention
+
+
+def fusion_rope(
+    query_states,
+    key_states,
+    value_states,
+    hidden_states,
+    position_ids,
+    past_key_value,
+    rotary_emb,
+    context_parallel_degree=-1,
+):
+    if get_env_device() not in ["gcu", "intel_hpu"]:
+        assert past_key_value is None, "fuse rotary not support cache kv for now"
+    batch_size, seq_length, num_heads, head_dim = query_states.shape
+    _, kv_seq_len, num_key_value_heads, _ = key_states.shape
+    if context_parallel_degree > 1:
+        assert (
+            get_env_device() == "gpu"
+        ), "context parallel only support cuda device for now"
+        kv_seq_len *= context_parallel_degree
+    if get_env_device() not in ["gcu", "intel_hpu"]:
+        cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
+    if get_env_device() == "npu":
+        query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
+            0
+        ]
+        key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
+    elif get_env_device() == "intel_hpu":
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-3]
+        cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
+        cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
+        sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
+        query_states, _, _ = (
+            paddle.incubate.nn.functional.fused_rotary_position_embedding(
+                paddle.transpose(query_states, [0, 2, 1, 3]),
+                None,
+                None,
+                sin=sin,
+                cos=cos,
+                position_ids=position_ids,
+            )
+        )
+        key_states, _, _ = (
+            paddle.incubate.nn.functional.fused_rotary_position_embedding(
+                paddle.transpose(key_states, [0, 2, 1, 3]),
+                None,
+                None,
+                sin=sin,
+                cos=cos,
+                position_ids=position_ids,
+            )
+        )
+        query_states = paddle.transpose(query_states, [0, 2, 1, 3])
+        key_states = paddle.transpose(key_states, [0, 2, 1, 3])
+    elif get_env_device() == "gcu":
+        cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
+        query_states, key_states = core.eager._run_custom_op(
+            "fused_rotary_embedding_gcu",
+            query_states,
+            key_states,
+            cos_sin,
+            position_ids,
+            True,
+        )
+    else:
+        # paddle version > 2.6 or develop support q and k/v with different num_heads
+        paddle_version = float(paddle.__version__[:3])
+        if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (
+            num_heads != num_key_value_heads
+        ):
+            query_states, _, _ = fused_rotary_position_embedding(
+                query_states,
+                None,
+                None,
+                sin=sin,
+                cos=cos,
+                position_ids=position_ids,
+                use_neox_rotary_style=False,
+            )
+            key_states, _, _ = fused_rotary_position_embedding(
+                key_states,
+                None,
+                None,
+                sin=sin,
+                cos=cos,
+                position_ids=position_ids,
+                use_neox_rotary_style=False,
+            )
+        else:
+            query_states, key_states, _ = fused_rotary_position_embedding(
+                query_states,
+                key_states,
+                v=None,
+                sin=sin,
+                cos=cos,
+                position_ids=position_ids,
+                use_neox_rotary_style=False,
+            )
+    return query_states, key_states
+
+
+def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
+    if use_fast_ln:
+        fast_ln = try_import("fast_ln")
+        return fast_ln.fast_rms_norm(x_in, w, eps)[0]
+    else:
+        fused_ln = try_import("fused_ln")
+        return fused_ln.fused_rms_norm(x_in, w, eps)[0]
+
+
+def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
+    if get_env_device() == "npu":
+        return core.eager._run_custom_op(
+            "rms_norm_npu", hidden_states, weight, variance_epsilon
+        )[0]
+    if get_env_device() == "mlu":
+        return core.eager._run_custom_op(
+            "rms_norm_mlu", hidden_states, weight, variance_epsilon
+        )[0]
+    elif get_env_device() == "gcu":
+        return core.eager._run_custom_op(
+            "rms_norm_gcu", hidden_states, weight, variance_epsilon
+        )[0]
+    elif get_env_device() == "intel_hpu":
+        return paddle.incubate.nn.functional.fused_rms_norm(
+            hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
+        )[0]
+
+    return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln)
+
+
+def fusion_flash_attention(
+    query_states,
+    config,
+    key_states,
+    value_states,
+    attention_mask,
+    output_attentions,
+    alibi=None,
+    attn_mask_startend_row_indices=None,
+    sequence_parallel=False,
+    reshard_layer=None,
+    npu_is_casual=False,
+    skip_recompute=False,
+):
+    # Note:
+    # 1. The head_dim of query_states and key_states should be the same. And the head_dim of value_states should be used for reshape.
+    bsz, q_len, num_heads, _ = query_states.shape
+    _, kv_seq_len, _, head_dim = value_states.shape
+    version = paddle.version.full_version
+    if version != "0.0.0" and version <= "2.5.2":
+        if alibi is not None:
+            raise ValueError("Flash Attention doesn't support alibi")
+        if config.context_parallel_degree > 1:
+            raise ValueError(
+                f"Context parallel is not implemented in version {version}"
+            )
+        attn_output, attn_weights = flash_attention(
+            query_states,
+            key_states,
+            value_states,
+            causal=True,
+            return_softmax=output_attentions,
+        )
+    else:
+        if alibi is not None:
+            alibi = alibi.reshape([bsz, num_heads, 1, -1])
+            attention_mask = attention_mask.cast(alibi.dtype) + alibi
+        if get_env_device() == "npu":
+            if config.context_parallel_degree > 1:
+                raise ValueError("Context parallel is not implemented for npu")
+            attn_output = core.eager._run_custom_op(
+                "flash_attention_npu",
+                query_states,
+                key_states,
+                value_states,
+                None,
+                attention_mask,
+                None,
+                None,
+                0.0,
+                attention_mask is None,
+                True,
+                False,
+                npu_is_casual,
+                False,
+            )[0]
+        elif get_env_device() == "gcu":
+            if config.context_parallel_degree > 1:
+                raise ValueError("Context parallel is not implemented for gcu")
+            attn_output = core.eager._run_custom_op(
+                "fused_sdp_flash_attention_gcu",
+                query_states,
+                key_states,
+                value_states,
+                attention_mask,
+                0.0,
+                attention_mask is None,
+                True,
+            )[0]
+        elif get_env_device() == "intel_hpu":
+            if config.context_parallel_degree > 1:
+                raise ValueError("Context parallel is not implemented for intel_hpu")
+            scaling_factor = query_states.shape[3] ** -0.5
+            attention_mask = attention_mask.astype(query_states.dtype)
+            attn_output = paddle.incubate.nn.functional.fused_dot_product_attention(
+                query_states,
+                key_states,
+                value_states,
+                attention_mask,
+                0.0,
+                attention_mask is None,
+                scaling_factor,
+                False,
+            )
+        else:
+            if config.context_parallel_degree > 1:
+                attn_output = RingFlashAttention.apply(
+                    query_states,
+                    key_states,
+                    value_states,
+                    attn_mask=None,
+                    is_causal=True,
+                )
+            else:
+                if attn_mask_startend_row_indices is not None:
+                    assert (
+                        alibi is None
+                    ), "flashmask_attention or flash_attention_with_sparse_mask not support alibi"
+                    if len(attn_mask_startend_row_indices.shape) == 2:
+                        attn_mask_startend_row_indices = paddle.unsqueeze(
+                            attn_mask_startend_row_indices, axis=1
+                        )
+
+                    if hasattr(F, "flashmask_attention"):
+                        attn_output = no_recompute(
+                            F.flashmask_attention,
+                            query_states,
+                            key_states,
+                            value_states,
+                            startend_row_indices=attn_mask_startend_row_indices.unsqueeze(
+                                -1
+                            ),
+                            causal=True,
+                            enable=skip_recompute,
+                        )
+                    else:
+                        attn_output = no_recompute(
+                            F.flash_attention_with_sparse_mask,
+                            query_states,
+                            key_states,
+                            value_states,
+                            attn_mask_start_row_indices=attn_mask_startend_row_indices,
+                            is_causal=True,
+                            enable=skip_recompute,
+                        )
+                else:
+                    attn_output = no_recompute(
+                        F.scaled_dot_product_attention,
+                        query_states,
+                        key_states,
+                        value_states,
+                        attn_mask=attention_mask,
+                        is_causal=query_states.shape[1] != 1,
+                        enable=skip_recompute,
+                    )
+        attn_weights = None
+
+    if reshard_layer is not None:
+        # attn_output shape: [bs, seqlen, num_head/sep, head_dim]
+        attn_output = reshard_layer(
+            attn_output,
+            split_axis=1,
+            concat_axis=2,
+        )
+        # attn_output shape: [bs, seqlen/sep, num_head, head_dim]
+        assert (
+            config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
+        ), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
+        q_len = q_len // config.sep_parallel_degree
+        num_heads = num_heads * config.sep_parallel_degree
+
+    if sequence_parallel:
+        attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
+    else:
+        attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
+    return (attn_output, attn_weights) if output_attentions else attn_output

+ 830 - 0
paddlex/inference/models/doc_vlm/modeling/GOT_ocr_2_0.py

@@ -0,0 +1,830 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import List, Optional, Tuple, Type
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+from ...common.vlm.transformers.model_outputs import CausalLMOutputWithPast
+from .qwen2 import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
+
+
+class MLPBlock(paddle.nn.Layer):
+    def __init__(
+        self,
+        embedding_dim: int,
+        mlp_dim: int,
+        act: Type[paddle.nn.Layer] = paddle.nn.GELU,
+    ) -> None:
+        super().__init__()
+        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+        self.act = act()
+
+    def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+        return self.lin2(self.act(self.lin1(x)))
+
+
+class LayerNorm2d(paddle.nn.Layer):
+    def __init__(self, num_channels: int, epsilon: float = 1e-06) -> None:
+        super().__init__()
+        self.weight = paddle.base.framework.EagerParamBase.from_tensor(
+            tensor=paddle.ones(shape=num_channels)
+        )
+        self.bias = paddle.base.framework.EagerParamBase.from_tensor(
+            tensor=paddle.zeros(shape=num_channels)
+        )
+        self.epsilon = epsilon
+
+    def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+        u = x.mean(axis=1, keepdim=True)
+        s = (x - u).pow(y=2).mean(axis=1, keepdim=True)
+        x = (x - u) / paddle.sqrt(x=s + self.epsilon)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
+
+
+class ImageEncoderViT(paddle.nn.Layer):
+    def __init__(
+        self,
+        img_size: int = 1024,
+        patch_size: int = 16,
+        in_chans: int = 3,
+        embed_dim: int = 768,
+        depth: int = 12,
+        num_heads: int = 12,
+        mlp_ratio: float = 4.0,
+        out_chans: int = 256,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Layer] = nn.LayerNorm,
+        act_layer: Type[nn.Layer] = nn.GELU,
+        use_abs_pos: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        global_attn_indexes: Tuple[int, ...] = (),
+    ) -> None:
+        """
+        Args:
+            img_size (int): Input image size.
+            patch_size (int): Patch size.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Patch embedding dimension.
+            depth (int): Depth of ViT.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            norm_layer (nn.Layer): Normalization layer.
+            act_layer (nn.Layer): Activation layer.
+            use_abs_pos (bool): If True, use absolute positional embeddings.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks.
+            global_attn_indexes (list): Indexes for blocks using global attention.
+        """
+        super().__init__()
+        self.img_size = img_size
+
+        self.patch_embed = PatchEmbed(
+            kernel_size=(patch_size, patch_size),
+            stride=(patch_size, patch_size),
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+
+        self.pos_embed: Optional[paddle.base.framework.EagerParamBase.from_tensor] = (
+            None
+        )
+        if use_abs_pos:
+            self.pos_embed = paddle.base.framework.EagerParamBase.from_tensor(
+                tensor=paddle.zeros(
+                    shape=[1, img_size // patch_size, img_size // patch_size, embed_dim]
+                )
+            )
+
+        self.blocks = paddle.nn.LayerList()
+        for i in range(depth):
+            block = Block(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                norm_layer=norm_layer,
+                act_layer=act_layer,
+                use_rel_pos=use_rel_pos,
+                rel_pos_zero_init=rel_pos_zero_init,
+                window_size=window_size if i not in global_attn_indexes else 0,
+                input_size=(img_size // patch_size, img_size // patch_size),
+            )
+            self.blocks.append(block)
+
+        self.neck = nn.Sequential(
+            nn.Conv2D(
+                embed_dim,
+                out_chans,
+                kernel_size=1,
+                bias_attr=False,
+            ),
+            LayerNorm2d(out_chans),
+            nn.Conv2D(
+                out_chans,
+                out_chans,
+                kernel_size=3,
+                padding=1,
+                bias_attr=False,
+            ),
+            LayerNorm2d(out_chans),
+        )
+
+        self.net_2 = nn.Conv2D(
+            256, 512, kernel_size=3, stride=2, padding=1, bias_attr=False
+        )
+        self.net_3 = nn.Conv2D(
+            512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False
+        )
+
+    def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+        x = self.patch_embed(x)
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+        for blk in self.blocks:
+            x = blk(x)
+        x = self.neck(x.transpose([0, 3, 1, 2]))
+        x = self.net_2(x)
+        x = self.net_3(x)
+        return x
+
+
+class Block(paddle.nn.Layer):
+    """Transformer blocks with support of window attention and residual propagation blocks"""
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Layer] = nn.LayerNorm,
+        act_layer: Type[nn.Layer] = nn.GELU,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            norm_layer (nn.Layer): Normalization layer.
+            act_layer (nn.Layer): Activation layer.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks. If it equals 0, then
+                use global attention.
+            input_size (tuple(int, int) or None): Input resolution for calculating the relative
+                positional parameter size.
+        """
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            use_rel_pos=use_rel_pos,
+            rel_pos_zero_init=rel_pos_zero_init,
+            input_size=input_size if window_size == 0 else (window_size, window_size),
+        )
+
+        self.norm2 = norm_layer(dim)
+        self.mlp = MLPBlock(
+            embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
+        )
+
+        self.window_size = window_size
+
+    def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+        shortcut = x
+        x = self.norm1(x)
+        # Window partition
+        if self.window_size > 0:
+            H, W = x.shape[1], x.shape[2]
+            x, pad_hw = window_partition(x, self.window_size)
+
+        x = self.attn(x)
+        # Reverse window partition
+        if self.window_size > 0:
+            x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+        x = shortcut + x
+        x = x + self.mlp(self.norm2(x))
+
+        return x
+
+
+class Attention(paddle.nn.Layer):
+    """Multi-head Attention block with relative position embeddings."""
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 8,
+        qkv_bias: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads.
+            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
+            rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            input_size (tuple(int, int) or None): Input resolution for calculating the relative
+                positional parameter size.
+        """
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+        self.proj = nn.Linear(dim, dim)
+
+        self.use_rel_pos = use_rel_pos
+        if self.use_rel_pos:
+            assert (
+                input_size is not None
+            ), "Input size must be provided if using relative positional encoding."
+            self.rel_pos_h = paddle.base.framework.EagerParamBase.from_tensor(
+                tensor=paddle.zeros(shape=[2 * input_size[0] - 1, head_dim])
+            )
+            self.rel_pos_w = paddle.base.framework.EagerParamBase.from_tensor(
+                tensor=paddle.zeros(shape=[2 * input_size[1] - 1, head_dim])
+            )
+
+    def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+        B, H, W, _ = tuple(x.shape)
+        qkv = (
+            self.qkv(x)
+            .reshape([B, H * W, 3, self.num_heads, -1])
+            .transpose([2, 0, 3, 1, 4])
+        )
+        q, k, v = qkv.reshape([3, B * self.num_heads, H * W, -1]).unbind(axis=0)
+
+        attn = (q * self.scale) @ k.transpose([0, 2, 1])
+
+        if self.use_rel_pos:
+            attn = add_decomposed_rel_pos(
+                attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
+            )
+
+        attn = F.softmax(attn, axis=-1)
+        x = (
+            (attn @ v)
+            .reshape([B, self.num_heads, H, W, -1])
+            .transpose([0, 2, 3, 1, 4])
+            .reshape([B, H, W, -1])
+        )
+        x = self.proj(x)
+
+        return x
+
+
+def window_partition(
+    x: paddle.Tensor, window_size: int
+) -> Tuple[paddle.Tensor, Tuple[int, int]]:
+    """
+    Partition into non-overlapping windows with padding if needed.
+    Args:
+        x (tensor): input tokens with [B, H, W, C].
+        window_size (int): window size.
+
+    Returns:
+        windows: windows after partition with [B * num_windows, window_size, window_size, C].
+        (Hp, Wp): padded height and width before partition
+    """
+    B, H, W, C = tuple(x.shape)
+
+    pad_h = (window_size - H % window_size) % window_size
+    pad_w = (window_size - W % window_size) % window_size
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(x, pad=(0, pad_w, 0, pad_h), data_format="NHWC")
+    Hp, Wp = H + pad_h, W + pad_w
+
+    x = x.reshape(
+        [B, Hp // window_size, window_size, Wp // window_size, window_size, C]
+    )
+    windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])
+    return windows, (Hp, Wp)
+
+
+def window_unpartition(
+    windows: paddle.Tensor,
+    window_size: int,
+    pad_hw: Tuple[int, int],
+    hw: Tuple[int, int],
+) -> paddle.Tensor:
+    """
+    Window unpartition into original sequences and removing padding.
+    Args:
+        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+        window_size (int): window size.
+        pad_hw (Tuple): padded height and width (Hp, Wp).
+        hw (Tuple): original height and width (H, W) before padding.
+
+    Returns:
+        x: unpartitioned sequences with [B, H, W, C].
+    """
+    Hp, Wp = pad_hw
+    H, W = hw
+    B = tuple(windows.shape)[0] // (Hp * Wp // window_size // window_size)
+    x = windows.reshape(
+        [B, Hp // window_size, Wp // window_size, window_size, window_size, -1]
+    )
+    x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, -1])
+    if Hp > H or Wp > W:
+        x = x[:, :H, :W, :]
+    return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: paddle.Tensor) -> paddle.Tensor:
+    """
+    Get relative positional embeddings according to the relative positions of
+        query and key sizes.
+    Args:
+        q_size (int): size of query q.
+        k_size (int): size of key k.
+        rel_pos (Tensor): relative position embeddings (L, C).
+
+    Returns:
+        Extracted positional embeddings according to relative positions.
+    """
+    max_rel_dist = int(2 * max(q_size, k_size) - 1)
+    if tuple(rel_pos.shape)[0] != max_rel_dist:
+        rel_pos_resized = paddle.nn.functional.interpolate(
+            rel_pos.reshape([1, tuple(rel_pos.shape)[0], -1]).transpose([0, 2, 1]),
+            size=max_rel_dist,
+            mode="linear",
+        )
+        rel_pos_resized = rel_pos_resized.reshape([-1, max_rel_dist]).transpose([1, 0])
+    else:
+        rel_pos_resized = rel_pos
+
+    q_coords = paddle.arange(end=q_size)[:, None] * max(k_size / q_size, 1.0)
+    k_coords = paddle.arange(end=k_size)[None, :] * max(q_size / k_size, 1.0)
+    relative_coords = q_coords - k_coords + (k_size - 1) * max(q_size / k_size, 1.0)
+    return rel_pos_resized[relative_coords.astype(dtype="int64")]
+
+
+def add_decomposed_rel_pos(
+    attn: paddle.Tensor,
+    q: paddle.Tensor,
+    rel_pos_h: paddle.Tensor,
+    rel_pos_w: paddle.Tensor,
+    q_size: Tuple[int, int],
+    k_size: Tuple[int, int],
+) -> paddle.Tensor:
+    """
+    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
+    Args:
+        attn (Tensor): attention map.
+        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+    Returns:
+        attn (Tensor): attention map with added relative positional embeddings.
+    """
+    q_h, q_w = q_size
+    k_h, k_w = k_size
+    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+    Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+    B, _, dim = tuple(q.shape)
+    r_q = q.reshape([B, q_h, q_w, dim])
+    rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh)
+    rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+    attn = (
+        attn.reshape([B, q_h, q_w, k_h, k_w])
+        + rel_h[:, :, :, :, None]
+        + rel_w[:, :, :, None, :]
+    ).reshape([B, q_h * q_w, k_h * k_w])
+
+    return attn
+
+
+class PatchEmbed(paddle.nn.Layer):
+    """
+    Image to Patch Embedding.
+    """
+
+    def __init__(
+        self,
+        kernel_size: Tuple[int, int] = (16, 16),
+        stride: Tuple[int, int] = (16, 16),
+        padding: Tuple[int, int] = (0, 0),
+        in_chans: int = 3,
+        embed_dim: int = 768,
+    ) -> None:
+        """
+        Args:
+            kernel_size (Tuple): kernel size of the projection layer.
+            stride (Tuple): stride of the projection layer.
+            padding (Tuple): padding size of the projection layer.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Patch embedding dimension.
+        """
+        super().__init__()
+        self.proj = nn.Conv2D(
+            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+        )
+
+    def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+        x = self.proj(x)
+        # B C H W -> B H W C
+        x = x.transpose([0, 2, 3, 1])
+        return x
+
+
+DEFAULT_IMAGE_TOKEN = "<image>"
+DEFAULT_IMAGE_PATCH_TOKEN = "<imgpad>"
+DEFAULT_IM_START_TOKEN = "<img>"
+DEFAULT_IM_END_TOKEN = "</img>"
+
+
+class Qwen2LMHead(nn.Layer):
+    def __init__(
+        self,
+        config,
+        embedding_weights=None,
+        transpose_y=False,
+        tensor_parallel_output=1,
+    ):
+        super(Qwen2LMHead, self).__init__()
+        self.config = config
+        vocab_size = config.vocab_size
+
+        self.transpose_y = transpose_y
+        if transpose_y:
+            # only for weight from embedding_weights
+            if embedding_weights is not None:
+                self.weight = embedding_weights
+            else:
+                self.weight = self.create_parameter(
+                    shape=[vocab_size, config.hidden_size],
+                    dtype=paddle.get_default_dtype(),
+                )
+        else:
+            # for weight from model init
+            self.weight = self.create_parameter(
+                shape=[config.hidden_size, vocab_size],
+                dtype=paddle.get_default_dtype(),
+            )
+
+    def forward(self, hidden_states, tensor_parallel_output=1):
+        logits = paddle.matmul(hidden_states, self.weight, transpose_y=self.transpose_y)
+        return logits
+
+
+class GOTConfig(Qwen2Config):
+    model_type = "GOT"
+
+
+class GOTQwenModel(Qwen2Model):
+    config_class = GOTConfig
+
+    def __init__(self, config: Qwen2Config):
+        super(GOTQwenModel, self).__init__(config)
+        self.vision_tower_high = ImageEncoderViT(
+            depth=12,
+            embed_dim=768,
+            img_size=1024,
+            mlp_ratio=4,
+            norm_layer=partial(paddle.nn.LayerNorm, epsilon=1e-6),
+            num_heads=12,
+            patch_size=16,
+            qkv_bias=True,
+            use_rel_pos=True,
+            global_attn_indexes=[2, 5, 8, 11],
+            window_size=14,
+            out_chans=256,
+        )
+        self.mm_projector_vary = nn.Linear(1024, 1024)
+
+    def forward(
+        self,
+        input_ids: paddle.Tensor = None,
+        attention_mask: Optional[paddle.Tensor] = None,
+        position_ids: Optional[paddle.Tensor] = None,
+        past_key_values: Optional[List[paddle.Tensor]] = None,
+        inputs_embeds: Optional[paddle.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        images: Optional[paddle.Tensor] = None,
+        return_dict: Optional[bool] = None,
+    ):
+        # HACK: replace back original embeddings for LLaVA pretraining
+        orig_embeds_params = getattr(self, "orig_embeds_params", None)
+        if orig_embeds_params is not None:
+            with paddle.no_grad():
+                self.get_input_embeddings().weight[: -self.num_new_tokens] = (
+                    orig_embeds_params[: -self.num_new_tokens].data
+                )
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        vision_tower_high = getattr(self, "vision_tower_high", None)
+
+        if (
+            vision_tower_high is not None
+            and (input_ids.shape[1] != 1 or self.training)
+            and images is not None
+        ):
+            use_im_start_end = getattr(self.config, "use_im_start_end", -1)
+
+            im_patch_token = getattr(self.config, "im_patch_token", -1)
+            im_start_token = getattr(self.config, "im_start_token", -1)
+            im_end_token = getattr(self.config, "im_end_token", -1)
+
+            im_patch_token = 151859
+            im_start_token = 151857
+            im_end_token = 151858
+
+            image_features = []
+
+            for image in images:
+                if self.training:
+                    image = image[1]
+                P, C, H, W = image.shape
+                if P == 1:
+                    with paddle.set_grad_enabled(False):
+                        cnn_feature = vision_tower_high(image)
+                        cnn_feature = cnn_feature.flatten(2).transpose(
+                            [0, 2, 1]
+                        )  # 256*1024
+                    image_feature = self.mm_projector_vary(cnn_feature)
+                    image_features.append(image_feature)
+
+                else:
+                    image_patches = paddle.unbind(image)
+                    image_patches_features = []
+                    for image_patch in image_patches:
+                        image_p = paddle.stack([image_patch])
+                        with paddle.set_grad_enabled(False):
+                            cnn_feature_p = vision_tower_high(image_p)
+                            cnn_feature_p = cnn_feature_p.flatten(2).transpose(
+                                [0, 2, 1]
+                            )
+                        image_feature_p = self.mm_projector_vary(cnn_feature_p)
+                        image_patches_features.append(image_feature_p)
+                    image_feature = paddle.concat(image_patches_features, axis=1)
+                    image_features.append(image_feature)
+
+            dummy_image_features_2 = paddle.zeros(
+                [256, 1024], dtype=inputs_embeds.dtype
+            )
+            dummy_image_features = dummy_image_features_2
+            use_im_start_end = True
+            new_input_embeds = []
+            for cur_input_ids, cur_input_embeds, cur_image_features in zip(
+                input_ids, inputs_embeds, image_features
+            ):
+                if (cur_input_ids == im_patch_token).sum() == 0:
+                    # multimodal LLM, but the current sample is not multimodal
+                    cur_input_embeds = (
+                        cur_input_embeds + (0.0 * dummy_image_features).sum()
+                    )
+                    new_input_embeds.append(cur_input_embeds)
+                    continue
+
+                if use_im_start_end:
+                    if (cur_input_ids == im_start_token).sum() != (
+                        cur_input_ids == im_end_token
+                    ).sum():
+                        raise ValueError(
+                            "The number of image start tokens and image end tokens should be the same."
+                        )
+
+                    image_start_tokens = paddle.where(cur_input_ids == im_start_token)[
+                        0
+                    ]
+                    for image_start_token_pos, per_cur_image_features in zip(
+                        image_start_tokens, cur_image_features
+                    ):
+                        num_patches = per_cur_image_features.shape[0]
+
+                        if (
+                            cur_input_ids[image_start_token_pos + num_patches + 1]
+                            != im_end_token
+                        ):
+                            raise ValueError(
+                                "The image end token should follow the image start token."
+                            )
+
+                        cur_input_embeds = paddle.concat(
+                            (
+                                cur_input_embeds[: image_start_token_pos + 1],
+                                per_cur_image_features,
+                                cur_input_embeds[
+                                    image_start_token_pos + num_patches + 1 :
+                                ],
+                            ),
+                            axis=0,
+                        )
+
+                    new_input_embeds.append(cur_input_embeds)
+                else:
+                    raise NotImplementedError
+
+            inputs_embeds = paddle.stack(new_input_embeds, axis=0)
+
+        return super().forward(
+            input_ids=None,
+            attention_mask=attention_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            position_ids=position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+
+class GOTQwenForCausalLM(Qwen2ForCausalLM):
+    config_class = GOTConfig
+
+    def __init__(self, config):
+        super(Qwen2ForCausalLM, self).__init__(config)
+        self.qwen2 = GOTQwenModel(config)
+
+        self.vocab_size = config.vocab_size
+
+        if config.tie_word_embeddings:
+            self.lm_head = Qwen2LMHead(
+                config,
+                embedding_weights=self.qwen2.embed_tokens.weight,
+                transpose_y=True,
+            )
+            self.tie_weights()
+        else:
+            self.lm_head = Qwen2LMHead(config)
+
+        self.eval()
+
+    def get_model(self):
+        return self.qwen2
+
+    def forward(
+        self,
+        input_ids: paddle.Tensor = None,
+        attention_mask: Optional[paddle.Tensor] = None,
+        position_ids: Optional[paddle.Tensor] = None,
+        past_key_values: Optional[List[paddle.Tensor]] = None,
+        inputs_embeds: Optional[paddle.Tensor] = None,
+        labels: Optional[paddle.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        images: Optional[paddle.Tensor] = None,
+        return_dict: Optional[bool] = None,
+    ):
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        outputs = self.qwen2(
+            input_ids=input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            images=images,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0]
+        logits = self.lm_head(hidden_states)
+        logits = logits.astype(dtype="float32")
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :]
+            shift_labels = labels[..., 1:]
+            loss_fct = nn.CrossEntropyLoss(reduction="sum")
+            shift_logits = shift_logits.reshape([-1, self.config.vocab_size])
+            shift_labels = shift_labels.reshape([-1])
+
+            loss = loss_fct(shift_logits, shift_labels)
+            label_sum = paddle.sum(shift_labels != -100)
+            loss = loss / label_sum
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        past_key_values=None,
+        attention_mask=None,
+        inputs_embeds=None,
+        **kwargs
+    ):
+        batch_size, seq_length = input_ids.shape
+        attention_mask = paddle.ones((batch_size, seq_length), dtype=paddle.bool)
+
+        # Omit tokens covered by past_key_values
+        if past_key_values is not None:
+            past_length = past_key_values[0][0].shape[1]
+            if past_length < input_ids.shape[1]:
+                input_ids = input_ids[:, past_length:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.astype(dtype="int64").cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -input_ids.shape[1] :]
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+                "images": kwargs.get("images", None),
+            }
+        )
+        return model_inputs
+
+
+class PPChart2TableInference(GOTQwenForCausalLM):
+
+    def generate(self, inputs, **kwargs):
+        max_new_tokens = kwargs.get("max_new_tokens", 1024)
+        no_repeat_ngram_size = kwargs.get("no_repeat_ngram_size", 20)
+
+        with paddle.no_grad():
+            generated_ids = super().generate(
+                inputs["input_ids"],
+                images=inputs["images"],
+                do_sample=False,
+                num_beams=1,
+                no_repeat_ngram_size=no_repeat_ngram_size,
+                max_new_tokens=max_new_tokens,
+            )
+
+        return generated_ids

+ 1 - 0
paddlex/inference/models/doc_vlm/modeling/__init__.py

@@ -12,4 +12,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from .GOT_ocr_2_0 import PPChart2TableInference
 from .qwen2_vl import PPDocBeeInference, Qwen2VLForConditionalGeneration

+ 1706 - 0
paddlex/inference/models/doc_vlm/modeling/qwen2.py

@@ -0,0 +1,1706 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import partial
+from typing import List, Optional, Tuple, Union
+
+import paddle
+import paddle.distributed.fleet.meta_parallel as mpu
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import Tensor
+from paddle.distributed import fleet
+from paddle.distributed.fleet.utils import sequence_parallel_utils
+
+from .....utils import logging
+from .....utils.env import get_device_type
+from ...common.vlm import fusion_ops
+from ...common.vlm.activations import ACT2FN
+from ...common.vlm.transformers import PretrainedConfig, PretrainedModel
+from ...common.vlm.transformers.model_outputs import (
+    BaseModelOutputWithPast,
+    CausalLMOutputWithPast,
+)
+
+try:
+    from paddle.incubate.nn.functional import fused_rotary_position_embedding
+except ImportError:
+    fused_rotary_position_embedding = None
+
+try:
+    from paddle.distributed.fleet.utils.sequence_parallel_utils import (
+        GatherOp,
+        ScatterOp,
+        mark_as_sequence_parallel_parameter,
+    )
+except:
+    pass
+
+try:
+    from paddle.nn.functional.flash_attention import flash_attention
+except:
+    flash_attention = None
+
+
+Linear = nn.Linear
+ColumnParallelLinear = mpu.ColumnParallelLinear
+RowParallelLinear = mpu.RowParallelLinear
+ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
+RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
+
+
+class Qwen2Config(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
+    Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of
+    Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 151936):
+            Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`Qwen2Model`]
+        hidden_size (`int`, *optional*, defaults to 4096):
+            Dimension of the hidden representations.
+        intermediate_size (`int`, *optional*, defaults to 22016):
+            Dimension of the MLP representations.
+        num_hidden_layers (`int`, *optional*, defaults to 32):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 32):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        num_key_value_heads (`int`, *optional*, defaults to 32):
+            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+            by meanpooling all the original heads within that group. For more details checkout [this
+            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 32768):
+            The maximum sequence length that this model might ever be used with.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+            The epsilon used by the rms normalization layers.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether the model's input and output word embeddings should be tied.
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        use_sliding_window (`bool`, *optional*, defaults to `False`):
+            Whether to use sliding window attention.
+        sliding_window (`int`, *optional*, defaults to 4096):
+            Sliding window attention (SWA) window size. If not specified, will default to `4096`.
+        max_window_layers (`int`, *optional*, defaults to 28):
+            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+    """
+
+    model_type = "qwen2"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=151936,
+        hidden_size=4096,
+        intermediate_size=22016,
+        num_hidden_layers=32,
+        num_attention_heads=32,
+        num_key_value_heads=32,
+        hidden_act="silu",
+        max_position_embeddings=32768,
+        seq_length=32768,
+        initializer_range=0.02,
+        rms_norm_eps=1e-6,
+        use_cache=True,
+        tie_word_embeddings=False,
+        rope_theta=10000.0,
+        pad_token_id=0,
+        bos_token_id=151643,
+        eos_token_id=151643,
+        use_sliding_window=False,
+        sliding_window=4096,
+        max_window_layers=28,
+        attention_dropout=0.0,
+        rope_scaling_factor=1.0,
+        rope_scaling_type=None,
+        dpo_config=None,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.seq_length = seq_length
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.use_sliding_window = use_sliding_window
+        self.sliding_window = sliding_window
+        self.max_window_layers = max_window_layers
+
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.attention_dropout = attention_dropout
+
+        self.use_cache = use_cache
+        self.rope_scaling_factor = rope_scaling_factor
+        self.rope_scaling_type = rope_scaling_type
+
+        self.pad_token_id = pad_token_id
+        self.bos_token_id = bos_token_id
+        self.eos_token_id = eos_token_id
+        self.dpo_config = dpo_config
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+
+
+def get_triangle_upper_mask(x, mask=None):
+    if mask is not None:
+        return mask
+    # [bsz, n_head, q_len, kv_seq_len]
+    shape = x.shape
+    #  [bsz, 1, q_len, kv_seq_len]
+    shape[1] = 1
+    mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype)
+    mask = paddle.triu(mask, diagonal=1)
+    mask.stop_gradient = True
+    return mask
+
+
+def parallel_matmul(
+    x: Tensor, y: Tensor, transpose_y=True, tensor_parallel_output=True
+):
+    is_fleet_init = True
+    tensor_parallel_degree = 1
+    try:
+        hcg = fleet.get_hybrid_communicate_group()
+        model_parallel_group = hcg.get_model_parallel_group()
+        tensor_parallel_degree = hcg.get_model_parallel_world_size()
+    except:
+        is_fleet_init = False
+
+    if paddle.in_dynamic_mode():
+        y_is_distributed = y.is_distributed
+    else:
+        y_is_distributed = tensor_parallel_degree > 1
+
+    if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
+        # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
+        input_parallel = paddle.distributed.collective._c_identity(
+            x, group=model_parallel_group
+        )
+        logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)
+
+        if tensor_parallel_output:
+            return logits
+
+        return paddle.distributed.collective._c_concat(
+            logits, group=model_parallel_group
+        )
+
+    else:
+        logits = paddle.matmul(x, y, transpose_y=transpose_y)
+        return logits
+
+
+def scaled_dot_product_attention(
+    query_states,
+    config,
+    key_states,
+    value_states,
+    attention_mask,
+    output_attentions,
+    attn_mask_startend_row_indices=None,
+    training=True,
+    sequence_parallel=False,
+    skip_recompute=False,
+):
+    bsz, q_len, num_heads, head_dim = query_states.shape
+    _, kv_seq_len, _, _ = value_states.shape
+
+    if config.use_flash_attention and flash_attention:
+        # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
+        # Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
+
+        return fusion_ops.fusion_flash_attention(
+            query_states,
+            config,
+            key_states,
+            value_states,
+            attention_mask,
+            output_attentions,
+            attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+            sequence_parallel=sequence_parallel,
+            skip_recompute=skip_recompute,
+        )
+    else:
+        #  [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
+        query_states = paddle.transpose(query_states, [0, 2, 1, 3])
+        # merge with the next transpose
+        key_states = paddle.transpose(key_states, [0, 2, 1, 3])
+        value_states = paddle.transpose(value_states, [0, 2, 1, 3])
+
+        # Add pre divided factor to fix nan under float16.
+        if paddle.in_dynamic_mode() and query_states.dtype == paddle.float16:
+            pre_divided_factor = 32
+        else:
+            pre_divided_factor = 1
+
+        attn_weights = paddle.matmul(
+            query_states / (math.sqrt(head_dim) * pre_divided_factor),
+            key_states.transpose([0, 1, 3, 2]),
+        )
+
+        if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
+            raise ValueError(
+                f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.shape}"
+            )
+
+        if attention_mask is None:
+            attention_mask = get_triangle_upper_mask(attn_weights)
+
+        attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
+        if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]:
+            raise ValueError(
+                f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}"
+            )
+
+        attn_weights = attn_weights + attention_mask
+
+        if not paddle.in_dynamic_mode():
+            attn_weights = F.softmax(
+                attn_weights * pre_divided_factor, axis=-1, dtype="float32"
+            ).astype(query_states.dtype)
+        else:
+            with paddle.amp.auto_cast(False):
+                attn_weights = F.softmax(
+                    attn_weights.astype("float32") * pre_divided_factor,
+                    axis=-1,
+                    dtype="float32",
+                ).astype(query_states.dtype)
+
+        attn_weights = F.dropout(
+            attn_weights, p=config.attention_dropout, training=training
+        )
+
+        attn_output = paddle.matmul(attn_weights, value_states)
+        attn_output = attn_output.transpose([0, 2, 1, 3])
+
+        if sequence_parallel:
+            attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
+        else:
+            attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
+        return (attn_output, attn_weights) if output_attentions else attn_output
+
+
+def is_casual_mask(attention_mask):
+    """
+    Upper triangular of attention_mask equals to attention_mask is casual
+    """
+    return (paddle.triu(attention_mask) == attention_mask).all().item()
+
+
+def _make_causal_mask(input_ids_shape, past_key_values_length):
+    """
+    Make causal mask used for self-attention
+    """
+    batch_size, target_length = input_ids_shape  # target_length: seq_len
+
+    mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool"))
+
+    if past_key_values_length > 0:
+        # [tgt_len, tgt_len + past_len]
+        mask = paddle.concat(
+            [paddle.ones([target_length, past_key_values_length], dtype="bool"), mask],
+            axis=-1,
+        )
+
+    # [bs, 1, tgt_len, tgt_len + past_len]
+    return mask[None, None, :, :].expand(
+        [batch_size, 1, target_length, target_length + past_key_values_length]
+    )
+
+
+def _expand_2d_mask(mask, dtype, tgt_length):
+    """
+    Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
+    """
+    batch_size, src_length = mask.shape[0], mask.shape[-1]
+    tgt_length = tgt_length if tgt_length is not None else src_length
+
+    mask = mask[:, None, None, :].astype("bool")
+    mask.stop_gradient = True
+    expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
+
+    return expanded_mask
+
+
+class Qwen2RMSNorm(nn.Layer):
+    def __init__(self, config: Qwen2Config):
+        """
+        Qwen2RMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.weight = paddle.create_parameter(
+            shape=[self.hidden_size],
+            dtype=paddle.get_default_dtype(),
+            default_initializer=nn.initializer.Constant(1.0),
+        )
+        self.variance_epsilon = config.rms_norm_eps
+        self.config = config
+
+        if config.sequence_parallel:
+            mark_as_sequence_parallel_parameter(self.weight)
+
+    def forward(self, hidden_states):
+        if self.config.use_fused_rms_norm:
+            return fusion_ops.fusion_rms_norm(
+                hidden_states, self.weight, self.variance_epsilon, False
+            )
+
+        if paddle.in_dynamic_mode():
+            with paddle.amp.auto_cast(False):
+                variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
+                hidden_states = (
+                    paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
+                )
+        else:
+            variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True)
+            hidden_states = (
+                paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
+            )
+
+        if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
+            hidden_states = paddle.cast(hidden_states, self.weight.dtype)
+        return hidden_states * self.weight
+
+
+class Qwen2RotaryEmbedding(nn.Layer):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000):
+        super().__init__()
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        # [dim / 2]
+        self.inv_freq = 1.0 / (
+            self.base
+            ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)
+        )
+        self._set_cos_sin_cache(seq_len=max_position_embeddings)
+
+    def _set_cos_sin_cache(self, seq_len):
+        self.max_seq_len_cached = seq_len
+        # [seq_len]
+        t = paddle.arange(seq_len, dtype="float32")
+        # [seq_len, dim/2]
+        freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        # [seq_len, dim]
+        emb = paddle.concat([freqs, freqs], axis=-1)
+        # [1, seqlen, 1, dim]
+        self.cos_cached = emb.cos()[None, :, None, :]
+        self.sin_cached = emb.sin()[None, :, None, :]
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if seq_len > self.max_seq_len_cached:
+            self._set_cos_sin_cache(seq_len)
+        cos = self.cos_cached[:, :seq_len, :, :]
+        sin = self.sin_cached[:, :seq_len, :, :]
+        return (
+            cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
+            sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
+        )
+
+
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return paddle.concat([-x2, x1], axis=-1)  # shape is the same as x
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+    if position_ids is None:
+        # Note: Only for Qwen2MoEForCausalLMPipe model pretraining
+        cos = cos[:, : q.shape[1], :, :]  # [bs, seq_len, 1, dim]
+        sin = sin[:, : q.shape[1], :, :]  # [bs, seq_len, 1, dim]
+    else:
+        cos = cos.squeeze(axis=[0, 2])  # [seq_len, dim]
+        sin = sin.squeeze(axis=[0, 2])  # [seq_len, dim]
+        cos = cos[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
+        sin = sin[position_ids].unsqueeze(2)  # [bs, seq_len, 1, dim]
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class Qwen2MLP(nn.Layer):
+    def __init__(self, config: Qwen2Config, is_shared=False, skip_recompute_ops=None):
+        super().__init__()
+        if skip_recompute_ops is None:
+            skip_recompute_ops = {}
+        self.skip_recompute_ops = skip_recompute_ops
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
+        self.fuse_attention_ffn = config.fuse_attention_ffn
+
+        self.tensor_parallel_degree = config.tensor_parallel_degree
+
+        if config.sequence_parallel:
+            ColumnParallelLinear = ColumnSequenceParallelLinear
+            RowParallelLinear = RowSequenceParallelLinear
+
+        if config.tensor_parallel_degree > 1:
+            if self.fuse_attention_ffn:
+                self.gate_up_fused_proj = ColumnParallelLinear(
+                    self.hidden_size,
+                    self.intermediate_size * 2,
+                    gather_output=False,
+                    has_bias=False,
+                )
+            else:
+                self.gate_proj = ColumnParallelLinear(
+                    self.hidden_size,
+                    self.intermediate_size,
+                    gather_output=False,
+                    has_bias=False,
+                )
+                self.up_proj = ColumnParallelLinear(
+                    self.hidden_size,
+                    self.intermediate_size,
+                    gather_output=False,
+                    has_bias=False,
+                )
+            self.down_proj = RowParallelLinear(
+                self.intermediate_size,
+                self.hidden_size,
+                input_is_parallel=True,
+                has_bias=False,
+            )
+        else:
+            if self.fuse_attention_ffn:
+                self.gate_up_fused_proj = Linear(
+                    self.hidden_size, self.intermediate_size * 2, bias_attr=False
+                )
+            else:
+                self.gate_proj = Linear(
+                    self.hidden_size, self.intermediate_size, bias_attr=False
+                )  # w1
+                self.up_proj = Linear(
+                    self.hidden_size, self.intermediate_size, bias_attr=False
+                )  # w3
+            self.down_proj = Linear(
+                self.intermediate_size, self.hidden_size, bias_attr=False
+            )  # w2
+
+        if config.hidden_act == "silu":
+            self.act_fn = fusion_ops.swiglu
+            self.fuse_swiglu = True
+        else:
+            self.act_fn = ACT2FN[config.hidden_act]
+            self.fuse_swiglu = False
+
+    def forward(self, x):
+        if self.fuse_attention_ffn:
+            x = self.gate_up_fused_proj(x)
+            if self.fuse_swiglu:
+                y = None
+            else:
+                x, y = x.chunk(2, axis=-1)
+        else:
+            x, y = self.gate_proj(x), self.up_proj(x)
+
+        if self.fuse_swiglu:
+            x = self.act_fn(x, y)
+        else:
+            x = self.act_fn(x) * y
+
+        return self.down_proj(x)
+
+
+def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
+    """
+    This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, slen, num_key_value_heads, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+
+    hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
+    return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim])
+
+
+class Qwen2Attention(nn.Layer):
+    """
+    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+    and "Generating Long Sequences with Sparse Transformers".
+    """
+
+    def __init__(
+        self,
+        config: Qwen2Config,
+        layerwise_recompute: bool = True,
+        skip_recompute_ops=None,
+    ):
+        super().__init__()
+        if skip_recompute_ops is None:
+            skip_recompute_ops = {}
+        self.config = config
+        self.skip_recompute_ops = skip_recompute_ops
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+
+        self.head_dim = self.hidden_size // config.num_attention_heads
+
+        self.num_key_value_heads = config.num_key_value_heads
+        assert config.num_attention_heads // config.num_key_value_heads
+        self.num_key_value_groups = (
+            config.num_attention_heads // config.num_key_value_heads
+        )
+        self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.is_causal = True
+        self.attention_dropout = config.attention_dropout
+
+        self.seq_length = config.seq_length
+        self.sequence_parallel = config.sequence_parallel
+
+        self.fuse_attention_qkv = config.fuse_attention_qkv
+
+        # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
+        # Enable_recompute defaults to False and is controlled by Trainer
+        self.enable_recompute = False
+        self.layerwise_recompute = layerwise_recompute
+        self.recompute_granularity = config.recompute_granularity
+        if config.tensor_parallel_degree > 1:
+            assert (
+                self.num_heads % config.tensor_parallel_degree == 0
+            ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
+            self.num_heads = self.num_heads // config.tensor_parallel_degree
+
+            assert (
+                self.num_key_value_heads % config.tensor_parallel_degree == 0
+            ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
+            self.num_key_value_heads = (
+                self.num_key_value_heads // config.tensor_parallel_degree
+            )
+
+        self.use_fused_rope = config.use_fused_rope
+        if self.use_fused_rope:
+            if (
+                get_device_type() not in ["gpu", "xpu"]
+                or fused_rotary_position_embedding is None
+            ):
+                logging.warning(
+                    "Enable fuse rope in the config, but fuse rope is not available. "
+                    "Will disable fuse rope. Try using latest gpu version of Paddle."
+                )
+                self.use_fused_rope = False
+
+        if config.sequence_parallel:
+            ColumnParallelLinear = ColumnSequenceParallelLinear
+            RowParallelLinear = RowSequenceParallelLinear
+
+        if config.tensor_parallel_degree > 1:
+            if self.fuse_attention_qkv:
+                self.qkv_proj = ColumnParallelLinear(
+                    self.hidden_size,
+                    self.hidden_size
+                    + 2 * self.config.num_key_value_heads * self.head_dim,
+                    has_bias=True,
+                    gather_output=False,
+                )
+            else:
+                self.q_proj = ColumnParallelLinear(
+                    self.hidden_size,
+                    self.hidden_size,
+                    has_bias=True,
+                    gather_output=False,
+                )
+                self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False)  # fmt:skip
+                self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=True, gather_output=False)  # fmt:skip
+            self.o_proj = RowParallelLinear(
+                self.hidden_size,
+                self.hidden_size,
+                has_bias=False,
+                input_is_parallel=True,
+            )
+        else:
+            if self.fuse_attention_qkv:
+                self.qkv_proj = Linear(
+                    self.hidden_size,
+                    self.hidden_size
+                    + 2 * self.config.num_key_value_heads * self.head_dim,
+                )
+            else:
+                self.q_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=True)
+                self.k_proj = Linear(
+                    self.hidden_size,
+                    self.config.num_key_value_heads * self.head_dim,
+                    bias_attr=True,
+                )
+                self.v_proj = Linear(
+                    self.hidden_size,
+                    self.config.num_key_value_heads * self.head_dim,
+                    bias_attr=True,
+                )
+            self.o_proj = Linear(self.hidden_size, self.hidden_size, bias_attr=False)
+
+        self.rotary_emb = Qwen2RotaryEmbedding(
+            self.head_dim,
+            max_position_embeddings=self.max_position_embeddings,
+            base=self.rope_theta,
+        )
+
+        self.attn_func = scaled_dot_product_attention
+
+    def forward(
+        self,
+        hidden_states,
+        position_ids: Optional[Tuple[paddle.Tensor]] = None,
+        past_key_value: Optional[Tuple[paddle.Tensor]] = None,
+        attention_mask: Optional[paddle.Tensor] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
+        **kwargs,
+    ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+        # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
+
+        if self.fuse_attention_qkv:
+            mix_layer = self.qkv_proj(hidden_states)
+            if self.sequence_parallel:
+                target_shape = [
+                    -1,
+                    self.seq_length,
+                    self.num_key_value_heads,
+                    (self.num_key_value_groups + 2) * self.head_dim,
+                ]
+            else:
+                target_shape = [
+                    0,
+                    0,
+                    self.num_key_value_heads,
+                    (self.num_key_value_groups + 2) * self.head_dim,
+                ]
+            mix_layer = paddle.reshape_(mix_layer, target_shape)
+            query_states, key_states, value_states = paddle.split(
+                mix_layer,
+                num_or_sections=[
+                    self.num_key_value_groups * self.head_dim,
+                    self.head_dim,
+                    self.head_dim,
+                ],
+                axis=-1,
+            )
+            if self.gqa_or_mqa:
+                query_states = paddle.reshape_(
+                    query_states, [0, 0, self.num_heads, self.head_dim]
+                )
+        else:
+            query_states = self.q_proj(hidden_states)
+            key_states = self.k_proj(hidden_states)
+            value_states = self.v_proj(hidden_states)
+
+            if self.sequence_parallel:
+                target_query_shape = [
+                    -1,
+                    self.seq_length,
+                    self.num_heads,
+                    self.head_dim,
+                ]
+                target_key_value_shape = [
+                    -1,
+                    self.seq_length,
+                    self.num_key_value_heads,
+                    self.head_dim,
+                ]
+            else:
+                target_query_shape = [0, 0, self.num_heads, self.head_dim]
+                target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
+            query_states = query_states.reshape(shape=target_query_shape)
+            key_states = key_states.reshape(shape=target_key_value_shape)
+            value_states = value_states.reshape(shape=target_key_value_shape)
+
+        kv_seq_len = key_states.shape[-3]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-3]
+        if self.use_fused_rope:
+            assert past_key_value is None, "fuse rotary not support cache kv for now"
+            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+            query_states, key_states, _ = fused_rotary_position_embedding(
+                query_states,
+                key_states,
+                v=None,
+                sin=sin,
+                cos=cos,
+                position_ids=position_ids,
+                use_neox_rotary_style=False,
+            )
+        else:
+            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+            query_states, key_states = apply_rotary_pos_emb(
+                query_states, key_states, cos, sin, position_ids
+            )
+
+        # [bs, seq_len, num_head, head_dim]
+        if past_key_value is not None:
+            key_states = paddle.concat([past_key_value[0], key_states], axis=1)
+            value_states = paddle.concat([past_key_value[1], value_states], axis=1)
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
+        # repeat k/v heads if n_kv_heads < n_heads
+        paddle_version = float(paddle.__version__[:3])
+        if not self.config.use_flash_attention or (
+            (paddle_version != 0.0) and (paddle_version <= 2.6)
+        ):
+            key_states = repeat_kv(key_states, self.num_key_value_groups)
+            value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        outputs = self.attn_func(
+            query_states,
+            self.config,
+            key_states,
+            value_states,
+            attention_mask,
+            output_attentions,
+            attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+            training=self.training,
+            sequence_parallel=self.sequence_parallel,
+        )
+        if output_attentions:
+            attn_output, attn_weights = outputs
+        else:
+            attn_output = outputs
+
+        # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
+        # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        outputs = (attn_output,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        if use_cache:
+            outputs += (past_key_value,)
+
+        if type(outputs) is tuple and len(outputs) == 1:
+            outputs = outputs[0]
+
+        return outputs
+
+
+class Qwen2DecoderLayer(nn.Layer):
+    def __init__(
+        self,
+        config: Qwen2Config,
+        layerwise_recompute: bool = False,
+        skip_recompute_ops=None,
+    ):
+        super().__init__()
+        if skip_recompute_ops is None:
+            skip_recompute_ops = {}
+        self.config = config
+        self.skip_recompute_ops = skip_recompute_ops
+        self.hidden_size = config.hidden_size
+        self.self_attn = Qwen2Attention(
+            config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops
+        )
+
+        self.mlp = Qwen2MLP(config, skip_recompute_ops=skip_recompute_ops)
+        self.input_layernorm = Qwen2RMSNorm(config)
+        self.post_attention_layernorm = Qwen2RMSNorm(config)
+
+        # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
+        # Enable_recompute defaults to False and is controlled by Trainer
+        self.enable_recompute = False
+        self.layerwise_recompute = layerwise_recompute
+        self.recompute_granularity = config.recompute_granularity
+
+    def forward(
+        self,
+        hidden_states: paddle.Tensor,
+        position_ids: Optional[paddle.Tensor] = None,
+        attention_mask: Optional[paddle.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+        past_key_value: Optional[Tuple[paddle.Tensor]] = None,
+        use_cache: Optional[bool] = False,
+        attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
+        **kwargs,
+    ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
+        """
+        Args:
+            hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`paddle.Tensor`, *optional*): attention mask of size
+                `(batch, sequence_length)` where padding elements are indicated by 0.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
+        """
+
+        # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        outputs = self.self_attn(
+            hidden_states,
+            position_ids,
+            past_key_value,
+            attention_mask,
+            output_attentions,
+            use_cache,
+            attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+        )
+
+        if type(outputs) is tuple:
+            hidden_states = outputs[0]
+        else:
+            hidden_states = outputs
+
+        if output_attentions:
+            self_attn_weights = outputs[1]
+
+        if use_cache:
+            present_key_value = outputs[2 if output_attentions else 1]
+
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        if type(outputs) is tuple and len(outputs) == 1:
+            outputs = outputs[0]
+
+        return outputs
+
+
+class Qwen2PretrainedModel(PretrainedModel):
+    config_class = Qwen2Config
+    base_model_prefix = "qwen2"
+    _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"]
+
+    @classmethod
+    def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True):
+
+        from paddlenlp.transformers.conversion_utils import split_or_merge_func
+
+        fn = split_or_merge_func(
+            is_split=is_split,
+            tensor_parallel_degree=config.tensor_parallel_degree,
+            tensor_parallel_rank=config.tensor_parallel_rank,
+            num_attention_heads=config.num_attention_heads,
+        )
+
+        def get_tensor_parallel_split_mappings(num_layers):
+            final_actions = {}
+
+            base_actions = {
+                # Row Linear
+                "embed_tokens.weight": partial(fn, is_column=False),
+                "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
+                "layers.0.mlp.down_proj.weight": partial(fn, is_column=False),
+            }
+
+            if config.tie_word_embeddings:
+                base_actions["lm_head.weight"] = partial(fn, is_column=False)
+            else:
+                base_actions["lm_head.weight"] = partial(fn, is_column=True)
+
+            if not config.vocab_size % config.tensor_parallel_degree == 0:
+                base_actions.pop("lm_head.weight")
+                base_actions.pop("embed_tokens.weight")
+            # Column Linear
+            if config.fuse_attention_qkv:
+                base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(
+                    fn, is_column=True
+                )
+                base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(
+                    fn, is_column=True
+                )
+            else:
+                base_actions["layers.0.self_attn.q_proj.weight"] = partial(
+                    fn, is_column=True
+                )
+                base_actions["layers.0.self_attn.q_proj.bias"] = partial(
+                    fn, is_column=True
+                )
+                # if we have enough num_key_value_heads to split, then split it.
+                if config.num_key_value_heads % config.tensor_parallel_degree == 0:
+                    base_actions["layers.0.self_attn.k_proj.weight"] = partial(
+                        fn, is_column=True
+                    )
+                    base_actions["layers.0.self_attn.v_proj.weight"] = partial(
+                        fn, is_column=True
+                    )
+                    base_actions["layers.0.self_attn.k_proj.bias"] = partial(
+                        fn, is_column=True
+                    )
+                    base_actions["layers.0.self_attn.v_proj.bias"] = partial(
+                        fn, is_column=True
+                    )
+
+            if config.fuse_attention_ffn:
+                base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
+                    fn, is_column=True, is_naive_2fuse=True
+                )
+            else:
+                base_actions["layers.0.mlp.gate_proj.weight"] = partial(
+                    fn, is_column=True
+                )
+                base_actions["layers.0.mlp.up_proj.weight"] = partial(
+                    fn, is_column=True
+                )
+
+            for key, action in base_actions.items():
+                if "layers.0." in key:
+                    for i in range(num_layers):
+                        final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
+                final_actions[key] = action
+
+            return final_actions
+
+        mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
+
+        return mappings
+
+    @classmethod
+    def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False):
+        # return parameter fuse utils
+        from paddlenlp.transformers.conversion_utils import split_or_fuse_func
+
+        fn = split_or_fuse_func(is_fuse=is_fuse)
+
+        # last key is fused key, other keys are to be fused.
+        fuse_qkv_keys = [
+            (
+                "layers.0.self_attn.q_proj.weight",
+                "layers.0.self_attn.k_proj.weight",
+                "layers.0.self_attn.v_proj.weight",
+                "layers.0.self_attn.qkv_proj.weight",
+            ),
+            (
+                "layers.0.self_attn.q_proj.bias",
+                "layers.0.self_attn.k_proj.bias",
+                "layers.0.self_attn.v_proj.bias",
+                "layers.0.self_attn.qkv_proj.bias",
+            ),
+        ]
+
+        fuse_gate_up_keys = (
+            "layers.0.mlp.gate_proj.weight",
+            "layers.0.mlp.up_proj.weight",
+            "layers.0.mlp.gate_up_fused_proj.weight",
+        )
+        num_heads = config.num_attention_heads
+        num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
+        fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
+        fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)
+
+        final_actions = {}
+        if is_fuse:
+            if fuse_attention_qkv:
+                for i in range(config.num_hidden_layers):
+                    for fuse_keys in fuse_qkv_keys:
+                        keys = tuple(
+                            [
+                                key.replace("layers.0.", f"layers.{i}.")
+                                for key in fuse_keys
+                            ]
+                        )
+                        final_actions[keys] = partial(
+                            fn,
+                            is_qkv=True,
+                            num_heads=num_heads,
+                            num_key_value_heads=num_key_value_heads,
+                        )
+            if fuse_attention_ffn:
+                for i in range(config.num_hidden_layers):
+                    keys = tuple(
+                        [
+                            key.replace("layers.0.", f"layers.{i}.")
+                            for key in fuse_gate_up_keys
+                        ]
+                    )
+                    final_actions[keys] = fn
+        else:
+            if not fuse_attention_qkv:
+                for i in range(config.num_hidden_layers):
+                    for fuse_keys in fuse_qkv_keys:
+                        keys = tuple(
+                            [
+                                key.replace("layers.0.", f"layers.{i}.")
+                                for key in fuse_keys
+                            ]
+                        )
+                        final_actions[keys] = partial(
+                            fn,
+                            split_nums=3,
+                            is_qkv=True,
+                            num_heads=num_heads,
+                            num_key_value_heads=num_key_value_heads,
+                        )
+            if not fuse_attention_ffn:
+                for i in range(config.num_hidden_layers):
+                    keys = tuple(
+                        [
+                            key.replace("layers.0.", f"layers.{i}.")
+                            for key in fuse_gate_up_keys
+                        ]
+                    )
+                    final_actions[keys] = partial(fn, split_nums=2)
+        return final_actions
+
+
+class Qwen2Model(Qwen2PretrainedModel):
+    """
+    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+    Args:
+        config: Qwen2Config
+    """
+
+    def __init__(self, config: Qwen2Config):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.hidden_size = config.hidden_size
+        self.sequence_parallel = config.sequence_parallel
+        self.recompute_granularity = config.recompute_granularity
+        self.no_recompute_layers = (
+            config.no_recompute_layers if config.no_recompute_layers is not None else []
+        )
+
+        # Recompute defaults to False and is controlled by Trainer
+        self.enable_recompute = False
+        if (
+            config.tensor_parallel_degree > 1
+            and config.vocab_size % config.tensor_parallel_degree == 0
+        ):
+            self.embed_tokens = mpu.VocabParallelEmbedding(
+                self.vocab_size,
+                self.hidden_size,
+                weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
+            )
+        else:
+            self.embed_tokens = nn.Embedding(
+                self.vocab_size,
+                self.hidden_size,
+            )
+
+        self.layers = nn.LayerList(
+            [
+                Qwen2DecoderLayer(
+                    config=config,
+                    layerwise_recompute=layer_idx not in self.no_recompute_layers,
+                )
+                for layer_idx in range(config.num_hidden_layers)
+            ]
+        )
+        self.norm = Qwen2RMSNorm(config)
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    @staticmethod
+    def _prepare_decoder_attention_mask(
+        attention_mask, input_shape, past_key_values_length, dtype
+    ):
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            if len(attention_mask.shape) == 2:
+                expanded_attn_mask = _expand_2d_mask(
+                    attention_mask, dtype, tgt_length=input_shape[-1]
+                )
+                # For decoding phase in generation, seq_length = 1, we don't need to add causal mask
+                if input_shape[-1] > 1:
+                    combined_attention_mask = _make_causal_mask(
+                        input_shape,
+                        past_key_values_length=past_key_values_length,
+                    )
+                    expanded_attn_mask = expanded_attn_mask & combined_attention_mask
+            # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
+            elif len(attention_mask.shape) == 3:
+                expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
+            # if attention_mask is already 4-D, do nothing
+            else:
+                expanded_attn_mask = attention_mask
+        else:
+            expanded_attn_mask = _make_causal_mask(
+                input_shape,
+                past_key_values_length=past_key_values_length,
+            )
+        # Convert bool attention_mask to float attention mask, which will be added to attention_scores later
+        if get_device_type() == "xpu":
+            x = paddle.to_tensor(0.0, dtype="float32")
+            y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
+            expanded_attn_mask = paddle.where(expanded_attn_mask, x, y)
+        else:
+            expanded_attn_mask = paddle.where(
+                expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min
+            ).astype(dtype)
+        return expanded_attn_mask
+
+    def forward(
+        self,
+        input_ids: paddle.Tensor = None,
+        position_ids: Optional[paddle.Tensor] = None,
+        attention_mask: Optional[paddle.Tensor] = None,
+        inputs_embeds: Optional[paddle.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        past_key_values: Optional[List[paddle.Tensor]] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        attn_mask_startend_row_indices=None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states  # fmt:skip
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+            )
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = inputs_embeds.shape
+        else:
+            raise ValueError(
+                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
+            )
+
+        if past_key_values is None:
+            past_key_values = tuple([None] * len(self.layers))
+        # NOTE: to make cache can be clear in-time
+        past_key_values = list(past_key_values)
+
+        seq_length_with_past = seq_length
+        cache_length = 0
+        if past_key_values[0] is not None:
+            cache_length = past_key_values[0][0].shape[1]
+            seq_length_with_past += cache_length
+        if inputs_embeds is None:
+            # [bs, seq_len, dim]
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if self.sequence_parallel:
+            # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
+            bs, seq_len, hidden_size = inputs_embeds.shape
+            inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size])
+            # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
+            inputs_embeds = ScatterOp.apply(inputs_embeds)
+
+        # [bs, seq_len]
+        attention_mask = (
+            paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
+            if attention_mask is None
+            else attention_mask
+        )
+        attention_mask = self._prepare_decoder_attention_mask(
+            attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
+        )  # [bs, 1, seq_len, seq_len]
+        if self.config.use_flash_attention:
+            attention_mask = None if is_casual_mask(attention_mask) else attention_mask
+
+        if position_ids is None:
+            position_ids = paddle.arange(seq_length, dtype="int64").expand(
+                (batch_size, seq_length)
+            )
+
+        hidden_states = inputs_embeds
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = () if use_cache else None
+
+        for idx, (decoder_layer) in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            past_key_value = (
+                past_key_values[idx] if past_key_values is not None else None
+            )
+
+            has_gradient = not hidden_states.stop_gradient
+            if (
+                self.enable_recompute
+                and idx not in self.no_recompute_layers
+                and has_gradient
+                and self.recompute_granularity == "full"
+            ):
+                layer_outputs = self.recompute_training_full(
+                    decoder_layer,
+                    hidden_states,
+                    position_ids,
+                    attention_mask,
+                    output_attentions,
+                    past_key_value,
+                    use_cache,
+                    attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    position_ids,
+                    attention_mask,
+                    output_attentions,
+                    past_key_value,
+                    use_cache,
+                    attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+                )
+
+            # NOTE: clear outdate cache after it has been used for memory saving
+            past_key_value = past_key_values[idx] = None
+            if type(layer_outputs) is tuple:
+                hidden_states = layer_outputs[0]
+            else:
+                hidden_states = layer_outputs
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+                if v is not None
+            )
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+
+class Qwen2PretrainingCriterion(nn.Layer):
+    """
+    Criterion for Mixtral.
+    It calculates the final loss.
+    """
+
+    def __init__(self, config: Qwen2Config):
+        super(Qwen2PretrainingCriterion, self).__init__()
+        self.ignore_index = getattr(config, "ignore_index", -100)
+        self.config = config
+        self.enable_parallel_cross_entropy = (
+            config.tensor_parallel_degree > 1 and config.tensor_parallel_output
+        )
+
+        if (
+            self.enable_parallel_cross_entropy
+        ):  # and False: # and lm_head is distributed
+            self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index)
+        else:
+            self.loss_func = paddle.nn.CrossEntropyLoss(
+                reduction="none", ignore_index=self.ignore_index
+            )
+
+    def forward(self, prediction_scores, masked_lm_labels):
+        if self.enable_parallel_cross_entropy:
+            if prediction_scores.shape[-1] == self.config.vocab_size:
+                logging.warning(
+                    f"enable_parallel_cross_entropy, the vocab_size should be splitted: {prediction_scores.shape[-1]}, {self.config.vocab_size}"
+                )
+                self.loss_func = paddle.nn.CrossEntropyLoss(
+                    reduction="none", ignore_index=self.ignore_index
+                )
+
+        with paddle.amp.auto_cast(False):
+            masked_lm_loss = self.loss_func(
+                prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)
+            )
+
+            # skip ignore_index which loss == 0
+            # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0]
+            # loss = paddle.mean(masked_lm_loss)
+            binary_sequence = paddle.where(
+                masked_lm_loss > 0,
+                paddle.ones_like(masked_lm_loss),
+                paddle.zeros_like(masked_lm_loss),
+            )
+            count = paddle.sum(binary_sequence)
+            if count == 0:
+                loss = paddle.sum(masked_lm_loss * binary_sequence)
+            else:
+                loss = paddle.sum(masked_lm_loss * binary_sequence) / count
+
+        return loss
+
+
+class Qwen2LMHead(nn.Layer):
+    def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=False):
+        super(Qwen2LMHead, self).__init__()
+        self.config = config
+        if (
+            config.tensor_parallel_degree > 1
+            and config.vocab_size % config.tensor_parallel_degree == 0
+        ):
+            vocab_size = config.vocab_size // config.tensor_parallel_degree
+        else:
+            vocab_size = config.vocab_size
+
+        self.transpose_y = transpose_y
+        if transpose_y:
+            if embedding_weights is not None:
+                self.weight = embedding_weights
+            else:
+                self.weight = self.create_parameter(
+                    shape=[vocab_size, config.hidden_size],
+                    dtype=paddle.get_default_dtype(),
+                )
+        else:
+            if vocab_size != config.vocab_size:
+                self.weight = self.create_parameter(
+                    shape=[config.hidden_size, vocab_size],
+                    dtype=paddle.get_default_dtype(),
+                )
+            else:
+                self.weight = self.create_parameter(
+                    shape=[config.hidden_size, vocab_size],
+                    dtype=paddle.get_default_dtype(),
+                )
+
+        # Must set distributed attr for Tensor Parallel !
+        self.weight.is_distributed = (
+            True if (vocab_size != config.vocab_size) else False
+        )
+        if self.weight.is_distributed:
+            # for tie_word_embeddings
+            self.weight.split_axis = 0 if self.transpose_y else 1
+
+    def forward(self, hidden_states, tensor_parallel_output=None):
+        if self.config.sequence_parallel:
+            hidden_states = GatherOp.apply(hidden_states)
+            seq_length = self.config.seq_length
+            hidden_states = paddle.reshape_(
+                hidden_states, [-1, seq_length, self.config.hidden_size]
+            )
+
+        if tensor_parallel_output is None:
+            tensor_parallel_output = self.config.tensor_parallel_output
+
+        logits = parallel_matmul(
+            hidden_states,
+            self.weight,
+            transpose_y=self.transpose_y,
+            tensor_parallel_output=tensor_parallel_output,
+        )
+        return logits
+
+
+class Qwen2ForCausalLM(Qwen2PretrainedModel):
+    enable_to_static_method = True
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config: Qwen2Config):
+        super().__init__(config)
+        self.qwen2 = Qwen2Model(config)
+        if config.tie_word_embeddings:
+            self.lm_head = Qwen2LMHead(
+                config,
+                embedding_weights=self.qwen2.embed_tokens.weight,
+                transpose_y=True,
+            )
+            self.tie_weights()
+        else:
+            self.lm_head = Qwen2LMHead(config)
+        self.criterion = Qwen2PretrainingCriterion(config)
+        self.vocab_size = config.vocab_size
+
+    def get_input_embeddings(self):
+        return self.qwen2.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.qwen2.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.qwen2 = decoder
+
+    def get_decoder(self):
+        return self.qwen2
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        use_cache=False,
+        past_key_values=None,
+        attention_mask=None,
+        inputs_embeds=None,
+        **kwargs,
+    ):
+        batch_size, seq_length = input_ids.shape
+        position_ids = kwargs.get(
+            "position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))
+        )
+        if past_key_values:
+            input_ids = input_ids[:, -1].unsqueeze(axis=-1)
+            position_ids = position_ids[:, -1].unsqueeze(-1)
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": use_cache,
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs
+
+    def _get_model_inputs_spec(self, dtype: str):
+        return {
+            "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
+            "attention_mask": paddle.static.InputSpec(
+                shape=[None, None], dtype="int64"
+            ),
+            "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"),
+        }
+
+    @staticmethod
+    def update_model_kwargs_for_generation(
+        outputs, model_kwargs, is_encoder_decoder=False
+    ):
+        # update cache
+        if (
+            isinstance(outputs, tuple)
+            and len(outputs) > 1
+            and not isinstance(outputs[1], paddle.Tensor)
+        ):
+            model_kwargs["past_key_values"] = outputs[1]
+
+        if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs:
+            model_kwargs["past_key_values"] = outputs.past_key_values
+
+        # update position_ids
+        if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
+            position_ids = model_kwargs["position_ids"]
+            model_kwargs["position_ids"] = paddle.concat(
+                [position_ids, position_ids[..., -1:] + 1], axis=-1
+            )
+
+        if not is_encoder_decoder and "attention_mask" in model_kwargs:
+            # TODO: support attention mask for other models
+            attention_mask = model_kwargs["attention_mask"]
+            if len(attention_mask.shape) == 2:
+                model_kwargs["attention_mask"] = paddle.concat(
+                    [
+                        attention_mask,
+                        paddle.ones(
+                            [attention_mask.shape[0], 1], dtype=attention_mask.dtype
+                        ),
+                    ],
+                    axis=-1,
+                )
+            elif len(attention_mask.shape) == 4:
+                model_kwargs["attention_mask"] = paddle.concat(
+                    [
+                        attention_mask,
+                        paddle.ones(
+                            [*attention_mask.shape[:3], 1], dtype=attention_mask.dtype
+                        ),
+                    ],
+                    axis=-1,
+                )[:, :, -1:, :]
+
+        return model_kwargs
+
+    def forward(
+        self,
+        input_ids: paddle.Tensor = None,
+        position_ids: Optional[paddle.Tensor] = None,
+        attention_mask: Optional[paddle.Tensor] = None,
+        inputs_embeds: Optional[paddle.Tensor] = None,
+        labels: Optional[paddle.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        past_key_values: Optional[List[paddle.Tensor]] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        attn_mask_startend_row_indices=None,
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        r"""
+        Args:
+            labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
+
+        >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+        >>> prompt = "Hey, are you conscious? Can you talk to me?"
+        >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >>> # Generate
+        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+        ```"""
+
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        if attn_mask_startend_row_indices is not None and attention_mask is not None:
+            logging.warning(
+                "You have provided both attn_mask_startend_row_indices and attention_mask. "
+                "The attn_mask_startend_row_indices will be used."
+            )
+            attention_mask = None
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.qwen2(
+            input_ids=input_ids,
+            position_ids=position_ids,
+            attention_mask=attention_mask,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+        )
+
+        hidden_states = outputs[0]
+
+        # if labels is None,means we need full output, instead of tensor_parallel_output
+        # tensor_parallel_output is together with ParallelCrossEntropy
+        tensor_parallel_output = (
+            self.config.tensor_parallel_output
+            and self.config.tensor_parallel_degree > 1
+        )
+
+        logits = self.lm_head(
+            hidden_states, tensor_parallel_output=tensor_parallel_output
+        )
+        loss = None
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )

+ 41 - 19
paddlex/inference/models/doc_vlm/predictor.py

@@ -34,8 +34,16 @@ class DocVLMPredictor(BasePredictor):
             *args: Arbitrary positional arguments passed to the superclass.
             **kwargs: Arbitrary keyword arguments passed to the superclass.
         """
+        import paddle
+
         super().__init__(*args, **kwargs)
         self.device = kwargs.get("device", None)
+        self.dtype = (
+            "bfloat16"
+            if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
+            else "float32"
+        )
+
         self.infer, self.processor = self._build(**kwargs)
 
     def _build_batch_sampler(self):
@@ -44,7 +52,7 @@ class DocVLMPredictor(BasePredictor):
         Returns:
             DocVLMBatchSampler: An instance of DocVLMBatchSampler.
         """
-        return DocVLMBatchSampler()
+        return DocVLMBatchSampler(self.model_name)
 
     def _get_result_class(self):
         """Returns the result class, DocVLMResult.
@@ -61,9 +69,10 @@ class DocVLMPredictor(BasePredictor):
             model: An instance of Paddle model, could be either a dynamic model or a static model.
             processor: The correspounding processor for the model.
         """
-        import paddle
+        from .modeling import PPChart2TableInference, PPDocBeeInference
 
-        from .modeling import PPDocBeeInference
+        # build processor
+        processor = self.build_processor()
 
         # build model
         if "PP-DocBee" in self.model_name:
@@ -71,18 +80,24 @@ class DocVLMPredictor(BasePredictor):
                 raise ValueError(
                     f"PP-DocBee series do not support `use_hpip=True` for now."
                 )
-            dtype = (
-                "bfloat16"
-                if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
-                else "float32"
-            )
             with TemporaryDeviceChanger(self.device):
-                model = PPDocBeeInference.from_pretrained(self.model_dir, dtype=dtype)
+                model = PPDocBeeInference.from_pretrained(
+                    self.model_dir, dtype=self.dtype
+                )
+        elif "PP-Chart2Table" in self.model_name:
+            if kwargs.get("use_hpip", False):
+                raise ValueError(
+                    f"PP-Chart2Table series do not support `use_hpip=True` for now."
+                )
+            with TemporaryDeviceChanger(self.device):
+                model = PPChart2TableInference.from_pretrained(
+                    self.model_dir,
+                    dtype=self.dtype,
+                    pad_token_id=processor.tokenizer.eos_token_id,
+                )
         else:
             raise NotImplementedError(f"Model {self.model_name} is not supported.")
 
-        # build processor
-        processor = self.build_processor()
         return model, processor
 
     def process(self, data: List[dict], **kwargs):
@@ -96,15 +111,11 @@ class DocVLMPredictor(BasePredictor):
         Returns:
             dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
         """
-        assert (
-            isinstance(data, List) and len(data) == 1
-        ), "data must be a list of length 1"
-        assert isinstance(data[0], dict)
+        assert all(isinstance(i, dict) for i in data)
 
-        data = data[0]
         src_data = copy.copy(data)
         # preprocess
-        data = self.processor.preprocess(**data)
+        data = self.processor.preprocess(data)
         data = self._switch_inputs_to_device(data)
 
         # do infer
@@ -118,8 +129,13 @@ class DocVLMPredictor(BasePredictor):
         return result_dict
 
     def build_processor(self, **kwargs):
-        from ..common.tokenizer import MIXQwen2Tokenizer
-        from .processors import PPDocBeeProcessor, Qwen2VLImageProcessor
+        from ..common.tokenizer import MIXQwen2Tokenizer, QWenTokenizer
+        from .processors import (
+            GOTImageProcessor,
+            PPChart2TableProcessor,
+            PPDocBeeProcessor,
+            Qwen2VLImageProcessor,
+        )
 
         if "PP-DocBee" in self.model_name:
             image_processor = Qwen2VLImageProcessor()
@@ -127,6 +143,12 @@ class DocVLMPredictor(BasePredictor):
             return PPDocBeeProcessor(
                 image_processor=image_processor, tokenizer=tokenizer
             )
+        elif "PP-Chart2Table" in self.model_name:
+            image_processor = GOTImageProcessor(1024)
+            tokenizer = QWenTokenizer.from_pretrained(self.model_dir)
+            return PPChart2TableProcessor(
+                image_processor=image_processor, tokenizer=tokenizer, dtype=self.dtype
+            )
         else:
             raise NotImplementedError
 

+ 97 - 0
paddlex/inference/models/doc_vlm/processors/GOT_ocr_2_0.py

@@ -0,0 +1,97 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Union
+
+import numpy as np
+import paddle
+import requests
+from paddle.vision import transforms
+from PIL import Image
+
+from ....utils.benchmark import benchmark
+
+MEAN = (0.48145466, 0.4578275, 0.40821073)
+STD = (0.26862954, 0.26130258, 0.27577711)
+
+
+class GOTImageProcessor(object):
+    def __init__(self, image_size=1024):
+
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize((image_size, image_size), interpolation="bicubic"),
+                transforms.ToTensor(),
+                transforms.Normalize(MEAN, STD),
+            ]
+        )
+
+    def __call__(self, image):
+        return self.transform(image)
+
+
+class PPChart2TableProcessor(object):
+    def __init__(self, image_processor, tokenizer, dtype, **kwargs):
+        self.image_processor = image_processor
+        self.tokenizer = tokenizer
+        self.dtype = dtype
+
+        prompt = (
+            "<|im_start|>system\n"
+            "You should follow the instructions carefully and explain your answers in detail.<|im_end|><|im_start|>user\n"
+            "<img>" + "<imgpad>" * 256 + "</img>\n"
+            "Chart to table<|im_end|><|im_start|>assistant\n"
+        )
+        self.input_ids = paddle.to_tensor(self.tokenizer([prompt]).input_ids)
+
+    @benchmark.timeit
+    def preprocess(self, image: Union[str, Image.Image, np.ndarray, Dict, List]):
+        if isinstance(image, (str, Image.Image, np.ndarray)):
+            image = [image]
+        elif isinstance(image, dict):
+            image = [image["image"]]
+
+        assert isinstance(image, list)
+        images = [
+            image_["image"] if isinstance(image_, dict) else image_ for image_ in image
+        ]
+        images = [
+            self.image_processor(self._load_image(image)).unsqueeze(0).to(self.dtype)
+            for image in images
+        ]
+        img_cnt = len(images)
+
+        input_ids = paddle.tile(self.input_ids, [img_cnt, 1])
+
+        return {"input_ids": input_ids, "images": images}
+
+    @benchmark.timeit
+    def postprocess(self, model_pred, *args, **kwargs):
+        return self.tokenizer.batch_decode(
+            model_pred[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
+        )
+
+    def _load_image(self, image_file):
+        from io import BytesIO
+
+        if isinstance(image_file, Image.Image):
+            image = image_file.convert("RGB")
+        elif isinstance(image_file, np.ndarray):
+            image = Image.fromarray(image_file)
+        elif image_file.startswith("http") or image_file.startswith("https"):
+            response = requests.get(image_file)
+            image = Image.open(BytesIO(response.content)).convert("RGB")
+        else:
+            image = Image.open(image_file).convert("RGB")
+        return image

+ 1 - 0
paddlex/inference/models/doc_vlm/processors/__init__.py

@@ -12,4 +12,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from .GOT_ocr_2_0 import GOTImageProcessor, PPChart2TableProcessor
 from .qwen2_vl import PPDocBeeProcessor, Qwen2VLImageProcessor

+ 7 - 1
paddlex/inference/models/doc_vlm/processors/qwen2_vl.py

@@ -670,10 +670,16 @@ class PPDocBeeProcessor(Qwen2VLProcessor):
     """
 
     @benchmark.timeit
-    def preprocess(self, image: Union[str, Image.Image, np.ndarray], query: str):
+    def preprocess(self, input_dicts):
         """
         PreProcess for PP-DocBee Series
         """
+        assert (
+            isinstance(input_dicts, list) and len(input_dicts) == 1
+        ), f"PP-DocBee series only supports batchsize of one, but received {len(input_dicts)} samples."
+        input_dict = input_dicts[0]
+        image = input_dict["image"]
+        query = input_dict["query"]
         image_inputs = fetch_image(image)
         image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>"
         text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{query}<|im_end|>\n<|im_start|>assistant\n"

+ 1 - 0
paddlex/inference/utils/official_models.py

@@ -335,6 +335,7 @@ PP-LCNet_x1_0_vehicle_attribute_infer.tar",
     "YOLO-Worldv2-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/YOLO-Worldv2-L_infer.tar",
     "PP-DocBee-2B": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-DocBee-2B_infer.tar",
     "PP-DocBee-7B": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-DocBee-7B_infer.tar",
+    "PP-Chart2Table": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-Chart2Table_infer.tar",
 }
 
 

+ 1 - 1
paddlex/modules/doc_vlm/model_list.py

@@ -13,4 +13,4 @@
 # limitations under the License.
 
 
-MODELS = ["PP-DocBee-2B", "PP-DocBee-7B"]
+MODELS = ["PP-DocBee-2B", "PP-DocBee-7B", "PP-Chart2Table"]

+ 2 - 0
setup.py

@@ -69,6 +69,7 @@ DEP_SPECS = {
     "shapely": "",
     "soundfile": "",
     "starlette": ">= 0.36",
+    "tiktoken": "",
     "tokenizers": "== 0.19.1",
     "tqdm": "",
     "typing-extensions": "",
@@ -116,6 +117,7 @@ EXTRAS = {
             # For the same reason as in `cv`
             "PyMuPDF",
             "regex",
+            "tiktoken",
         ],
         "ie": [
             "ftfy",