|
|
@@ -28,6 +28,11 @@ from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
|
|
|
from paddle.distributed.fleet.utils import recompute
|
|
|
|
|
|
from .....utils import logging
|
|
|
+from ....utils.benchmark import (
|
|
|
+ benchmark,
|
|
|
+ get_inference_operations,
|
|
|
+ set_inference_operations,
|
|
|
+)
|
|
|
from ...common.vlm.activations import ACT2FN
|
|
|
from ...common.vlm.bert_padding import index_first_axis, pad_input, unpad_input
|
|
|
from ...common.vlm.flash_attn_utils import has_flash_attn_func
|
|
|
@@ -2575,6 +2580,9 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
|
|
|
|
|
|
|
|
|
class PPDocBeeInference(Qwen2VLForConditionalGeneration):
|
|
|
+ set_inference_operations(get_inference_operations() + ["docbee_generate"])
|
|
|
+
|
|
|
+ @benchmark.timeit_with_options(name="docbee_generate")
|
|
|
def generate(self, inputs, **kwargs):
|
|
|
max_new_tokens = kwargs.get("max_new_tokens", 2048)
|
|
|
temperature = kwargs.get("temperature", 0.1)
|