Преглед изворни кода

add whisper and docbee models to benchmark (#3819)

* add multilingual_speech_recognition module to benchmark

* add docbee to benchmark

* add docbee to benchmark
zhang-prog пре 7 месеци
родитељ
комит
ad8719438a

+ 4 - 2
paddlex/inference/common/batch_sampler/audio_batch_sampler.py

@@ -57,9 +57,11 @@ class AudioBatchSampler(BaseBatchSampler):
             if inputs.startswith("http"):
                 inputs = self._download_from_url(inputs)
             yield [inputs]
+        elif isinstance(inputs, list):
+            yield inputs
         else:
-            logging.warning(
-                f"Not supported input data type! Only `str` are supported, but got: {input}."
+            raise TypeError(
+                f"Not supported input data type! Only `str` are supported, but got: {type(inputs)}."
             )
 
     @BaseBatchSampler.batch_size.setter

+ 4 - 2
paddlex/inference/common/batch_sampler/doc_vlm_batch_sampler.py

@@ -38,9 +38,11 @@ class DocVLMBatchSampler(BaseBatchSampler):
         """
         if isinstance(inputs, dict):
             yield [inputs]
+        elif isinstance(inputs, list) and all(isinstance(i, dict) for i in inputs):
+            yield inputs
         else:
-            logging.warning(
-                f"Not supported input data type! Only `dict` are supported, but got: {input}."
+            raise TypeError(
+                f"Not supported input data type! Only `dict` are supported, but got: {type(inputs)}."
             )
 
     @BaseBatchSampler.batch_size.setter

+ 8 - 0
paddlex/inference/models/doc_vlm/modeling/qwen2_vl.py

@@ -28,6 +28,11 @@ from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
 from paddle.distributed.fleet.utils import recompute
 
 from .....utils import logging
+from ....utils.benchmark import (
+    benchmark,
+    get_inference_operations,
+    set_inference_operations,
+)
 from ...common.vlm.activations import ACT2FN
 from ...common.vlm.bert_padding import index_first_axis, pad_input, unpad_input
 from ...common.vlm.flash_attn_utils import has_flash_attn_func
@@ -2575,6 +2580,9 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel):
 
 
 class PPDocBeeInference(Qwen2VLForConditionalGeneration):
+    set_inference_operations(get_inference_operations() + ["docbee_generate"])
+
+    @benchmark.timeit_with_options(name="docbee_generate")
     def generate(self, inputs, **kwargs):
         max_new_tokens = kwargs.get("max_new_tokens", 2048)
         temperature = kwargs.get("temperature", 0.1)

+ 3 - 0
paddlex/inference/models/doc_vlm/processors/qwen2_vl.py

@@ -23,6 +23,7 @@ import requests
 from PIL import Image
 
 from .....utils import logging
+from ....utils.benchmark import benchmark
 from ...common.vision.funcs import resize
 from .common import (
     BatchFeature,
@@ -668,6 +669,7 @@ class PPDocBeeProcessor(Qwen2VLProcessor):
     PP-DocBee processor, based on Qwen2VLProcessor
     """
 
+    @benchmark.timeit
     def preprocess(self, image: Union[str, Image.Image, np.ndarray], query: str):
         """
         PreProcess for PP-DocBee Series
@@ -686,6 +688,7 @@ class PPDocBeeProcessor(Qwen2VLProcessor):
 
         return rst_inputs
 
+    @benchmark.timeit
     def postprocess(self, model_pred, *args, **kwargs):
         """
         Post process adapt for PaddleX

+ 7 - 1
paddlex/inference/models/multilingual_speech_recognition/processors.py

@@ -22,6 +22,11 @@ import numpy as np
 import paddle
 
 from ....utils.deps import function_requires_deps, is_dep_available
+from ...utils.benchmark import (
+    benchmark,
+    get_inference_operations,
+    set_inference_operations,
+)
 from ..common.tokenizer import GPTTokenizer
 
 if is_dep_available("soundfile"):
@@ -1825,7 +1830,8 @@ class Whisper(paddle.nn.Layer):
         return cache, hooks
 
     detect_language = detect_language
-    transcribe = transcribe
+    set_inference_operations(get_inference_operations() + ["speech_transcribe"])
+    transcribe = benchmark.timeit_with_options(name="speech_transcribe")(transcribe)
     decode = decode