Jelajahi Sumber

[Feat] Support lazy imports (#3764)

* Fix style

* Fix comment

* Improve

* Align opencv variants

* Fix header comments

* Fix 3D op bugs

* Rename 3d_bev_detection to legal module name m_3d_bev_detection

* Fix bug

* Fix requirements

* Support extras

* Use importlib.metadata

* Remove quick_check

* Allow repos not installing packages

* Fix style rule

* Remove unused code

* Remove unused code

* Always no deps

* Implement lazy imports

* Fix and update

* Fix import openai

* Fix bugs and protect pipeline initialization

* pandas as required dep

* Fix bugs

* Add constraints file

* Fix bugs

* Remove print

* Fix bug

* Hack for opencv

* Temporarily remove minimum version for paddle2onnx

* Alter paddle2onnx logic

* Force using patched albumentations and nuscenes-devkit

* Patch albucore

* Patch imgaug

* Fix bug

* Downgrade albucore version

* Fix bugs
Lin Manhui 7 bulan lalu
induk
melakukan
f4a679ff9f
100 mengubah file dengan 2705 tambahan dan 2112 penghapusan
  1. 11 4
      .pre-commit-config.yaml
  2. 192 0
      .precommit/check_imports.py
  3. 13 18
      .precommit/check_license_headers.py
  4. 12 25
      paddlex/__init__.py
  5. 4 2
      paddlex/inference/common/reader/audio_reader.py
  6. 10 8
      paddlex/inference/common/reader/image_reader.py
  7. 5 1
      paddlex/inference/models/anomaly_detection/processors.py
  8. 2 0
      paddlex/inference/models/base/predictor/base_predictor.py
  9. 30 10
      paddlex/inference/models/common/static_infer.py
  10. 6 1
      paddlex/inference/models/common/tokenizer/gpt_tokenizer.py
  11. 24 13
      paddlex/inference/models/common/tokenizer/tokenizer_utils.py
  12. 12 3
      paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py
  13. 10 2
      paddlex/inference/models/common/ts/funcs.py
  14. 5 1
      paddlex/inference/models/common/ts/processors.py
  15. 8 1
      paddlex/inference/models/common/vision/funcs.py
  16. 25 20
      paddlex/inference/models/common/vision/processors.py
  17. 17 9
      paddlex/inference/models/formula_recognition/processors.py
  18. 9 2
      paddlex/inference/models/formula_recognition/result.py
  19. 5 1
      paddlex/inference/models/instance_segmentation/result.py
  20. 6 1
      paddlex/inference/models/keypoint_detection/processors.py
  21. 7 2
      paddlex/inference/models/keypoint_detection/result.py
  22. 5 6
      paddlex/inference/models/m_3d_bev_detection/predictor.py
  23. 15 8
      paddlex/inference/models/m_3d_bev_detection/processors.py
  24. 7 1
      paddlex/inference/models/multilingual_speech_recognition/predictor.py
  25. 1766 1761
      paddlex/inference/models/multilingual_speech_recognition/processors.py
  26. 7 1
      paddlex/inference/models/object_detection/processors.py
  27. 6 1
      paddlex/inference/models/open_vocabulary_detection/processors/common.py
  28. 16 6
      paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py
  29. 9 9
      paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py
  30. 5 1
      paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py
  31. 5 1
      paddlex/inference/models/semantic_segmentation/processors.py
  32. 5 1
      paddlex/inference/models/table_structure_recognition/result.py
  33. 9 2
      paddlex/inference/models/text_detection/processors.py
  34. 5 1
      paddlex/inference/models/text_detection/result.py
  35. 5 1
      paddlex/inference/models/text_recognition/processors.py
  36. 6 2
      paddlex/inference/models/ts_anomaly_detection/result.py
  37. 5 1
      paddlex/inference/models/ts_classification/result.py
  38. 5 1
      paddlex/inference/models/ts_forecasting/processors.py
  39. 5 1
      paddlex/inference/models/ts_forecasting/result.py
  40. 14 16
      paddlex/inference/models/video_classification/processors.py
  41. 5 1
      paddlex/inference/models/video_classification/result.py
  42. 16 3
      paddlex/inference/models/video_detection/processors.py
  43. 5 1
      paddlex/inference/models/video_detection/result.py
  44. 2 0
      paddlex/inference/pipelines/anomaly_detection/pipeline.py
  45. 3 0
      paddlex/inference/pipelines/attribute_recognition/pipeline.py
  46. 5 1
      paddlex/inference/pipelines/attribute_recognition/result.py
  47. 4 5
      paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py
  48. 7 2
      paddlex/inference/pipelines/components/common/crop_image_regions.py
  49. 40 9
      paddlex/inference/pipelines/components/common/seal_det_warp.py
  50. 6 1
      paddlex/inference/pipelines/components/common/warp_image.py
  51. 6 1
      paddlex/inference/pipelines/components/faisser.py
  52. 13 9
      paddlex/inference/pipelines/components/retriever/base.py
  53. 84 81
      paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py
  54. 2 0
      paddlex/inference/pipelines/doc_preprocessor/pipeline.py
  55. 2 0
      paddlex/inference/pipelines/face_recognition/pipeline.py
  56. 2 0
      paddlex/inference/pipelines/formula_recognition/pipeline.py
  57. 6 1
      paddlex/inference/pipelines/formula_recognition/result.py
  58. 2 0
      paddlex/inference/pipelines/image_classification/pipeline.py
  59. 2 0
      paddlex/inference/pipelines/image_multilabel_classification/pipeline.py
  60. 2 0
      paddlex/inference/pipelines/instance_segmentation/pipeline.py
  61. 2 0
      paddlex/inference/pipelines/keypoint_detection/pipeline.py
  62. 2 0
      paddlex/inference/pipelines/layout_parsing/pipeline.py
  63. 2 0
      paddlex/inference/pipelines/layout_parsing/pipeline_v2.py
  64. 2 0
      paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py
  65. 2 0
      paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py
  66. 2 0
      paddlex/inference/pipelines/object_detection/pipeline.py
  67. 2 0
      paddlex/inference/pipelines/ocr/pipeline.py
  68. 6 1
      paddlex/inference/pipelines/ocr/result.py
  69. 2 0
      paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py
  70. 2 0
      paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py
  71. 2 0
      paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py
  72. 10 1
      paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py
  73. 2 0
      paddlex/inference/pipelines/pp_shitu_v2/pipeline.py
  74. 2 0
      paddlex/inference/pipelines/rotated_object_detection/pipeline.py
  75. 2 0
      paddlex/inference/pipelines/seal_recognition/pipeline.py
  76. 2 0
      paddlex/inference/pipelines/semantic_segmentation/pipeline.py
  77. 2 0
      paddlex/inference/pipelines/small_object_detection/pipeline.py
  78. 2 0
      paddlex/inference/pipelines/table_recognition/pipeline.py
  79. 10 1
      paddlex/inference/pipelines/table_recognition/pipeline_v2.py
  80. 2 0
      paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py
  81. 2 0
      paddlex/inference/pipelines/ts_classification/pipeline.py
  82. 2 0
      paddlex/inference/pipelines/ts_forecasting/pipeline.py
  83. 2 0
      paddlex/inference/pipelines/video_classification/pipeline.py
  84. 2 0
      paddlex/inference/pipelines/video_detection/pipeline.py
  85. 4 0
      paddlex/inference/serving/__init__.py
  86. 20 11
      paddlex/inference/serving/basic_serving/_app.py
  87. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py
  88. 5 1
      paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py
  89. 6 1
      paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py
  90. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py
  91. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py
  92. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py
  93. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py
  94. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py
  95. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py
  96. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py
  97. 9 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py
  98. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py
  99. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py
  100. 6 3
      paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py

+ 11 - 4
.pre-commit-config.yaml

@@ -56,8 +56,15 @@ repos:
 # check license
 -   repo: local
     hooks:
-    -   id: check-custom
-        name: Check Custom
-        entry: python .precommit/check_custom.py
+    -   id: check-license-headers
+        name: Check License Headers
+        entry: python .precommit/check_license_headers.py
         language: python
-        files: \.py$
+        files: .*\.py$
+    -   id: check-imports
+        name: Check Imports
+        entry: python .precommit/check_imports.py
+        language: python
+        files: ^paddlex/.*\.py$
+        additional_dependencies:
+            - stdlib-list==0.10.0

+ 192 - 0
.precommit/check_imports.py

@@ -0,0 +1,192 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TODO: Less verbose
+
+import ast
+import pathlib
+import re
+import sys
+import traceback
+from collections import deque
+
+from stdlib_list import stdlib_list
+
+sys.path.append(str(pathlib.Path(__file__).parent.parent))
+from setup import DEP_SPECS, REQUIRED_DEPS
+
+# NOTE: We do not use `importlib.metadata.packages_distributions` here because
+# 1. It is supported only in Python 3.10+.
+# 2. It requires the packages to be installed, but we are doing a static check.
+MOD_TO_DEP = {
+    "aiohttp": "aiohttp",
+    "baidubce": "bce-python-sdk",
+    "chardet": "chardet",
+    "chinese_calendar": "chinese-calendar",
+    "colorlog": "colorlog",
+    "decord": "decord",
+    "faiss": "faiss-cpu",
+    "fastapi": "fastapi",
+    "filelock": "filelock",
+    "filetype": "filetype",
+    "ftfy": "ftfy",
+    "GPUtil": "GPUtil",
+    "imagesize": "imagesize",
+    "jinja2": "Jinja2",
+    "joblib": "joblib",
+    "langchain": "langchain",
+    "langchain_community": "langchain-community",
+    "langchain_core": "langchain-core",
+    "langchain_openai": "langchain-openai",
+    "lxml": "lxml",
+    "matplotlib": "matplotlib",
+    "numpy": "numpy",
+    "openai": "openai",
+    "cv2": "opencv-contrib-python",
+    "openpyxl": "openpyxl",
+    "packaging": "packaging",
+    "pandas": "pandas",
+    "PIL": "pillow",
+    "premailer": "premailer",
+    "prettytable": "prettytable",
+    "cpuinfo": "py-cpuinfo",
+    "pyclipper": "pyclipper",
+    "pycocotools": "pycocotools",
+    "pydantic": "pydantic",
+    "fitz": "PyMuPDF",
+    "yaml": "PyYAML",
+    "regex": "regex",
+    "requests": "requests",
+    "ruamel.yaml": "ruamel.yaml",
+    "skimage": "scikit-image",
+    "sklearn": "scikit-learn",
+    "shapely": "shapely",
+    "soundfile": "soundfile",
+    "starlette": "starlette",
+    "tokenizers": "tokenizers",
+    "tqdm": "tqdm",
+    "typing_extensions": "typing-extensions",
+    "ujson": "ujson",
+    "uvicorn": "uvicorn",
+    "yarl": "yarl",
+}
+assert (
+    set(MOD_TO_DEP.values()) == DEP_SPECS.keys()
+), f"`MOD_TO_DEP` should be updated to match `DEP_SPECS`. Symmetric difference: {set(MOD_TO_DEP.values()) ^ DEP_SPECS.keys()}"
+MOD_PATTERN = re.compile(
+    rf"^(?:{'|'.join([re.escape(mod) for mod in MOD_TO_DEP])})(?=\.|$)"
+)
+STDLIB_MODS = set(stdlib_list())
+SPECIAL_KNOWN_MODS = {
+    "paddleseg",
+    "paddleclas",
+    "paddledet",
+    "paddlets",
+    "paddlenlp",
+    "paddlespeech",
+    "parl",
+    "paddlemix",
+    "paddle3d",
+    "paddlevideo",
+}
+MANUALLY_MANAGED_HEAVY_MODS = {"paddle", "paddle_custom_device", "ultra_infer"}
+
+
+def check(file_path):
+    # TODO:
+    # 1. Handle more cases, e.g., `from ruamel import yaml`.
+    # 2. Find unused dependencies.
+    # 3. Better output format.
+
+    with open(file_path, "r", encoding="utf-8") as f:
+        file_contents = f.read()
+
+    try:
+        tree = ast.parse(file_contents)
+    except Exception:
+        print(
+            f"Failed to parse the source code in `{file_path}` into an AST node:\n{traceback.format_exc()}"
+        )
+        return False
+
+    # 1. Never import unknown modules
+    # 2. Don't import optional third-party modules at the top level
+    unknown_modules_found = False
+    top_level_imports_found = False
+    q = deque()
+    for child in ast.iter_child_nodes(tree):
+        q.append((child, 1))
+    while q:
+        node, level = q.popleft()
+        mods = set()
+        if isinstance(node, ast.Import):
+            for alias in node.names:
+                mod = alias.name
+                mods.add(mod)
+        elif isinstance(node, ast.ImportFrom):
+            if node.module and node.level == 0:
+                mod = node.module
+                mods.add(mod)
+        for mod in mods:
+            pos = f"{file_path}:{node.lineno}:{node.col_offset}"
+            tl = mod.split(".")[0]
+            if tl == "paddlex" or tl in SPECIAL_KNOWN_MODS or tl in STDLIB_MODS:
+                continue
+            elif tl in MANUALLY_MANAGED_HEAVY_MODS:
+                if level == 1:
+                    print(
+                        f"{pos}: Module of a manually managed heavy dependency imported at the top level: {mod}"
+                    )
+                    top_level_imports_found = True
+            elif match_ := MOD_PATTERN.match(mod):
+                if level == 1:
+                    dep = MOD_TO_DEP[match_.group(0)]
+                    if dep not in REQUIRED_DEPS:
+                        print(
+                            f"{pos}: Module of an optional dependency imported at the top level: {mod}"
+                        )
+                        top_level_imports_found = True
+            else:
+                print(f"{pos}: Unknown module imported: {mod}")
+                unknown_modules_found = True
+        for child in ast.iter_child_nodes(node):
+            q.append((child, level + 1))
+
+    return unknown_modules_found | (top_level_imports_found << 1)
+
+
+def main():
+    files = sys.argv[1:]
+    flag = 0
+    for file in files:
+        ret = check(file)
+        flag |= ret
+    if flag:
+        if flag & 1:
+            curr_script_path = pathlib.Path(__file__)
+            curr_script_path = curr_script_path.relative_to(
+                curr_script_path.parent.parent
+            )
+            print(
+                f"If a new dependency should be added, please update `setup.py` and `{curr_script_path}`."
+            )
+        if (flag >> 1) & 1:
+            print(
+                "Please put the imports from optional dependencies and manually managed heavy dependencies inside a conditional body or within a function body."
+            )
+        sys.exit(1)
+
+
+if __name__ == "__main__":
+    main()

+ 13 - 18
.precommit/check_custom.py → .precommit/check_license_headers.py

@@ -12,13 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 import re
 import sys
 
 YEAR_PATTERN = r"(?:20\d\d)"
-LICENSE_TEXT = re.escape(
-    """# Copyright (c) <YEAR_PATTERN> PaddlePaddle Authors. All Rights Reserved.
+LICENSE_HEADER_PATTERN = re.compile(
+    re.escape(
+        """# Copyright (c) <YEAR_PATTERN> PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -32,26 +32,21 @@ LICENSE_TEXT = re.escape(
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """
-).replace("<YEAR_PATTERN>", YEAR_PATTERN)
+    ).replace("<YEAR_PATTERN>", YEAR_PATTERN)
+)
 
 
 def check(file_path):
-    with open(file_path, "r") as f:
-        content = f.read()
+    with open(file_path, "r", encoding="utf-8") as f:
+        contents = f.read()
     # Exclude shebang line
-    if content.startswith("#!"):
-        content = content[content.index("\n") + 1 :]
-        if content.startswith("\n"):
-            content = content[1:]
-    if not re.match(LICENSE_TEXT, content):
-        print(f"License header missing in {file_path}")
+    if contents.startswith("#!"):
+        contents = contents[contents.index("\n") + 1 :]
+        if contents.startswith("\n"):
+            contents = contents[1:]
+    if not LICENSE_HEADER_PATTERN.match(contents):
+        print(f"License header missing in `{file_path}`")
         return False
-    if "paddlex" in file_path.split(os.sep):
-        if "import paddle" in content or "from paddle import " in content:
-            print(
-                f"Please use `lazy_paddle` instead `paddle` when import in {file_path}"
-            )
-            return False
     return True
 
 

+ 12 - 25
paddlex/__init__.py

@@ -12,14 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import os
 import sys
 
-from .utils.lazy_loader import LazyLoader
-
-paddle = LazyLoader("lazy_paddle", globals(), "paddle")
-sys.modules["lazy_paddle"] = paddle
-
-import os
+_SPECIAL_MODS = ["paddle", "paddle_custom_device", "ultra_infer"]
+_loaded_special_mods = []
+for mod in _SPECIAL_MODS:
+    if mod in sys.modules:
+        _loaded_special_mods.append(mod)
 
 from . import version
 from .inference import create_pipeline, create_predictor
@@ -42,26 +42,13 @@ def _initialize():
     if flags.EAGER_INITIALIZATION:
         repo_manager.initialize()
 
-
-def _check_paddle_version():
-    """check paddle version"""
-
-    supported_versions = ["3.0", "0.0"]
-    device_type = paddle.device.get_device().split(":")[0]
-    if device_type.lower() == "xpu":
-        supported_versions.append("2.6")
-    version = paddle.__version__
-    # Recognizable version number: major.minor.patch
-    major, minor, patch = version.split(".")
-    # Ignore patch
-    version = f"{major}.{minor}"
-    if version not in supported_versions:
-        raise RuntimeError(
-            f"The {version} version of PaddlePaddle is not supported. "
-            f"Please install one of the following versions of PaddlePaddle: {supported_versions}."
-        )
+    __version__ = version.get_pdx_version()
 
 
 _initialize()
 
-__version__ = version.get_pdx_version()
+for mod in _SPECIAL_MODS:
+    if mod in sys.modules and mod not in _loaded_special_mods:
+        raise AssertionError(
+            f"`{mod}` is unexpectedly loaded. Please contact the PaddleX team to report this issue."
+        )

+ 4 - 2
paddlex/inference/common/reader/audio_reader.py

@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import lazy_paddle as paddle
-
+from ....utils.deps import class_requires_deps
 from ...utils.io import AudioReader
 
 
+@class_requires_deps("paddlepaddle")
 class ReadAudio:
     """Load audio from the file."""
 
@@ -29,6 +29,8 @@ class ReadAudio:
         self._audio_reader = AudioReader(backend="wav")
 
     def read(self, input):
+        import paddle
+
         if isinstance(input, str):
             audio, sample_rate = self._audio_reader.read(input)
             if sample_rate != 16000:

+ 10 - 8
paddlex/inference/common/reader/image_reader.py

@@ -12,23 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import cv2
 import numpy as np
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 from ...utils.io import ImageReader
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 
 @benchmark.timeit_with_options(name=None, is_read_operation=True)
+@class_requires_deps("opencv-contrib-python")
 class ReadImage:
     """Load image from the file."""
 
-    _FLAGS_DICT = {
-        "BGR": cv2.IMREAD_COLOR,
-        "RGB": cv2.IMREAD_COLOR,
-        "GRAY": cv2.IMREAD_GRAYSCALE,
-    }
-
     def __init__(self, format="BGR"):
         """
         Initialize the instance.
@@ -39,7 +37,11 @@ class ReadImage:
         """
         super().__init__()
         self.format = format
-        flags = self._FLAGS_DICT[self.format]
+        flags = {
+            "BGR": cv2.IMREAD_COLOR,
+            "RGB": cv2.IMREAD_COLOR,
+            "GRAY": cv2.IMREAD_GRAYSCALE,
+        }[self.format]
         self._img_reader = ImageReader(backend="opencv", flags=flags)
 
     def __call__(self, imgs):

+ 5 - 1
paddlex/inference/models/anomaly_detection/processors.py

@@ -13,12 +13,16 @@
 # limitations under the License.
 
 import numpy as np
-from skimage import morphology
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 
+if is_dep_available("scikit-image"):
+    from skimage import morphology
+
 
 @benchmark.timeit
+@class_requires_deps("scikit-image")
 class MapToMask:
     """Map_to_mask"""
 

+ 2 - 0
paddlex/inference/models/base/predictor/base_predictor.py

@@ -21,6 +21,7 @@ from pydantic import ValidationError
 
 from ..... import constants
 from .....utils import logging
+from .....utils.deps import require_hpip
 from .....utils.device import get_default_device, parse_device
 from .....utils.flags import (
     INFER_BENCHMARK,
@@ -121,6 +122,7 @@ class BasePredictor(
                 logging.warning("`hpi_config` will be ignored when not using HPIP.")
             self._pp_option = self._prepare_pp_option(pp_option, device)
         else:
+            require_hpip()
             if pp_option is not None:
                 logging.warning("`pp_option` will be ignored when using HPIP.")
             self._hpi_config = self._prepare_hpi_config(hpi_config, device)

+ 30 - 10
paddlex/inference/models/common/static_infer.py

@@ -13,15 +13,18 @@
 # limitations under the License.
 
 import abc
-import importlib.util
 import subprocess
 from pathlib import Path
 from typing import List, Sequence
 
-import lazy_paddle as paddle
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import (
+    class_requires_deps,
+    function_requires_deps,
+    is_paddle2onnx_plugin_available,
+)
 from ....utils.device import constr_device
 from ....utils.flags import DEBUG, INFER_BENCHMARK_USE_NEW_INFER_API, USE_PIR_TRT
 from ...utils.benchmark import benchmark, set_inference_operations
@@ -50,7 +53,10 @@ set_inference_operations(INFERENCE_OPERATIONS)
 
 
 # XXX: Better use Paddle Inference API to do this
+@function_requires_deps("paddlepaddle")
 def _pd_dtype_to_np_dtype(pd_dtype):
+    import paddle
+
     if pd_dtype == paddle.inference.DataType.FLOAT64:
         return np.float64
     elif pd_dtype == paddle.inference.DataType.FLOAT32:
@@ -68,6 +74,7 @@ def _pd_dtype_to_np_dtype(pd_dtype):
 
 
 # old trt
+@function_requires_deps("paddlepaddle")
 def _collect_trt_shape_range_info(
     model_file,
     model_params,
@@ -76,6 +83,7 @@ def _collect_trt_shape_range_info(
     dynamic_shapes,
     dynamic_shape_input_data,
 ):
+    import paddle.inference
 
     dynamic_shape_input_data = dynamic_shape_input_data or {}
 
@@ -143,6 +151,7 @@ def _collect_trt_shape_range_info(
 
 
 # pir trt
+@function_requires_deps("paddlepaddle")
 def _convert_trt(
     trt_cfg_setting,
     pp_model_file,
@@ -152,6 +161,7 @@ def _convert_trt(
     dynamic_shapes,
     dynamic_shape_input_data,
 ):
+    import paddle.inference
     from paddle.tensorrt.export import Input, TensorRTConfig, convert
 
     def _set_trt_config():
@@ -239,12 +249,15 @@ def _concatenate(*callables):
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class PaddleCopyToDevice:
     def __init__(self, device_type, device_id):
         self.device_type = device_type
         self.device_id = device_id
 
     def __call__(self, arrs):
+        import paddle
+
         device_id = [self.device_id] if self.device_id is not None else self.device_id
         device = constr_device(self.device_type, device_id)
         paddle_tensors = [paddle.to_tensor(i, place=device) for i in arrs]
@@ -252,6 +265,7 @@ class PaddleCopyToDevice:
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class PaddleCopyToHost:
     def __call__(self, paddle_tensors):
         arrs = [i.numpy() for i in paddle_tensors]
@@ -259,6 +273,7 @@ class PaddleCopyToHost:
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class PaddleModelInfer:
     def __init__(self, predictor):
         super().__init__()
@@ -270,6 +285,7 @@ class PaddleModelInfer:
 
 # FIXME: Name might be misleading
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class PaddleInferChainLegacy:
     def __init__(self, predictor):
         self.predictor = predictor
@@ -299,6 +315,7 @@ class StaticInfer(metaclass=abc.ABCMeta):
         raise NotImplementedError
 
 
+@class_requires_deps("paddlepaddle")
 class PaddleInfer(StaticInfer):
     def __init__(
         self,
@@ -338,6 +355,9 @@ class PaddleInfer(StaticInfer):
         self,
     ):
         """_create"""
+        import paddle
+        import paddle.inference
+
         model_paths = get_model_paths(self.model_dir, self.model_file_prefix)
         if "paddle" not in model_paths:
             raise RuntimeError("No valid PaddlePaddle model found")
@@ -469,6 +489,8 @@ class PaddleInfer(StaticInfer):
 
     def _configure_trt(self, model_file, params_file, cache_dir):
         # TODO: Support calibration
+        import paddle.inference
+
         if USE_PIR_TRT:
             trt_save_path = cache_dir / "trt" / self.model_file_prefix
             _convert_trt(
@@ -566,6 +588,7 @@ class PaddleInfer(StaticInfer):
 
 # FIXME: Name might be misleading
 @benchmark.timeit
+@class_requires_deps("ultra-infer")
 class MultiBackendInfer(object):
     def __init__(self, ui_runtime):
         super().__init__()
@@ -579,6 +602,7 @@ class MultiBackendInfer(object):
 
 # TODO: It would be better to refactor the code to make `HPInfer` a higher-level
 # class that uses `PaddleInfer`.
+@class_requires_deps("ultra-infer", "paddlepaddle")
 class HPInfer(StaticInfer):
     def __init__(
         self,
@@ -644,16 +668,16 @@ class HPInfer(StaticInfer):
 
         model_paths = get_model_paths(self._model_dir, self._model_file_prefix)
         is_onnx_model_available = "onnx" in model_paths
-        # TODO: Give a warning if Paddle2ONNX is not available but can be used
-        # to select a better backend.
+        # TODO: Give a warning if the Paddle2ONNX plugin is not available but
+        # can be used to select a better backend.
         if self._config.auto_paddle2onnx:
-            if self._check_paddle2onnx():
+            if is_paddle2onnx_plugin_available():
                 is_onnx_model_available = (
                     is_onnx_model_available or "paddle" in model_paths
                 )
             else:
                 logging.debug(
-                    "Paddle2ONNX is not available. Automatic model conversion will not be performed."
+                    "The Paddle2ONNX plugin is not available. Automatic model conversion will not be performed."
                 )
         available_backends = []
         if "paddle" in model_paths:
@@ -843,7 +867,3 @@ class HPInfer(StaticInfer):
         ui_runtime = Runtime(ui_option)
 
         return ui_runtime
-
-    def _check_paddle2onnx(self):
-        # HACK
-        return importlib.util.find_spec("paddle2onnx") is not None

+ 6 - 1
paddlex/inference/models/common/tokenizer/gpt_tokenizer.py

@@ -19,8 +19,8 @@ from functools import lru_cache
 from typing import Dict, Optional, Union
 
 import numpy as np
-import regex as re
 
+from .....utils.deps import class_requires_deps
 from .tokenizer_utils import PretrainedTokenizer
 from .tokenizer_utils_base import (
     AddedToken,
@@ -75,6 +75,7 @@ def get_pairs(word):
     return pairs
 
 
+@class_requires_deps("regex")
 class GPTTokenizer(PretrainedTokenizer):
     """
     Constructs a GPT tokenizer based on byte-level Byte-Pair-Encoding.
@@ -176,6 +177,8 @@ class GPTTokenizer(PretrainedTokenizer):
         add_bos_token=False,
         **kwargs  # The token of newline.
     ):
+        import regex as re
+
         pad_token = (
             AddedToken(pad_token, lstrip=False, rstrip=False)
             if isinstance(pad_token, str)
@@ -293,6 +296,8 @@ class GPTTokenizer(PretrainedTokenizer):
 
     def _tokenize(self, text):
         """Tokenize a string."""
+        import regex as re
+
         bpe_tokens = []
         for token in re.findall(self.pat, text):
             token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))

+ 24 - 13
paddlex/inference/models/common/tokenizer/tokenizer_utils.py

@@ -26,15 +26,14 @@ from dataclasses import asdict, dataclass
 from functools import lru_cache
 from typing import Any, Dict, List, Optional, Tuple, Union
 
-import lazy_paddle as paddle
-import numpy
 import numpy as np
-import six
-from jinja2 import Template
-from jinja2.exceptions import TemplateError, TemplateSyntaxError
-from jinja2.sandbox import ImmutableSandboxedEnvironment
 
 from .....utils import logging
+from .....utils.deps import (
+    class_requires_deps,
+    function_requires_deps,
+    is_dep_available,
+)
 from .tokenizer_utils_base import (
     CHAT_TEMPLATE_CONFIG_NAME,
     AddedToken,
@@ -53,6 +52,11 @@ from .tokenizer_utils_base import (
 from .utils import convert_to_dict_message, fn_args_to_dict
 from .vocab import Vocab
 
+if is_dep_available("Jinja2"):
+    from jinja2 import Template
+    from jinja2.exceptions import TemplateError, TemplateSyntaxError
+    from jinja2.sandbox import ImmutableSandboxedEnvironment
+
 __all__ = [
     "ChatTemplate",
     "Trie",
@@ -62,6 +66,7 @@ __all__ = [
 ]
 
 
+@class_requires_deps("Jinja2")
 @dataclass
 class ChatTemplate:
     conversation: Union[List[str], None] = None
@@ -70,7 +75,7 @@ class ChatTemplate:
 
     @staticmethod
     @lru_cache()
-    def _compile_jinja_template(chat_template) -> Template:
+    def _compile_jinja_template(chat_template) -> "Template":
         def raise_exception(message):
             raise TemplateError(message)
 
@@ -196,12 +201,14 @@ class ChatTemplate:
         return cls.from_dict(config)
 
 
+@function_requires_deps("paddlepaddle")
 def adapt_stale_fwd_patch(self, name, value):
     """
     Since there are some monkey patches for forward of PretrainedModel, such as
     model compression, we make these patches compatible with the latest forward
     method.
     """
+
     if name == "forward":
         # NOTE(guosheng): In dygraph to static, `layer.forward` would be patched
         # by an instance of `StaticFunction`. And use string compare to avoid to
@@ -229,6 +236,8 @@ def adapt_stale_fwd_patch(self, name, value):
         ]
 
         if new_args:
+            import paddle
+
             if self.__module__.startswith("paddlenlp"):
                 logging.warning(
                     f"The `forward` method of {self.__class__ if isinstance(self, paddle.nn.Layer) else self} is patched and the patch "
@@ -634,6 +643,7 @@ def normalize_chars(text):
     return "".join(output)
 
 
+@class_requires_deps("paddlepaddle", "Jinja2")
 class ChatTemplateMixin:
     chat_template: Optional[ChatTemplate] = None
 
@@ -643,7 +653,7 @@ class ChatTemplateMixin:
         tokenize: bool = True,
         context_data: Dict[str, Any] = {},
         **tokenizer_kwargs,
-    ) -> Union[str, Dict[str, Union["numpy.ndarray", "paddle.Tensor"]]]:
+    ):
         """apply chat_template rules to conversation which should not be batched data
 
         Args:
@@ -652,7 +662,7 @@ class ChatTemplateMixin:
             tokenize (bool, optional): whether do tokenization. Defaults to True.
 
         Returns:
-            str | dict[str, Union["numpy.ndarray", "paddle.Tensor"]]: return the result of applied data
+            str | dict[str, Union[numpy.ndarray, paddle.Tensor]]: return the result of applied data
         """
         if not self.chat_template:
             raise ValueError(
@@ -677,7 +687,7 @@ class ChatTemplateMixin:
         self,
         conversation: Union[List[Dict[str, str]], str],
         context_data: Dict[str, Any] = {},
-    ) -> Union[str, Dict[str, Union["numpy.ndarray", "paddle.Tensor"]]]:
+    ):
         context_data = self.chat_template._init_context_data(context_data)
 
         if isinstance(conversation, str):
@@ -695,7 +705,7 @@ class ChatTemplateMixin:
         self,
         conversation: Union[Dict[str, str], str],
         add_generation_prompt=True,
-    ) -> Union[str, Dict[str, Union["numpy.ndarray", "paddle.Tensor"]]]:
+    ):
         if isinstance(conversation, str):
             conversations = [{"role": "user", "content": conversation}]
         elif isinstance(conversation, list):
@@ -932,8 +942,9 @@ class ChatTemplateMixin:
             logging.info("Chat-template config file saved in " + chat_template_file)
 
 
-@six.add_metaclass(InitTrackerMeta)
-class PretrainedTokenizer(ChatTemplateMixin, PretrainedTokenizerBase):
+class PretrainedTokenizer(
+    ChatTemplateMixin, PretrainedTokenizerBase, metaclass=InitTrackerMeta
+):
     """
     Base class for all tokenizers.
 

+ 12 - 3
paddlex/inference/models/common/tokenizer/tokenizer_utils_base.py

@@ -22,10 +22,10 @@ from dataclasses import dataclass, field
 from enum import Enum
 from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
 
-import lazy_paddle as paddle
 import numpy as np
 
 from .....utils import logging
+from .....utils.deps import class_requires_deps, function_requires_deps
 
 __all__ = [
     "AddedToken",
@@ -125,10 +125,13 @@ class TensorType(ExplicitEnum):
     NUMPY = "np"
 
 
+@function_requires_deps("paddlepaddle")
 def to_py_obj(obj):
     """
     Convert a Paddle tensor, Numpy array or python list to a python list.
     """
+    import paddle
+
     if isinstance(obj, (dict, UserDict)):
         return {k: to_py_obj(v) for k, v in obj.items()}
     elif isinstance(obj, (list, tuple)):
@@ -183,6 +186,7 @@ class TokenSpan(NamedTuple):
     end: int
 
 
+@class_requires_deps("paddlepaddle")
 class BatchEncoding(UserDict):
     """
     Holds the output of the [`PretrainedTokenizerBase.__call__`],
@@ -719,6 +723,8 @@ class BatchEncoding(UserDict):
             prepend_batch_axis (`int`, *optional*, defaults to `False`):
                 Whether or not to add the batch dimension during the conversion.
         """
+        import paddle
+
         if tensor_type is None:
             return self
 
@@ -1304,6 +1310,7 @@ class SpecialTokensMixin:
         return all_ids
 
 
+@class_requires_deps("paddlepaddle")
 class PretrainedTokenizerBase(SpecialTokensMixin):
     """
     Base class for [`PretrainedTokenizer`].
@@ -2723,6 +2730,8 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
             verbose (`bool`, *optional*, defaults to `True`):
                 Whether or not to print more information and warnings.
         """
+        import paddle
+
         # If we have a list of dicts, let's convert it in a dict of lists
         if isinstance(encoded_inputs, (list, tuple)) and isinstance(
             encoded_inputs[0], (dict, BatchEncoding)
@@ -3336,7 +3345,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
 
     def batch_decode(
         self,
-        sequences: Union[List[int], List[List[int]], "np.ndarray", "paddle.Tensor"],
+        sequences,
         skip_special_tokens: bool = False,
         clean_up_tokenization_spaces: bool = True,
         **kwargs,
@@ -3369,7 +3378,7 @@ class PretrainedTokenizerBase(SpecialTokensMixin):
 
     def decode(
         self,
-        token_ids: Union[int, List[int], "np.ndarray", "paddle.Tensor"],
+        token_ids,
         skip_special_tokens: bool = False,
         clean_up_tokenization_spaces: bool = True,
         **kwargs,

+ 10 - 2
paddlex/inference/models/common/ts/funcs.py

@@ -15,12 +15,17 @@
 
 from typing import Callable, Dict, List, Optional, Union
 
-import chinese_calendar
 import numpy as np
 import pandas as pd
 from pandas.tseries import holiday as hd
 from pandas.tseries.offsets import DateOffset, Day, Easter
-from sklearn.preprocessing import StandardScaler
+
+from .....utils.deps import function_requires_deps, is_dep_available
+
+if is_dep_available("chinese-calendar"):
+    import chinese_calendar
+if is_dep_available("scikit-learn"):
+    from sklearn.preprocessing import StandardScaler
 
 MAX_WINDOW = 183 + 17
 EasterSunday = hd.Holiday("Easter Sunday", month=1, day=1, offset=[Easter(), Day(0)])
@@ -134,12 +139,14 @@ def _cal_weekofyear(
     return x.weekofyear / 51.0 - 0.5
 
 
+@function_requires_deps("chinese-calendar")
 def _cal_holiday(
     x: np.datetime64,
 ):
     return float(chinese_calendar.is_holiday(x))
 
 
+@function_requires_deps("chinese-calendar")
 def _cal_workday(
     x: np.datetime64,
 ):
@@ -443,6 +450,7 @@ def _distance_to_holiday(holiday) -> Callable[[pd.Timestamp], float]:
     return _distance_to_day
 
 
+@function_requires_deps("scikit-learn")
 def time_feature(
     dataset: Dict,
     freq: Optional[Union[str, int]],

+ 5 - 1
paddlex/inference/models/common/ts/processors.py

@@ -14,13 +14,16 @@
 
 from typing import Any, Dict, List
 
-import joblib
 import numpy as np
 import pandas as pd
 
+from .....utils.deps import class_requires_deps, is_dep_available
 from ....utils.benchmark import benchmark
 from .funcs import load_from_dataframe, time_feature
 
+if is_dep_available("joblib"):
+    import joblib
+
 __all__ = [
     "BuildTSDataset",
     "TSCutOff",
@@ -91,6 +94,7 @@ class TSCutOff:
 
 
 @benchmark.timeit
+@class_requires_deps("joblib")
 class TSNormalize:
     """Normalizes time series data using a pre-fitted scaler.
 

+ 8 - 1
paddlex/inference/models/common/vision/funcs.py

@@ -12,11 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import cv2
 import numpy as np
 from PIL import Image
 
 from .....utils import logging
+from .....utils.deps import function_requires_deps, is_dep_available
+
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
 
 def check_image_size(input_):
@@ -47,6 +50,7 @@ def resize(im, target_size, interp, backend="cv2"):
     return im
 
 
+@function_requires_deps("opencv-contrib-python")
 def _cv2_resize(src, size, resample):
     return cv2.resize(src, size, interpolation=resample)
 
@@ -60,11 +64,13 @@ def _pil_resize(src, size, resample):
     return np.asarray(pil_img)
 
 
+@function_requires_deps("opencv-contrib-python")
 def flip_h(im):
     """flip image horizontally"""
     return cv2.flip(im, 1)
 
 
+@function_requires_deps("opencv-contrib-python")
 def flip_v(im):
     """flip image vertically"""
     return cv2.flip(im, 0)
@@ -77,6 +83,7 @@ def slice(im, coords):
     return im
 
 
+@function_requires_deps("opencv-contrib-python")
 def pad(im, pad, val):
     """padding image by value"""
     if isinstance(pad, int):

+ 25 - 20
paddlex/inference/models/common/vision/processors.py

@@ -14,31 +14,35 @@
 
 import math
 
-import cv2
 import numpy as np
 from PIL import Image
 
+from .....utils.deps import class_requires_deps, is_dep_available
 from ....utils.benchmark import benchmark
 from . import funcs as F
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
-class _BaseResize:
-    _CV2_INTERP_DICT = {
-        "NEAREST": cv2.INTER_NEAREST,
-        "LINEAR": cv2.INTER_LINEAR,
-        "BICUBIC": cv2.INTER_CUBIC,
-        "AREA": cv2.INTER_AREA,
-        "LANCZOS4": cv2.INTER_LANCZOS4,
-    }
-    _PIL_INTERP_DICT = {
-        "NEAREST": Image.NEAREST,
-        "BILINEAR": Image.BILINEAR,
-        "BICUBIC": Image.BICUBIC,
-        "BOX": Image.BOX,
-        "LANCZOS4": Image.LANCZOS,
-    }
 
+@class_requires_deps("opencv-contrib-python")
+class _BaseResize:
     def __init__(self, size_divisor, interp, backend="cv2"):
+        _CV2_INTERP_DICT = {
+            "NEAREST": cv2.INTER_NEAREST,
+            "LINEAR": cv2.INTER_LINEAR,
+            "BICUBIC": cv2.INTER_CUBIC,
+            "AREA": cv2.INTER_AREA,
+            "LANCZOS4": cv2.INTER_LANCZOS4,
+        }
+        _PIL_INTERP_DICT = {
+            "NEAREST": Image.NEAREST,
+            "BILINEAR": Image.BILINEAR,
+            "BICUBIC": Image.BICUBIC,
+            "BOX": Image.BOX,
+            "LANCZOS4": Image.LANCZOS,
+        }
+
         super().__init__()
 
         if size_divisor is not None:
@@ -50,9 +54,9 @@ class _BaseResize:
         try:
             interp = interp.upper()
             if backend == "cv2":
-                interp = self._CV2_INTERP_DICT[interp]
+                interp = _CV2_INTERP_DICT[interp]
             elif backend == "pil":
-                interp = self._PIL_INTERP_DICT[interp]
+                interp = _PIL_INTERP_DICT[interp]
             else:
                 raise ValueError("backend must be `cv2` or `pil`")
         except KeyError:
@@ -60,9 +64,9 @@ class _BaseResize:
                 "For backend '{}', `interp` should be one of {}. Please ensure the interpolation method matches the selected backend.".format(
                     backend,
                     (
-                        self._CV2_INTERP_DICT.keys()
+                        _CV2_INTERP_DICT.keys()
                         if backend == "cv2"
-                        else self._PIL_INTERP_DICT.keys()
+                        else _PIL_INTERP_DICT.keys()
                     ),
                 )
             )
@@ -216,6 +220,7 @@ class ResizeByShort(_BaseResize):
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class Normalize:
     """Normalize the three-channel image."""
 

+ 17 - 9
paddlex/inference/models/formula_recognition/processors.py

@@ -20,16 +20,21 @@ import re
 import tempfile
 from typing import Any, Dict, List, Optional, Tuple, Union
 
-import cv2
 import numpy as np
 from PIL import Image, ImageOps
-from tokenizers import AddedToken
-from tokenizers import Tokenizer as TokenizerFast
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+if is_dep_available("tokenizers"):
+    from tokenizers import AddedToken
+    from tokenizers import Tokenizer as TokenizerFast
+
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class MinMaxResize:
     """Class for resizing images to be within specified minimum and maximum dimensions, with padding and normalization."""
 
@@ -155,6 +160,7 @@ class MinMaxResize:
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class LatexTestTransform:
     """
     A transform class for processing images according to Latex test requirements.
@@ -307,6 +313,7 @@ class ToBatch(object):
 
 
 @benchmark.timeit
+@class_requires_deps("tokenizers")
 class LaTeXOCRDecode(object):
     """Class for decoding LaTeX OCR tokens based on a provided character list."""
 
@@ -317,8 +324,6 @@ class LaTeXOCRDecode(object):
             character_list (list): The list of characters to use for tokenization.
             **kwargs: Additional keyword arguments for initialization.
         """
-        from tokenizers import Tokenizer as TokenizerFast
-
         super(LaTeXOCRDecode, self).__init__()
         temp_path = tempfile.gettempdir()
         rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
@@ -408,6 +413,7 @@ class LaTeXOCRDecode(object):
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class UniMERNetImgDecode(object):
     """Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
 
@@ -561,6 +567,7 @@ class UniMERNetImgDecode(object):
 
 
 @benchmark.timeit
+@class_requires_deps("tokenizers")
 class UniMERNetDecode(object):
     """Class for decoding tokenized inputs using UniMERNet tokenizer.
 
@@ -694,8 +701,8 @@ class UniMERNetDecode(object):
                         self._add_tokens(tokens, special_tokens=is_last_special)
 
     def _add_tokens(
-        self, new_tokens: List[Union[AddedToken, str]], special_tokens: bool = False
-    ) -> List[Union[AddedToken, str]]:
+        self, new_tokens: "List[Union[AddedToken, str]]", special_tokens: bool = False
+    ) -> "List[Union[AddedToken, str]]":
         """Adds new tokens to the tokenizer.
 
         Args:
@@ -711,7 +718,7 @@ class UniMERNetDecode(object):
         return self.tokenizer.add_tokens(new_tokens)
 
     def added_tokens_encoder(
-        self, added_tokens_decoder: Dict[int, AddedToken]
+        self, added_tokens_decoder: "Dict[int, AddedToken]"
     ) -> Dict[str, int]:
         """Creates an encoder dictionary from added tokens.
 
@@ -737,7 +744,7 @@ class UniMERNetDecode(object):
         return all_toks
 
     @property
-    def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:
+    def all_special_tokens_extended(self) -> "List[Union[str, AddedToken]]":
         """Retrieves all special tokens, including extended ones.
 
         Returns:
@@ -908,6 +915,7 @@ class UniMERNetDecode(object):
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class UniMERNetTestTransform:
     """
     A class for transforming images according to UniMERNet test specifications.

+ 9 - 2
paddlex/inference/models/formula_recognition/result.py

@@ -20,17 +20,21 @@ import tempfile
 from pathlib import Path
 from typing import List, Optional
 
-import cv2
-import fitz
 import numpy as np
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
 from ....utils import logging
+from ....utils.deps import function_requires_deps, is_dep_available
 from ....utils.file_interface import custom_open
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ...common.result import BaseCVResult, JsonMixin
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+if is_dep_available("PyMuPDF"):
+    import fitz
+
 
 class FormulaRecResult(BaseCVResult):
     def _get_input_fn(self):
@@ -205,6 +209,7 @@ def generate_pdf_file(
             )
 
 
+@function_requires_deps("opencv-contrib-python")
 def crop_white_area(image: np.ndarray) -> Optional[List[int]]:
     """
     Finds and returns the bounding box of the non-white area in an image.
@@ -231,6 +236,7 @@ def crop_white_area(image: np.ndarray) -> Optional[List[int]]:
         return None
 
 
+@function_requires_deps("PyMuPDF", "opencv-contrib-python")
 def pdf2img(pdf_path: str, img_path: str, is_padding: bool = False):
     """
     Converts a single-page PDF to an image, optionally cropping white areas and adding padding.
@@ -326,6 +332,7 @@ def env_valid() -> bool:
             formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
 
 
+@function_requires_deps("opencv-contrib-python")
 def draw_box_txt_fine(img_size: tuple, box: list, txt: str, font_path: str):
     """
     Draw box text.

+ 5 - 1
paddlex/inference/models/instance_segmentation/result.py

@@ -14,15 +14,19 @@
 
 import copy
 
-import cv2
 import numpy as np
 from PIL import Image
 
+from ....utils.deps import function_requires_deps, is_dep_available
 from ...common.result import BaseCVResult, JsonMixin
 from ...utils.color_map import get_colormap
 from ..object_detection.result import draw_box
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@function_requires_deps("opencv-contrib-python")
 def draw_segm(im, masks, mask_info, alpha=0.7):
     """
     Draw segmentation on image

+ 6 - 1
paddlex/inference/models/keypoint_detection/processors.py

@@ -15,13 +15,16 @@
 import math
 from typing import List, Optional, Sequence, Tuple, Union
 
-import cv2
 import numpy as np
 from numpy import ndarray
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 from ..object_detection.processors import get_affine_transform
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 Number = Union[int, float]
 Kpts = List[dict]
 
@@ -67,6 +70,7 @@ def get_warp_matrix(
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class TopDownAffine:
     """refer to https://github.com/open-mmlab/mmpose/blob/71ec36ebd63c475ab589afc817868e749a61491f/mmpose/datasets/transforms/topdown_transforms.py#L13
     Get the bbox image as the model input by affine transform.
@@ -199,6 +203,7 @@ def transform_preds(
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class KptPostProcess:
     """Save Result Transform"""
 

+ 7 - 2
paddlex/inference/models/keypoint_detection/result.py

@@ -15,13 +15,17 @@
 import copy
 import math
 
-import cv2
-import matplotlib.pyplot as plt
 import numpy as np
 from PIL import Image
 
+from ....utils.deps import function_requires_deps, is_dep_available
 from ...common.result import BaseCVResult, JsonMixin
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+if is_dep_available("matplotlib"):
+    import matplotlib.pyplot as plt
+
 
 def get_color(idx):
     idx = idx * 3
@@ -29,6 +33,7 @@ def get_color(idx):
     return color
 
 
+@function_requires_deps("matplotlib", "opencv-contrib-python")
 def draw_keypoints(img, results, visual_thresh=0.1, ids=None):
     plt.switch_backend("agg")
     skeletons = results["keypoints"]

+ 5 - 6
paddlex/inference/models/m_3d_bev_detection/predictor.py

@@ -16,10 +16,9 @@ import shutil
 import tempfile
 from typing import Any, Dict, Iterator, List, Tuple
 
-import lazy_paddle
-
 from ....modules.m_3d_bev_detection.model_list import MODELS
 from ....utils import logging
+from ....utils.deps import function_requires_deps
 from ....utils.func_register import FuncRegister
 from ...common.batch_sampler import Det3DBatchSampler
 from ...common.reader import ReadNuscenesData
@@ -76,16 +75,16 @@ class BEVDet3DPredictor(BasePredictor):
         """
         return BEV3DDetResult
 
+    @function_requires_deps("paddlepaddle")
     def _build(self) -> Tuple:
         """Build the preprocessors and inference engine based on the configuration.
 
         Returns:
             tuple: A tuple containing the preprocessors and inference engine.
         """
-        if (
-            lazy_paddle.is_compiled_with_cuda()
-            and not lazy_paddle.is_compiled_with_rocm()
-        ):
+        import paddle
+
+        if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
             from ....ops.iou3d_nms import nms_gpu  # noqa: F401
             from ....ops.voxelize import hard_voxelize  # noqa: F401
         else:

+ 15 - 8
paddlex/inference/models/m_3d_bev_detection/processors.py

@@ -15,19 +15,14 @@
 
 import numbers
 
-import cv2
 import numpy as np
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...common.reader.det_3d_reader import Sample
 from ...utils.benchmark import benchmark
 
-cv2_interp_codes = {
-    "nearest": cv2.INTER_NEAREST,
-    "bilinear": cv2.INTER_LINEAR,
-    "bicubic": cv2.INTER_CUBIC,
-    "area": cv2.INTER_AREA,
-    "lanczos": cv2.INTER_LANCZOS4,
-}
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
 
 @benchmark.timeit
@@ -45,6 +40,7 @@ class LoadPointsFromFile:
             shift_height (bool): Whether to shift height values.
             use_color (bool): Whether to include color attributes in the loaded points.
         """
+
         self.shift_height = shift_height
         self.use_color = use_color
         if isinstance(use_dim, int):
@@ -275,6 +271,7 @@ class LoadPointsFromMultiSweeps(object):
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class LoadMultiViewImageFromFiles:
     """Load multi-view images from files."""
 
@@ -342,6 +339,7 @@ class LoadMultiViewImageFromFiles:
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class ResizeImage:
     """Resize images & bbox & mask."""
 
@@ -595,6 +593,13 @@ class ResizeImage:
         Returns:
             numpy.ndarray or tuple: The resized image. If return_scale is True, returns a tuple containing the resized image and the scaling factors (w_scale, h_scale).
         """
+        cv2_interp_codes = {
+            "nearest": cv2.INTER_NEAREST,
+            "bilinear": cv2.INTER_LINEAR,
+            "bicubic": cv2.INTER_CUBIC,
+            "area": cv2.INTER_AREA,
+            "lanczos": cv2.INTER_LANCZOS4,
+        }
         h, w = img.shape[:2]
         if backend not in ["cv2", "pillow"]:
             raise ValueError(
@@ -673,6 +678,7 @@ class ResizeImage:
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class NormalizeImage:
     """Normalize the image."""
 
@@ -732,6 +738,7 @@ class NormalizeImage:
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class PadImage(object):
     """Pad the image & mask."""
 

+ 7 - 1
paddlex/inference/models/multilingual_speech_recognition/predictor.py

@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import lazy_paddle as paddle
 import numpy as np
 
 from ....modules.multilingual_speech_recognition.model_list import MODELS
+from ....utils.deps import function_requires_deps
 from ....utils.download import download_and_extract
 from ...common.batch_sampler import AudioBatchSampler
 from ...utils.io import AudioReader
@@ -54,12 +54,15 @@ class WhisperPredictor(BasePredictor):
         """
         return WhisperResult
 
+    @function_requires_deps("paddlepaddle")
     def _build(self):
         """Build the model, audio reader based on the configuration.
 
         Returns:
             AudioReader: An instance of AudioReader.
         """
+        import paddle
+
         from .processors import ModelDimensions, Whisper
 
         # build model
@@ -74,6 +77,7 @@ class WhisperPredictor(BasePredictor):
         audio_reader = AudioReader(backend="wav")
         return audio_reader
 
+    @function_requires_deps("paddlepaddle")
     def process(self, batch_data):
         """
         Process a batch of data through the preprocessing, inference, and postprocessing.
@@ -84,6 +88,8 @@ class WhisperPredictor(BasePredictor):
         Returns:
             dict: A dictionary containing the input path and result. The result include 'text', 'segments' and 'language'.
         """
+        import paddle
+
         from .processors import log_mel_spectrogram
 
         # load mel_filters from resource_dir and extract feature for audio

+ 1766 - 1761
paddlex/inference/models/multilingual_speech_recognition/processors.py

@@ -18,1922 +18,1927 @@ from dataclasses import dataclass, field
 from functools import lru_cache
 from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
-import lazy_paddle as paddle
 import numpy as np
-import soundfile
-import tqdm
 
+from ....utils.deps import function_requires_deps, is_dep_available
 from ..common.tokenizer import GPTTokenizer
 
-__all__ = [
-    "Whisper",
-    "Tokenizer",
-]
-
-
-def exact_div(x, y):
-    assert x % y == 0
-    return x // y
-
-
-_MODELS = ["large"]
-SAMPLE_RATE = 16000
-N_FFT = 400
-N_MELS = 80
-HOP_LENGTH = 160
-CHUNK_LENGTH = 30
-N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
-N_FRAMES = exact_div(
-    N_SAMPLES, HOP_LENGTH
-)  # 3000: number of frames in a mel spectrogram input
-
-
-@dataclass
-class ModelDimensions:
-    n_mels: int
-    n_audio_ctx: int
-    n_audio_state: int
-    n_audio_head: int
-    n_audio_layer: int
-    n_vocab: int
-    n_text_ctx: int
-    n_text_state: int
-    n_text_head: int
-    n_text_layer: int
-
-
-LANGUAGES = {
-    "en": "english",
-    "zh": "chinese",
-    "de": "german",
-    "es": "spanish",
-    "ru": "russian",
-    "ko": "korean",
-    "fr": "french",
-    "ja": "japanese",
-    "pt": "portuguese",
-    "tr": "turkish",
-    "pl": "polish",
-    "ca": "catalan",
-    "nl": "dutch",
-    "ar": "arabic",
-    "sv": "swedish",
-    "it": "italian",
-    "id": "indonesian",
-    "hi": "hindi",
-    "fi": "finnish",
-    "vi": "vietnamese",
-    "iw": "hebrew",
-    "uk": "ukrainian",
-    "el": "greek",
-    "ms": "malay",
-    "cs": "czech",
-    "ro": "romanian",
-    "da": "danish",
-    "hu": "hungarian",
-    "ta": "tamil",
-    "no": "norwegian",
-    "th": "thai",
-    "ur": "urdu",
-    "hr": "croatian",
-    "bg": "bulgarian",
-    "lt": "lithuanian",
-    "la": "latin",
-    "mi": "maori",
-    "ml": "malayalam",
-    "cy": "welsh",
-    "sk": "slovak",
-    "te": "telugu",
-    "fa": "persian",
-    "lv": "latvian",
-    "bn": "bengali",
-    "sr": "serbian",
-    "az": "azerbaijani",
-    "sl": "slovenian",
-    "kn": "kannada",
-    "et": "estonian",
-    "mk": "macedonian",
-    "br": "breton",
-    "eu": "basque",
-    "is": "icelandic",
-    "hy": "armenian",
-    "ne": "nepali",
-    "mn": "mongolian",
-    "bs": "bosnian",
-    "kk": "kazakh",
-    "sq": "albanian",
-    "sw": "swahili",
-    "gl": "galician",
-    "mr": "marathi",
-    "pa": "punjabi",
-    "si": "sinhala",
-    "km": "khmer",
-    "sn": "shona",
-    "yo": "yoruba",
-    "so": "somali",
-    "af": "afrikaans",
-    "oc": "occitan",
-    "ka": "georgian",
-    "be": "belarusian",
-    "tg": "tajik",
-    "sd": "sindhi",
-    "gu": "gujarati",
-    "am": "amharic",
-    "yi": "yiddish",
-    "lo": "lao",
-    "uz": "uzbek",
-    "fo": "faroese",
-    "ht": "haitian creole",
-    "ps": "pashto",
-    "tk": "turkmen",
-    "nn": "nynorsk",
-    "mt": "maltese",
-    "sa": "sanskrit",
-    "lb": "luxembourgish",
-    "my": "myanmar",
-    "bo": "tibetan",
-    "tl": "tagalog",
-    "mg": "malagasy",
-    "as": "assamese",
-    "tt": "tatar",
-    "haw": "hawaiian",
-    "ln": "lingala",
-    "ha": "hausa",
-    "ba": "bashkir",
-    "jw": "javanese",
-    "su": "sundanese",
-}
-
-# language code lookup by name, with a few language aliases
-TO_LANGUAGE_CODE = {
-    **{language: code for code, language in LANGUAGES.items()},
-    "burmese": "my",
-    "valencian": "ca",
-    "flemish": "nl",
-    "haitian": "ht",
-    "letzeburgesch": "lb",
-    "pushto": "ps",
-    "panjabi": "pa",
-    "moldavian": "ro",
-    "moldovan": "ro",
-    "sinhalese": "si",
-    "castilian": "es",
-}
-
-
-def compression_ratio(text) -> float:
-    return len(text) / len(zlib.compress(text.encode("utf-8")))
-
-
-def format_timestamp(
-    seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
-):
-    assert seconds >= 0, "non-negative timestamp expected"
-    milliseconds = round(seconds * 1000.0)
-
-    hours = milliseconds // 3_600_000
-    milliseconds -= hours * 3_600_000
-
-    minutes = milliseconds // 60_000
-    milliseconds -= minutes * 60_000
-
-    seconds = milliseconds // 1_000
-    milliseconds -= seconds * 1_000
-
-    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
-    return (
-        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
-    )
-
-
-@dataclass(frozen=True)
-class Tokenizer:
-    """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
-
-    tokenizer: "GPTTokenizer"
-    language: Optional[str]
-    sot_sequence: Tuple[int]
-
-    def encode(self, text, **kwargs):
-        return self.tokenizer.encode(text, **kwargs)
+if is_dep_available("soundfile"):
+    import soundfile
+if is_dep_available("tqdm"):
+    import tqdm
 
-    def decode(
-        self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
-    ):
-        if len(token_ids) > 1:
-            ids_list = []
-            for ids in token_ids:
-                if paddle.is_tensor(ids):
-                    ids = ids.item()
-                if ids < len(self.tokenizer):
-                    ids_list.append(ids)
-            token_ids = ids_list
-        elif len(token_ids) == 1:
-            token_ids = token_ids[0]
-        else:
-            raise ValueError(f"token_ids {token_ids} load error.")
-
-        return self.tokenizer.decode(token_ids, **kwargs)
-
-    def decode_with_timestamps(self, tokens) -> str:
-        """
-        Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
-        This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
-        """
-        outputs = [[]]
-        for token in tokens:
-            if token >= self.timestamp_begin:
-                timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
-                outputs.append(timestamp)
-                outputs.append([])
-            else:
-                outputs[-1].append(token)
-        outputs = [
-            s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
-        ]
-        return "".join(outputs)
-
-    @property
-    @lru_cache()
-    def eot(self) -> int:
-        return self.tokenizer.eos_token_id
-
-    @property
-    @lru_cache()
-    def sot(self) -> int:
-        return self._get_single_token_id("<|startoftranscript|>")
-
-    @property
-    @lru_cache()
-    def sot_lm(self) -> int:
-        return self._get_single_token_id("<|startoflm|>")
-
-    @property
-    @lru_cache()
-    def sot_prev(self) -> int:
-        return self._get_single_token_id("<|startofprev|>")
-
-    @property
-    @lru_cache()
-    def no_speech(self) -> int:
-        return self._get_single_token_id("<|nospeech|>")
-
-    @property
-    @lru_cache()
-    def no_timestamps(self) -> int:
-        return self._get_single_token_id("<|notimestamps|>")
-
-    @property
-    @lru_cache()
-    def timestamp_begin(self) -> int:
-        return self.tokenizer.all_special_ids[-1] + 1
-
-    @property
-    @lru_cache()
-    def language_token(self) -> int:
-        """Returns the token id corresponding to the value of the `language` field"""
-        if self.language is None:
-            raise ValueError("This tokenizer does not have language token configured")
-
-        additional_tokens = dict(
-            zip(
-                self.tokenizer.additional_special_tokens,
-                self.tokenizer.additional_special_tokens_ids,
-            )
-        )
-        candidate = f"<|{self.language}|>"
-        if candidate in additional_tokens:
-            return additional_tokens[candidate]
-
-        raise KeyError(f"Language {self.language} not found in tokenizer.")
-
-    @property
-    @lru_cache()
-    def all_language_tokens(self) -> Tuple[int]:
-        result = []
-        for token, token_id in zip(
-            self.tokenizer.additional_special_tokens,
-            self.tokenizer.additional_special_tokens_ids,
-        ):
-            if token.strip("<|>") in LANGUAGES:
-                result.append(token_id)
-        return tuple(result)
-
-    @property
-    @lru_cache()
-    def all_language_codes(self) -> Tuple[str]:
-        return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
-
-    @property
-    @lru_cache()
-    def sot_sequence_including_notimestamps(self) -> Tuple[int]:
-        return tuple(list(self.sot_sequence) + [self.no_timestamps])
-
-    @property
-    @lru_cache()
-    def non_speech_tokens(self) -> Tuple[int]:
-        """
-        Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
-        annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
+if is_dep_available("paddlepaddle"):
 
-        - ♪♪♪
-        - ( SPEAKING FOREIGN LANGUAGE )
-        - [DAVID] Hey there,
+    import paddle
 
-        keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
-        """
-        symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
-        symbols += (
-            "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
-        )
-
-        # symbols that may be a single token or multiple tokens depending on the tokenizer.
-        # In case they're multiple tokens, suppress the first token, which is safe because:
-        # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
-        # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
-        miscellaneous = set("♩♪♫♬♭♮♯")
-        assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
-
-        # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
-        result = {
-            self.tokenizer.encode(" -").input_ids[0],
-            self.tokenizer.encode(" '").input_ids[0],
-        }
-        for symbol in symbols + list(miscellaneous):
-            for tokens in [
-                self.tokenizer.encode(symbol).input_ids,
-                self.tokenizer.encode(" " + symbol).input_ids,
-            ]:
-                if len(tokens) == 1 or symbol in miscellaneous:
-                    result.add(tokens[0])
-
-        return tuple(sorted(result))
-
-    def _get_single_token_id(self, text) -> int:
-        tokens = self.tokenizer.encode(text).input_ids
-        assert len(tokens) == 1, f"{text} is not encoded as a single token"
-        return tokens[0]
-
-
-@lru_cache(maxsize=None)
-def build_tokenizer(resource_path: str, name: str = "gpt2"):
-    os.environ["TOKENIZERS_PARALLELISM"] = "false"
-    path = os.path.join(resource_path, "assets", name)
-    tokenizer = GPTTokenizer.from_pretrained(path)
-
-    specials = [
-        "<|startoftranscript|>",
-        *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
-        "<|translate|>",
-        "<|transcribe|>",
-        "<|startoflm|>",
-        "<|startofprev|>",
-        "<|nospeech|>",
-        "<|notimestamps|>",
+    __all__ = [
+        "Whisper",
+        "Tokenizer",
     ]
 
-    tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
-    return tokenizer
-
-
-@lru_cache(maxsize=None)
-def get_tokenizer(
-    multilingual: bool,
-    resource_path: str,
-    *,
-    task: Optional[str] = None,  # Literal["transcribe", "translate", None]
-    language: Optional[str] = None,
-) -> Tokenizer:
-    if language is not None:
-        language = language.lower()
-        if language not in LANGUAGES:
-            if language in TO_LANGUAGE_CODE:
-                language = TO_LANGUAGE_CODE[language]
-            else:
-                raise ValueError(f"Unsupported language: {language}")
-
-    if multilingual:
-        tokenizer_name = "multilingual"
-        task = task or "transcribe"
-        language = language or "en"
-    else:
-        tokenizer_name = "gpt2"
-        task = None
-        language = None
-
-    tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
-    all_special_ids: List[int] = tokenizer.all_special_ids
-    sot: int = all_special_ids[1]
-    translate: int = all_special_ids[-6]
-    transcribe: int = all_special_ids[-5]
-
-    langs = tuple(LANGUAGES.keys())
-    sot_sequence = [sot]
-    if language is not None:
-        sot_sequence.append(sot + 1 + langs.index(language))
-    if task is not None:
-        sot_sequence.append(transcribe if task == "transcribe" else translate)
-
-    return Tokenizer(
-        tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
-    )
-
-
-class MultiHeadAttention(paddle.nn.Layer):
-    def __init__(self, n_state: int, n_head: int):
-        super().__init__()
-        self.n_head = n_head
-        self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
-        self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
-        self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
-        self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
-
-    def forward(
-        self,
-        x: paddle.Tensor,
-        xa: Optional[paddle.Tensor] = None,
-        mask: Optional[paddle.Tensor] = None,
-        kv_cache: Optional[dict] = None,
+    def exact_div(x, y):
+        assert x % y == 0
+        return x // y
+
+    _MODELS = ["large"]
+    SAMPLE_RATE = 16000
+    N_FFT = 400
+    N_MELS = 80
+    HOP_LENGTH = 160
+    CHUNK_LENGTH = 30
+    N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000: number of samples in a chunk
+    N_FRAMES = exact_div(
+        N_SAMPLES, HOP_LENGTH
+    )  # 3000: number of frames in a mel spectrogram input
+
+    @dataclass
+    class ModelDimensions:
+        n_mels: int
+        n_audio_ctx: int
+        n_audio_state: int
+        n_audio_head: int
+        n_audio_layer: int
+        n_vocab: int
+        n_text_ctx: int
+        n_text_state: int
+        n_text_head: int
+        n_text_layer: int
+
+    LANGUAGES = {
+        "en": "english",
+        "zh": "chinese",
+        "de": "german",
+        "es": "spanish",
+        "ru": "russian",
+        "ko": "korean",
+        "fr": "french",
+        "ja": "japanese",
+        "pt": "portuguese",
+        "tr": "turkish",
+        "pl": "polish",
+        "ca": "catalan",
+        "nl": "dutch",
+        "ar": "arabic",
+        "sv": "swedish",
+        "it": "italian",
+        "id": "indonesian",
+        "hi": "hindi",
+        "fi": "finnish",
+        "vi": "vietnamese",
+        "iw": "hebrew",
+        "uk": "ukrainian",
+        "el": "greek",
+        "ms": "malay",
+        "cs": "czech",
+        "ro": "romanian",
+        "da": "danish",
+        "hu": "hungarian",
+        "ta": "tamil",
+        "no": "norwegian",
+        "th": "thai",
+        "ur": "urdu",
+        "hr": "croatian",
+        "bg": "bulgarian",
+        "lt": "lithuanian",
+        "la": "latin",
+        "mi": "maori",
+        "ml": "malayalam",
+        "cy": "welsh",
+        "sk": "slovak",
+        "te": "telugu",
+        "fa": "persian",
+        "lv": "latvian",
+        "bn": "bengali",
+        "sr": "serbian",
+        "az": "azerbaijani",
+        "sl": "slovenian",
+        "kn": "kannada",
+        "et": "estonian",
+        "mk": "macedonian",
+        "br": "breton",
+        "eu": "basque",
+        "is": "icelandic",
+        "hy": "armenian",
+        "ne": "nepali",
+        "mn": "mongolian",
+        "bs": "bosnian",
+        "kk": "kazakh",
+        "sq": "albanian",
+        "sw": "swahili",
+        "gl": "galician",
+        "mr": "marathi",
+        "pa": "punjabi",
+        "si": "sinhala",
+        "km": "khmer",
+        "sn": "shona",
+        "yo": "yoruba",
+        "so": "somali",
+        "af": "afrikaans",
+        "oc": "occitan",
+        "ka": "georgian",
+        "be": "belarusian",
+        "tg": "tajik",
+        "sd": "sindhi",
+        "gu": "gujarati",
+        "am": "amharic",
+        "yi": "yiddish",
+        "lo": "lao",
+        "uz": "uzbek",
+        "fo": "faroese",
+        "ht": "haitian creole",
+        "ps": "pashto",
+        "tk": "turkmen",
+        "nn": "nynorsk",
+        "mt": "maltese",
+        "sa": "sanskrit",
+        "lb": "luxembourgish",
+        "my": "myanmar",
+        "bo": "tibetan",
+        "tl": "tagalog",
+        "mg": "malagasy",
+        "as": "assamese",
+        "tt": "tatar",
+        "haw": "hawaiian",
+        "ln": "lingala",
+        "ha": "hausa",
+        "ba": "bashkir",
+        "jw": "javanese",
+        "su": "sundanese",
+    }
+
+    # language code lookup by name, with a few language aliases
+    TO_LANGUAGE_CODE = {
+        **{language: code for code, language in LANGUAGES.items()},
+        "burmese": "my",
+        "valencian": "ca",
+        "flemish": "nl",
+        "haitian": "ht",
+        "letzeburgesch": "lb",
+        "pushto": "ps",
+        "panjabi": "pa",
+        "moldavian": "ro",
+        "moldovan": "ro",
+        "sinhalese": "si",
+        "castilian": "es",
+    }
+
+    def compression_ratio(text) -> float:
+        return len(text) / len(zlib.compress(text.encode("utf-8")))
+
+    def format_timestamp(
+        seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
     ):
-        q = self.query(x)
+        assert seconds >= 0, "non-negative timestamp expected"
+        milliseconds = round(seconds * 1000.0)
 
-        if kv_cache is None or xa is None or self.key not in kv_cache:
-            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
-            # otherwise, perform key/value projections for self- or cross-attention as usual.
-            k = self.key(x if xa is None else xa)
-            v = self.value(x if xa is None else xa)
-        else:
-            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
-            k = kv_cache[self.key]
-            v = kv_cache[self.value]
-
-        wv = self.qkv_attention(q, k, v, mask)
-        return self.out(wv)
-
-    def qkv_attention(
-        self,
-        q: paddle.Tensor,
-        k: paddle.Tensor,
-        v: paddle.Tensor,
-        mask: Optional[paddle.Tensor] = None,
-    ):
-        n_batch, n_ctx, n_state = q.shape
-        scale = (n_state // self.n_head) ** -0.25
-        q = (
-            paddle.transpose(q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
-            * scale
-        )
-        k = (
-            paddle.transpose(k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1))
-            * scale
-        )
-        v = paddle.transpose(v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3))
+        hours = milliseconds // 3_600_000
+        milliseconds -= hours * 3_600_000
 
-        qk = q @ k
-        if mask is not None:
-            qk = qk + mask[:n_ctx, :n_ctx]
+        minutes = milliseconds // 60_000
+        milliseconds -= minutes * 60_000
 
-        w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
-        return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
+        seconds = milliseconds // 1_000
+        milliseconds -= seconds * 1_000
 
+        hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
+        return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
 
-class ResidualAttentionBlock(paddle.nn.Layer):
-    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
-        super().__init__()
+    @dataclass(frozen=True)
+    class Tokenizer:
+        """A thin wrapper around `GPTTokenizer` providing quick access to special tokens"""
 
-        self.attn = MultiHeadAttention(n_state, n_head)
-        self.attn_ln = paddle.nn.LayerNorm(n_state)
+        tokenizer: "GPTTokenizer"
+        language: Optional[str]
+        sot_sequence: Tuple[int]
 
-        self.cross_attn = (
-            MultiHeadAttention(n_state, n_head) if cross_attention else None
-        )
-        self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None
+        def encode(self, text, **kwargs):
+            return self.tokenizer.encode(text, **kwargs)
 
-        n_mlp = n_state * 4
-        self.mlp = paddle.nn.Sequential(
-            paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
-            paddle.nn.GELU(),
-            paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
-        )
-        self.mlp_ln = paddle.nn.LayerNorm(n_state)
-
-    def forward(
-        self,
-        x: paddle.Tensor,
-        xa: Optional[paddle.Tensor] = None,
-        mask: Optional[paddle.Tensor] = None,
-        kv_cache: Optional[dict] = None,
-    ):
-        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
-        if self.cross_attn:
-            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
-        x = x + self.mlp(self.mlp_ln(x))
-        return x
-
-
-def sinusoids(length, channels, max_timescale=10000):
-    """Returns sinusoids for positional embedding"""
-    assert channels % 2 == 0
-    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
-    inv_timescales = paddle.exp(
-        -log_timescale_increment * paddle.arange(channels // 2, dtype=paddle.float32)
-    )
-    scaled_time = (
-        paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
-        * inv_timescales[np.newaxis, :]
-    )
-    return paddle.to_tensor(
-        paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
-    )
-
-
-class AudioEncoder(paddle.nn.Layer):
-    def __init__(
-        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
-    ):
-        super().__init__()
-        self.conv1 = paddle.nn.Conv1D(
-            n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
-        )
-        self.conv2 = paddle.nn.Conv1D(
-            n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
-        )
-        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
-
-        self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
-            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
-        )
-        self.ln_post = paddle.nn.LayerNorm(n_state)
-
-    def forward(self, x: paddle.Tensor):
-        """
-        x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
-            the mel spectrogram of the audio
-        """
-        x = paddle.nn.functional.gelu(self.conv1(x))
-        x = paddle.nn.functional.gelu(self.conv2(x))
-        x = paddle.transpose(x, (0, 2, 1))
-
-        assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
-        x = x + self.positional_embedding
-
-        for block in self.blocks:
-            x = block(x)
-
-        x = self.ln_post(x)
-        return x
+        def decode(
+            self, token_ids: Union[int, List[int], np.ndarray, paddle.Tensor], **kwargs
+        ):
+            if len(token_ids) > 1:
+                ids_list = []
+                for ids in token_ids:
+                    if paddle.is_tensor(ids):
+                        ids = ids.item()
+                    if ids < len(self.tokenizer):
+                        ids_list.append(ids)
+                token_ids = ids_list
+            elif len(token_ids) == 1:
+                token_ids = token_ids[0]
+            else:
+                raise ValueError(f"token_ids {token_ids} load error.")
+
+            return self.tokenizer.decode(token_ids, **kwargs)
+
+        def decode_with_timestamps(self, tokens) -> str:
+            """
+            Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
+            This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
+            """
+            outputs = [[]]
+            for token in tokens:
+                if token >= self.timestamp_begin:
+                    timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
+                    outputs.append(timestamp)
+                    outputs.append([])
+                else:
+                    outputs[-1].append(token)
+            outputs = [
+                s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs
+            ]
+            return "".join(outputs)
+
+        @property
+        @lru_cache()
+        def eot(self) -> int:
+            return self.tokenizer.eos_token_id
+
+        @property
+        @lru_cache()
+        def sot(self) -> int:
+            return self._get_single_token_id("<|startoftranscript|>")
+
+        @property
+        @lru_cache()
+        def sot_lm(self) -> int:
+            return self._get_single_token_id("<|startoflm|>")
+
+        @property
+        @lru_cache()
+        def sot_prev(self) -> int:
+            return self._get_single_token_id("<|startofprev|>")
+
+        @property
+        @lru_cache()
+        def no_speech(self) -> int:
+            return self._get_single_token_id("<|nospeech|>")
+
+        @property
+        @lru_cache()
+        def no_timestamps(self) -> int:
+            return self._get_single_token_id("<|notimestamps|>")
+
+        @property
+        @lru_cache()
+        def timestamp_begin(self) -> int:
+            return self.tokenizer.all_special_ids[-1] + 1
+
+        @property
+        @lru_cache()
+        def language_token(self) -> int:
+            """Returns the token id corresponding to the value of the `language` field"""
+            if self.language is None:
+                raise ValueError(
+                    "This tokenizer does not have language token configured"
+                )
 
+            additional_tokens = dict(
+                zip(
+                    self.tokenizer.additional_special_tokens,
+                    self.tokenizer.additional_special_tokens_ids,
+                )
+            )
+            candidate = f"<|{self.language}|>"
+            if candidate in additional_tokens:
+                return additional_tokens[candidate]
 
-class TextDecoder(paddle.nn.Layer):
-    def __init__(
-        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
-    ):
-        super().__init__()
+            raise KeyError(f"Language {self.language} not found in tokenizer.")
 
-        self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
-        self.positional_embedding = paddle.create_parameter(
-            shape=[n_ctx, n_state], dtype="float32"
-        )
+        @property
+        @lru_cache()
+        def all_language_tokens(self) -> Tuple[int]:
+            result = []
+            for token, token_id in zip(
+                self.tokenizer.additional_special_tokens,
+                self.tokenizer.additional_special_tokens_ids,
+            ):
+                if token.strip("<|>") in LANGUAGES:
+                    result.append(token_id)
+            return tuple(result)
+
+        @property
+        @lru_cache()
+        def all_language_codes(self) -> Tuple[str]:
+            return tuple(
+                self.decode([l]).strip("<|>") for l in self.all_language_tokens
+            )
 
-        self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
-            [
-                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
-                for _ in range(n_layer)
-            ]
-        )
-        self.ln = paddle.nn.LayerNorm(n_state)
+        @property
+        @lru_cache()
+        def sot_sequence_including_notimestamps(self) -> Tuple[int]:
+            return tuple(list(self.sot_sequence) + [self.no_timestamps])
+
+        @property
+        @lru_cache()
+        def non_speech_tokens(self) -> Tuple[int]:
+            """
+            Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
+            annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
+            - ♪♪♪
+            - ( SPEAKING FOREIGN LANGUAGE )
+            - [DAVID] Hey there,
+            keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
+            """
+            symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
+            symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
+
+            # symbols that may be a single token or multiple tokens depending on the tokenizer.
+            # In case they're multiple tokens, suppress the first token, which is safe because:
+            # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
+            # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
+            miscellaneous = set("♩♪♫♬♭♮♯")
+            assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
+
+            # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
+            result = {
+                self.tokenizer.encode(" -").input_ids[0],
+                self.tokenizer.encode(" '").input_ids[0],
+            }
+            for symbol in symbols + list(miscellaneous):
+                for tokens in [
+                    self.tokenizer.encode(symbol).input_ids,
+                    self.tokenizer.encode(" " + symbol).input_ids,
+                ]:
+                    if len(tokens) == 1 or symbol in miscellaneous:
+                        result.add(tokens[0])
+
+            return tuple(sorted(result))
+
+        def _get_single_token_id(self, text) -> int:
+            tokens = self.tokenizer.encode(text).input_ids
+            assert len(tokens) == 1, f"{text} is not encoded as a single token"
+            return tokens[0]
+
+    @lru_cache(maxsize=None)
+    def build_tokenizer(resource_path: str, name: str = "gpt2"):
+        os.environ["TOKENIZERS_PARALLELISM"] = "false"
+        path = os.path.join(resource_path, "assets", name)
+        tokenizer = GPTTokenizer.from_pretrained(path)
+
+        specials = [
+            "<|startoftranscript|>",
+            *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
+            "<|translate|>",
+            "<|transcribe|>",
+            "<|startoflm|>",
+            "<|startofprev|>",
+            "<|nospeech|>",
+            "<|notimestamps|>",
+        ]
 
-        mask = paddle.full(shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32")
-        mask = paddle.triu(mask, diagonal=1)
-        self.register_buffer("mask", mask, persistable=False)
+        tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
+        return tokenizer
+
+    @lru_cache(maxsize=None)
+    def get_tokenizer(
+        multilingual: bool,
+        resource_path: str,
+        *,
+        task: Optional[str] = None,  # Literal["transcribe", "translate", None]
+        language: Optional[str] = None,
+    ) -> Tokenizer:
+        if language is not None:
+            language = language.lower()
+            if language not in LANGUAGES:
+                if language in TO_LANGUAGE_CODE:
+                    language = TO_LANGUAGE_CODE[language]
+                else:
+                    raise ValueError(f"Unsupported language: {language}")
 
-    def forward(
-        self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
-    ):
-        """
-        x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
-            the text tokens
-        xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
-            the encoded audio features to be attended on
-        """
-        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
-        x = (
-            self.token_embedding(x)
-            + self.positional_embedding[offset : offset + x.shape[-1]]
-        )
-        x = x.to(xa.dtype)
-
-        for block in self.blocks:
-            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
-
-        x = self.ln(x)
-        logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
-
-        return logits
-
-
-@dataclass(frozen=True)
-class DecodingOptions:
-    task: str = (
-        "transcribe"  # whether to perform X->X "transcribe" or X->English "translate"
-    )
-    language: Optional[str] = (
-        None  # language that the audio is in; uses detected language if None
-    )
-    # sampling-related options
-    temperature: float = 0.0
-    sample_len: Optional[int] = None  # maximum number of tokens to sample
-    best_of: Optional[int] = (
-        None  # number of independent samples to collect, when t > 0
-    )
-    beam_size: Optional[int] = None  # number of beams in beam search, when t == 0
-    patience: Optional[float] = (
-        None  # patience in beam search (https://arxiv.org/abs/2204.05424)
-    )
-
-    # options for ranking generations (either beams or best-of-N samples)
-    length_penalty: Optional[float] = (
-        None  # "alpha" in Google NMT, None defaults to length norm
-    )
-
-    # prompt, prefix, and token suppression
-    prompt: Optional[Union[str, List[int]]] = (
-        None  # text or tokens for the previous context
-    )
-    prefix: Optional[Union[str, List[int]]] = (
-        None  # text or tokens to prefix the current context
-    )
-    suppress_blank: bool = True  # this will suppress blank outputs
-
-    # list of tokens ids (or comma-separated token ids) to suppress
-    # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
-    suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
-
-    # timestamp sampling options
-    without_timestamps: bool = False  # use <|notimestamps|> to sample text tokens only
-    max_initial_timestamp: Optional[float] = (
-        1.0  # the initial timestamp cannot be later than this
-    )
-
-    # implementation details
-    fp16: bool = False  # use fp16 for most of the calculation
-
-
-@dataclass(frozen=True)
-class DecodingResult:
-    audio_features: paddle.Tensor
-    language: str
-    language_probs: Optional[Dict[str, float]] = None
-    tokens: List[int] = field(default_factory=list)
-    text: str = ""
-    avg_logprob: float = np.nan
-    no_speech_prob: float = np.nan
-    temperature: float = np.nan
-    compression_ratio: float = np.nan
-
-
-class Inference:
-    def logits(
-        self, tokens: paddle.Tensor, audio_features: paddle.Tensor
-    ) -> paddle.Tensor:
-        """Perform a forward pass on the decoder and return per-token logits"""
-        raise NotImplementedError
-
-    def rearrange_kv_cache(self, source_indices) -> None:
-        """Update the key-value cache according to the updated beams"""
-        raise NotImplementedError
-
-    def cleanup_caching(self) -> None:
-        """Clean up any resources or hooks after decoding is finished"""
-
-
-class WhisperInference(Inference):
-    def __init__(self, model: "Whisper", initial_token_length: int):
-        self.model: "Whisper" = model
-        self.initial_token_length = initial_token_length
-        self.kv_cache = {}
-        self.hooks = []
-
-    def logits(
-        self, tokens: paddle.Tensor, audio_features: paddle.Tensor
-    ) -> paddle.Tensor:
-        if not self.kv_cache:
-            self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
-
-        if tokens.shape[-1] > self.initial_token_length:
-            # only need to use the last token except in the first forward pass
-            tokens = tokens[:, -1:]
-
-        return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
-
-    def cleanup_caching(self):
-        for hook in self.hooks:
-            hook.remove()
-
-        self.kv_cache = {}
-        self.hooks = []
-
-    def rearrange_kv_cache(self, source_indices):
-        for module, tensor in self.kv_cache.items():
-            # update the key/value cache to contain the selected sequences
-            self.kv_cache[module] = tensor[source_indices].detach()
-
-
-@paddle.no_grad()
-def detect_language(
-    model: "Whisper",
-    mel: paddle.Tensor,
-    resource_path: str,
-    tokenizer: Tokenizer = None,
-) -> Tuple[paddle.Tensor, List[dict]]:
-    """
-    Detect the spoken language in the audio, and return them as list of strings, along with the ids
-    of the most probable language tokens and the probability distribution over all language tokens.
-    This is performed outside the main decode loop in order to not interfere with kv-caching.
-
-    Returns
-    -------
-    language_tokens : Tensor, shape = (batch_size,)
-        ids of the most probable language tokens, which appears after the startoftranscript token.
-    language_probs : List[Dict[str, float]], length = batch_size
-        list of dictionaries containing the probability distribution over all languages.
-    """
-    if tokenizer is None:
-        tokenizer = get_tokenizer(model.is_multilingual, resource_path=resource_path)
-    if (
-        tokenizer.language is None
-        or tokenizer.language_token not in tokenizer.sot_sequence
-    ):
-        raise ValueError(
-            "This model doesn't have language tokens so it can't perform lang id"
+        if multilingual:
+            tokenizer_name = "multilingual"
+            task = task or "transcribe"
+            language = language or "en"
+        else:
+            tokenizer_name = "gpt2"
+            task = None
+            language = None
+
+        tokenizer = build_tokenizer(resource_path=resource_path, name=tokenizer_name)
+        all_special_ids: List[int] = tokenizer.all_special_ids
+        sot: int = all_special_ids[1]
+        translate: int = all_special_ids[-6]
+        transcribe: int = all_special_ids[-5]
+
+        langs = tuple(LANGUAGES.keys())
+        sot_sequence = [sot]
+        if language is not None:
+            sot_sequence.append(sot + 1 + langs.index(language))
+        if task is not None:
+            sot_sequence.append(transcribe if task == "transcribe" else translate)
+
+        return Tokenizer(
+            tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)
         )
 
-    single = mel.ndim == 2
-    if single:
-        mel = mel.unsqueeze(0)
-
-    # skip encoder forward pass if already-encoded audio features were given
-    if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
-        mel = model.encoder(mel)
-
-    # forward pass using a single token, startoftranscript
-    batch_size = mel.shape[0]
-    x = paddle.to_tensor([[tokenizer.sot]] * batch_size)  # [batch_size, 1]
-    logits = model.logits(x, mel)[:, 0]
-
-    # collect detected languages; suppress all non-language tokens
-    mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
-    mask[list(tokenizer.all_language_tokens)] = False
-    logits[:, mask] = -np.inf
-    language_tokens = paddle.argmax(logits, axis=-1)
-    language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
-    language_probs = [
-        {
-            c: language_token_probs[i, j].tolist()
-            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
-        }
-        for i in range(batch_size)
-    ]
+    class MultiHeadAttention(paddle.nn.Layer):
+        def __init__(self, n_state: int, n_head: int):
+            super().__init__()
+            self.n_head = n_head
+            self.query = paddle.nn.Linear(n_state, n_state, bias_attr=True)
+            self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False)
+            self.value = paddle.nn.Linear(n_state, n_state, bias_attr=True)
+            self.out = paddle.nn.Linear(n_state, n_state, bias_attr=True)
+
+        def forward(
+            self,
+            x: paddle.Tensor,
+            xa: Optional[paddle.Tensor] = None,
+            mask: Optional[paddle.Tensor] = None,
+            kv_cache: Optional[dict] = None,
+        ):
+            q = self.query(x)
 
-    if single:
-        language_tokens = language_tokens[0]
-        language_probs = language_probs[0]
-
-    return language_tokens, language_probs
-
-
-def transcribe(
-    model: "Whisper",
-    mel: paddle.Tensor,
-    resource_path: str,
-    *,
-    verbose: Optional[bool] = None,
-    temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
-    compression_ratio_threshold: Optional[float] = 2.4,
-    logprob_threshold: Optional[float] = -1.0,
-    no_speech_threshold: Optional[float] = 0.6,
-    condition_on_previous_text: bool = True,
-    **decode_options,
-):
-    """
-    Transcribe an audio file using Whisper
-
-    Parameters
-    ----------
-    model: Whisper
-        The Whisper model instance
-
-    mel: paddle.Tensor
-        The audio feature
-
-    verbose: bool
-        Whether to display the text being decoded to the console. If True, displays all the details,
-        If False, displays minimal details. If None, does not display anything
-
-    temperature: Union[float, Tuple[float, ...]]
-        Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
-        upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
-
-    compression_ratio_threshold: float
-        If the gzip compression ratio is above this value, treat as failed
-
-    logprob_threshold: float
-        If the average log probability over sampled tokens is below this value, treat as failed
-
-    no_speech_threshold: float
-        If the no_speech probability is higher than this value AND the average log probability
-        over sampled tokens is below `logprob_threshold`, consider the segment as silent
-
-    condition_on_previous_text: bool
-        if True, the previous output of the model is provided as a prompt for the next window;
-        disabling may make the text inconsistent across windows, but the model becomes less prone to
-        getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
-
-    decode_options: dict
-        Keyword arguments to construct `DecodingOptions` instances
-
-    Returns
-    -------
-    A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
-    the spoken language ("language"), which is detected when `decode_options["language"]` is None.
-    """
-    dtype = np.float32  # paddle only support float32
-
-    if dtype == np.float32:
-        decode_options["fp16"] = False
-
-    if (
-        decode_options.get("language") == "None"
-        or decode_options.get("language", None) is None
-    ):
-        if not model.is_multilingual:
-            decode_options["language"] = "en"
-        else:
-            if verbose:
-                print(
-                    "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
+            if kv_cache is None or xa is None or self.key not in kv_cache:
+                # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+                # otherwise, perform key/value projections for self- or cross-attention as usual.
+                k = self.key(x if xa is None else xa)
+                v = self.value(x if xa is None else xa)
+            else:
+                # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+                k = kv_cache[self.key]
+                v = kv_cache[self.value]
+
+            wv = self.qkv_attention(q, k, v, mask)
+            return self.out(wv)
+
+        def qkv_attention(
+            self,
+            q: paddle.Tensor,
+            k: paddle.Tensor,
+            v: paddle.Tensor,
+            mask: Optional[paddle.Tensor] = None,
+        ):
+            n_batch, n_ctx, n_state = q.shape
+            scale = (n_state // self.n_head) ** -0.25
+            q = (
+                paddle.transpose(
+                    q.reshape([*q.shape[:2], self.n_head, -1]), (0, 2, 1, 3)
                 )
-            segment = pad_or_trim(mel, N_FRAMES)
-            _, probs = model.detect_language(segment, resource_path)
-            decode_options["language"] = max(probs, key=probs.get)
-            if verbose is not None:
-                print(
-                    f"Detected language: {LANGUAGES[decode_options['language']].title()}"
+                * scale
+            )
+            k = (
+                paddle.transpose(
+                    k.reshape([*k.shape[:2], self.n_head, -1]), (0, 2, 3, 1)
                 )
+                * scale
+            )
+            v = paddle.transpose(
+                v.reshape([*v.shape[:2], self.n_head, -1]), (0, 2, 1, 3)
+            )
 
-    language = decode_options["language"]
-    task = decode_options.get("task", "transcribe")
-    tokenizer = get_tokenizer(
-        model.is_multilingual, resource_path=resource_path, language=language, task=task
-    )
+            qk = q @ k
+            if mask is not None:
+                qk = qk + mask[:n_ctx, :n_ctx]
 
-    def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
-        temperatures = (
-            [temperature] if isinstance(temperature, (int, float)) else temperature
-        )
-        decode_result = None
-
-        for t in temperatures:
-            kwargs = {**decode_options}
-            if t > 0:
-                # disable beam_size and patience when t > 0
-                kwargs.pop("beam_size", None)
-                kwargs.pop("patience", None)
-            else:
-                # disable best_of when t == 0
-                kwargs.pop("best_of", None)
+            w = paddle.nn.functional.softmax(qk.astype(q.dtype), axis=-1)
+            return paddle.transpose((w @ v), (0, 2, 1, 3)).flatten(start_axis=2)
 
-            options = DecodingOptions(**kwargs, temperature=t)
-            decode_result = model.decode(segment, options, resource_path)
+    class ResidualAttentionBlock(paddle.nn.Layer):
+        def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+            super().__init__()
 
-            needs_fallback = False
-            if (
-                compression_ratio_threshold is not None
-                and decode_result.compression_ratio > compression_ratio_threshold
-            ):
-                needs_fallback = True  # too repetitive
-            if (
-                logprob_threshold is not None
-                and decode_result.avg_logprob < logprob_threshold
-            ):
-                needs_fallback = True  # average log probability is too low
-
-            if not needs_fallback:
-                break
-
-        return decode_result
-
-    seek = 0
-    input_stride = exact_div(
-        N_FRAMES, model.dims.n_audio_ctx
-    )  # mel frames per output token: 2
-    time_precision = (
-        input_stride * HOP_LENGTH / SAMPLE_RATE
-    )  # time per output token: 0.02 (seconds)
-    all_tokens = []
-    all_segments = []
-    prompt_reset_since = 0
-
-    initial_prompt = decode_options.pop("initial_prompt", None)
-    if initial_prompt and initial_prompt != "None":
-        initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
-        all_tokens.extend(initial_prompt)
-    else:
-        initial_prompt = []
-
-    def add_segment(
-        *, start: float, end: float, text_tokens: paddle.Tensor, result: DecodingResult
-    ):
-        text = tokenizer.decode(
-            [token for token in text_tokens if token < tokenizer.eot]
-        )
-        if len(text.strip()) == 0:  # skip empty text output
-            return
+            self.attn = MultiHeadAttention(n_state, n_head)
+            self.attn_ln = paddle.nn.LayerNorm(n_state)
 
-        all_segments.append(
-            {
-                "id": len(all_segments),
-                "seek": seek,
-                "start": start,
-                "end": end,
-                "text": text,
-                "tokens": result.tokens,
-                "temperature": result.temperature,
-                "avg_logprob": result.avg_logprob,
-                "compression_ratio": result.compression_ratio,
-                "no_speech_prob": result.no_speech_prob,
-            }
-        )
-        if verbose:
-            print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
-
-    # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
-    num_frames = mel.shape[-1]
-    previous_seek_value = seek
-
-    with tqdm.tqdm(
-        total=num_frames, unit="frames", disable=verbose is not False
-    ) as pbar:
-        while seek < num_frames:
-            timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
-            segment = pad_or_trim(mel[:, seek:], N_FRAMES)
-            segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
-
-            decode_options["prompt"] = all_tokens[prompt_reset_since:]
-            result: DecodingResult = decode_with_fallback(segment)
-            tokens = paddle.to_tensor(result.tokens)
-
-            if no_speech_threshold is not None:
-                # no voice activity check
-                should_skip = result.no_speech_prob > no_speech_threshold
-                if (
-                    logprob_threshold is not None
-                    and result.avg_logprob > logprob_threshold
-                ):
-                    # don't skip if the logprob is high enough, despite the no_speech_prob
-                    should_skip = False
-
-                if should_skip:
-                    seek += segment.shape[
-                        -1
-                    ]  # fast-forward to the next segment boundary
-                    continue
-
-            timestamp_tokens: paddle.Tensor = tokens.greater_equal(
-                paddle.to_tensor(tokenizer.timestamp_begin)
+            self.cross_attn = (
+                MultiHeadAttention(n_state, n_head) if cross_attention else None
+            )
+            self.cross_attn_ln = (
+                paddle.nn.LayerNorm(n_state) if cross_attention else None
             )
 
-            consecutive = paddle.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
-            if (
-                len(consecutive) > 0
-            ):  # if the output contains two consecutive timestamp tokens
-                consecutive = paddle.add(consecutive, paddle.to_tensor(1))
-                last_slice = 0
-                for current_slice in consecutive:
-                    sliced_tokens = tokens[last_slice:current_slice]
-                    start_timestamp_position = (
-                        sliced_tokens[0].item() - tokenizer.timestamp_begin
-                    )
-                    end_timestamp_position = (
-                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
-                    )
-                    add_segment(
-                        start=timestamp_offset
-                        + start_timestamp_position * time_precision,
-                        end=timestamp_offset + end_timestamp_position * time_precision,
-                        text_tokens=sliced_tokens[1:-1],
-                        result=result,
-                    )
-                    last_slice = current_slice
-                last_timestamp_position = (
-                    tokens[last_slice - 1].item() - tokenizer.timestamp_begin
-                )
-                seek += last_timestamp_position * input_stride
-                all_tokens.extend(tokens[: last_slice + 1].tolist())
-            else:
-                duration = segment_duration
-                timestamps = tokens[timestamp_tokens.nonzero().flatten()]
-                if (
-                    len(timestamps) > 0
-                    and timestamps[-1].item() != tokenizer.timestamp_begin
-                ):
-                    # no consecutive timestamps but it has a timestamp; use the last one.
-                    # single timestamp at the end means no speech after the last timestamp.
-                    last_timestamp_position = (
-                        timestamps[-1].item() - tokenizer.timestamp_begin
-                    )
-                    duration = last_timestamp_position * time_precision
-
-                add_segment(
-                    start=timestamp_offset,
-                    end=timestamp_offset + duration,
-                    text_tokens=tokens,
-                    result=result,
-                )
-
-                seek += segment.shape[-1]
-                all_tokens.extend(tokens.tolist())
-
-            if not condition_on_previous_text or result.temperature > 0.5:
-                # do not feed the prompt tokens if a high temperature was used
-                prompt_reset_since = len(all_tokens)
+            n_mlp = n_state * 4
+            self.mlp = paddle.nn.Sequential(
+                paddle.nn.Linear(n_state, n_mlp, bias_attr=True),
+                paddle.nn.GELU(),
+                paddle.nn.Linear(n_mlp, n_state, bias_attr=True),
+            )
+            self.mlp_ln = paddle.nn.LayerNorm(n_state)
+
+        def forward(
+            self,
+            x: paddle.Tensor,
+            xa: Optional[paddle.Tensor] = None,
+            mask: Optional[paddle.Tensor] = None,
+            kv_cache: Optional[dict] = None,
+        ):
+            x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
+            if self.cross_attn:
+                x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
+            x = x + self.mlp(self.mlp_ln(x))
+            return x
+
+    def sinusoids(length, channels, max_timescale=10000):
+        """Returns sinusoids for positional embedding"""
+        assert channels % 2 == 0
+        log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+        inv_timescales = paddle.exp(
+            -log_timescale_increment
+            * paddle.arange(channels // 2, dtype=paddle.float32)
+        )
+        scaled_time = (
+            paddle.arange(length, dtype=paddle.float32)[:, np.newaxis]
+            * inv_timescales[np.newaxis, :]
+        )
+        return paddle.to_tensor(
+            paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1)
+        )
 
-            # update progress bar
-            pbar.update(min(num_frames, seek) - previous_seek_value)
-            previous_seek_value = seek
+    class AudioEncoder(paddle.nn.Layer):
+        def __init__(
+            self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+        ):
+            super().__init__()
+            self.conv1 = paddle.nn.Conv1D(
+                n_mels, n_state, kernel_size=3, stride=1, padding=1, bias_attr=True
+            )
+            self.conv2 = paddle.nn.Conv1D(
+                n_state, n_state, kernel_size=3, stride=2, padding=1, bias_attr=True
+            )
+            self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
 
-    return dict(
-        text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
-        segments=all_segments,
-        language=language,
-    )
+            self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
+                [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+            )
+            self.ln_post = paddle.nn.LayerNorm(n_state)
+
+        def forward(self, x: paddle.Tensor):
+            """
+            x : paddle.Tensor, shape = (batch_size, n_mels, n_ctx)
+                the mel spectrogram of the audio
+            """
+            x = paddle.nn.functional.gelu(self.conv1(x))
+            x = paddle.nn.functional.gelu(self.conv2(x))
+            x = paddle.transpose(x, (0, 2, 1))
+
+            assert (
+                x.shape[1:] == self.positional_embedding.shape
+            ), "incorrect audio shape"
+            x = x + self.positional_embedding
+
+            for block in self.blocks:
+                x = block(x)
+
+            x = self.ln_post(x)
+            return x
+
+    class TextDecoder(paddle.nn.Layer):
+        def __init__(
+            self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+        ):
+            super().__init__()
 
+            self.token_embedding = paddle.nn.Embedding(n_vocab, n_state)
+            self.positional_embedding = paddle.create_parameter(
+                shape=[n_ctx, n_state], dtype="float32"
+            )
 
-class SequenceRanker:
-    def rank(
-        self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
-    ) -> List[int]:
-        """
-        Given a list of groups of samples and their cumulative log probabilities,
-        return the indices of the samples in each group to select as the final result
-        """
-        raise NotImplementedError
+            self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList(
+                [
+                    ResidualAttentionBlock(n_state, n_head, cross_attention=True)
+                    for _ in range(n_layer)
+                ]
+            )
+            self.ln = paddle.nn.LayerNorm(n_state)
 
+            mask = paddle.full(
+                shape=[n_ctx, n_state], fill_value=-np.inf, dtype="float32"
+            )
+            mask = paddle.triu(mask, diagonal=1)
+            self.register_buffer("mask", mask, persistable=False)
 
-class MaximumLikelihoodRanker(SequenceRanker):
-    """
-    Select the sample with the highest log probabilities, penalized using either
-    a simple length normalization or Google NMT paper's length penalty
-    """
+        def forward(
+            self, x: paddle.Tensor, xa: paddle.Tensor, kv_cache: Optional[dict] = None
+        ):
+            """
+            x : paddle.LongTensor, shape = (batch_size, <= n_ctx)
+                the text tokens
+            xa : paddle.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
+                the encoded audio features to be attended on
+            """
+            offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+            x = (
+                self.token_embedding(x)
+                + self.positional_embedding[offset : offset + x.shape[-1]]
+            )
+            x = x.to(xa.dtype)
 
-    def __init__(self, length_penalty: Optional[float]):
-        self.length_penalty = length_penalty
+            for block in self.blocks:
+                x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
 
-    def rank(self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]):
-        def scores(logprobs, lengths):
-            result = []
-            for logprob, length in zip(logprobs, lengths):
-                if self.length_penalty is None or self.length_penalty == "None":
-                    penalty = length
-                else:
-                    # from the Google NMT paper
-                    penalty = ((5 + length) / 6) ** self.length_penalty
-                result.append(logprob / penalty)
-            return result
+            x = self.ln(x)
+            logits = x @ paddle.transpose(self.token_embedding.weight, (1, 0))
 
-        # get the sequence with the highest score
-        lengths = [[len(t) for t in s] for s in tokens]
-        return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
+            return logits
 
+    @dataclass(frozen=True)
+    class DecodingOptions:
+        task: str = (
+            "transcribe"  # whether to perform X->X "transcribe" or X->English "translate"
+        )
+        language: Optional[str] = (
+            None  # language that the audio is in; uses detected language if None
+        )
+        # sampling-related options
+        temperature: float = 0.0
+        sample_len: Optional[int] = None  # maximum number of tokens to sample
+        best_of: Optional[int] = (
+            None  # number of independent samples to collect, when t > 0
+        )
+        beam_size: Optional[int] = None  # number of beams in beam search, when t == 0
+        patience: Optional[float] = (
+            None  # patience in beam search (https://arxiv.org/abs/2204.05424)
+        )
 
-class TokenDecoder:
-    def reset(self):
-        """Initialize any stateful variables for decoding a new sequence"""
+        # options for ranking generations (either beams or best-of-N samples)
+        length_penalty: Optional[float] = (
+            None  # "alpha" in Google NMT, None defaults to length norm
+        )
 
-    def update(
-        self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
-    ) -> Tuple[paddle.Tensor, bool]:
-        """Specify how to select the next token, based on the current trace and logits
+        # prompt, prefix, and token suppression
+        prompt: Optional[Union[str, List[int]]] = (
+            None  # text or tokens for the previous context
+        )
+        prefix: Optional[Union[str, List[int]]] = (
+            None  # text or tokens to prefix the current context
+        )
+        suppress_blank: bool = True  # this will suppress blank outputs
 
-        Parameters
-        ----------
-        tokens : Tensor, shape = (n_batch, current_sequence_length)
-            all tokens in the context so far, including the prefix and sot_sequence tokens
+        # list of tokens ids (or comma-separated token ids) to suppress
+        # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
+        suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
 
-        logits : Tensor, shape = (n_batch, vocab_size)
-            per-token logits of the probability distribution at the current step
+        # timestamp sampling options
+        without_timestamps: bool = (
+            False  # use <|notimestamps|> to sample text tokens only
+        )
+        max_initial_timestamp: Optional[float] = (
+            1.0  # the initial timestamp cannot be later than this
+        )
 
-        sum_logprobs : Tensor, shape = (n_batch)
-            cumulative log probabilities for each sequence
+        # implementation details
+        fp16: bool = False  # use fp16 for most of the calculation
+
+    @dataclass(frozen=True)
+    class DecodingResult:
+        audio_features: paddle.Tensor
+        language: str
+        language_probs: Optional[Dict[str, float]] = None
+        tokens: List[int] = field(default_factory=list)
+        text: str = ""
+        avg_logprob: float = np.nan
+        no_speech_prob: float = np.nan
+        temperature: float = np.nan
+        compression_ratio: float = np.nan
+
+    class Inference:
+        def logits(
+            self, tokens: paddle.Tensor, audio_features: paddle.Tensor
+        ) -> paddle.Tensor:
+            """Perform a forward pass on the decoder and return per-token logits"""
+            raise NotImplementedError
+
+        def rearrange_kv_cache(self, source_indices) -> None:
+            """Update the key-value cache according to the updated beams"""
+            raise NotImplementedError
+
+        def cleanup_caching(self) -> None:
+            """Clean up any resources or hooks after decoding is finished"""
+
+    class WhisperInference(Inference):
+        def __init__(self, model: "Whisper", initial_token_length: int):
+            self.model: "Whisper" = model
+            self.initial_token_length = initial_token_length
+            self.kv_cache = {}
+            self.hooks = []
+
+        def logits(
+            self, tokens: paddle.Tensor, audio_features: paddle.Tensor
+        ) -> paddle.Tensor:
+            if not self.kv_cache:
+                self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
+
+            if tokens.shape[-1] > self.initial_token_length:
+                # only need to use the last token except in the first forward pass
+                tokens = tokens[:, -1:]
+
+            return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
+
+        def cleanup_caching(self):
+            for hook in self.hooks:
+                hook.remove()
+
+            self.kv_cache = {}
+            self.hooks = []
+
+        def rearrange_kv_cache(self, source_indices):
+            for module, tensor in self.kv_cache.items():
+                # update the key/value cache to contain the selected sequences
+                self.kv_cache[module] = tensor[source_indices].detach()
 
+    @paddle.no_grad()
+    def detect_language(
+        model: "Whisper",
+        mel: paddle.Tensor,
+        resource_path: str,
+        tokenizer: Tokenizer = None,
+    ) -> Tuple[paddle.Tensor, List[dict]]:
+        """
+        Detect the spoken language in the audio, and return them as list of strings, along with the ids
+        of the most probable language tokens and the probability distribution over all language tokens.
+        This is performed outside the main decode loop in order to not interfere with kv-caching.
         Returns
         -------
-        tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
-            the tokens, appended with the selected next token
-
-        completed : bool
-            True if all sequences has reached the end of text
-
+        language_tokens : Tensor, shape = (batch_size,)
+            ids of the most probable language tokens, which appears after the startoftranscript token.
+        language_probs : List[Dict[str, float]], length = batch_size
+            list of dictionaries containing the probability distribution over all languages.
         """
-        raise NotImplementedError
+        if tokenizer is None:
+            tokenizer = get_tokenizer(
+                model.is_multilingual, resource_path=resource_path
+            )
+        if (
+            tokenizer.language is None
+            or tokenizer.language_token not in tokenizer.sot_sequence
+        ):
+            raise ValueError(
+                "This model doesn't have language tokens so it can't perform lang id"
+            )
 
-    def finalize(
-        self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
-    ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
-        """Finalize search and return the final candidate sequences
+        single = mel.ndim == 2
+        if single:
+            mel = mel.unsqueeze(0)
+
+        # skip encoder forward pass if already-encoded audio features were given
+        if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
+            mel = model.encoder(mel)
+
+        # forward pass using a single token, startoftranscript
+        batch_size = mel.shape[0]
+        x = paddle.to_tensor([[tokenizer.sot]] * batch_size)  # [batch_size, 1]
+        logits = model.logits(x, mel)[:, 0]
+
+        # collect detected languages; suppress all non-language tokens
+        mask = paddle.ones(paddle.to_tensor(logits.shape[-1]), dtype=bool)
+        mask[list(tokenizer.all_language_tokens)] = False
+        logits[:, mask] = -np.inf
+        language_tokens = paddle.argmax(logits, axis=-1)
+        language_token_probs = paddle.nn.functional.softmax(logits, axis=-1)
+        language_probs = [
+            {
+                c: language_token_probs[i, j].tolist()
+                for j, c in zip(
+                    tokenizer.all_language_tokens, tokenizer.all_language_codes
+                )
+            }
+            for i in range(batch_size)
+        ]
 
+        if single:
+            language_tokens = language_tokens[0]
+            language_probs = language_probs[0]
+
+        return language_tokens, language_probs
+
+    @function_requires_deps("tqdm")
+    def transcribe(
+        model: "Whisper",
+        mel: paddle.Tensor,
+        resource_path: str,
+        *,
+        verbose: Optional[bool] = None,
+        temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
+        compression_ratio_threshold: Optional[float] = 2.4,
+        logprob_threshold: Optional[float] = -1.0,
+        no_speech_threshold: Optional[float] = 0.6,
+        condition_on_previous_text: bool = True,
+        **decode_options,
+    ):
+        """
+        Transcribe an audio file using Whisper
         Parameters
         ----------
-        tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
-            all tokens in the context so far, including the prefix and sot_sequence
-
-        sum_logprobs : Tensor, shape = (batch_size, beam_size)
-            cumulative log probabilities for each sequence
-
+        model: Whisper
+            The Whisper model instance
+        mel: paddle.Tensor
+            The audio feature
+        verbose: bool
+            Whether to display the text being decoded to the console. If True, displays all the details,
+            If False, displays minimal details. If None, does not display anything
+        temperature: Union[float, Tuple[float, ...]]
+            Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
+            upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
+        compression_ratio_threshold: float
+            If the gzip compression ratio is above this value, treat as failed
+        logprob_threshold: float
+            If the average log probability over sampled tokens is below this value, treat as failed
+        no_speech_threshold: float
+            If the no_speech probability is higher than this value AND the average log probability
+            over sampled tokens is below `logprob_threshold`, consider the segment as silent
+        condition_on_previous_text: bool
+            if True, the previous output of the model is provided as a prompt for the next window;
+            disabling may make the text inconsistent across windows, but the model becomes less prone to
+            getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
+        decode_options: dict
+            Keyword arguments to construct `DecodingOptions` instances
         Returns
         -------
-        tokens : Sequence[Sequence[Tensor]], length = batch_size
-            sequence of Tensors containing candidate token sequences, for each audio input
-
-        sum_logprobs : List[List[float]], length = batch_size
-            sequence of cumulative log probabilities corresponding to the above
-
+        A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
+        the spoken language ("language"), which is detected when `decode_options["language"]` is None.
         """
-        raise NotImplementedError
-
+        dtype = np.float32  # paddle only support float32
 
-class GreedyDecoder(TokenDecoder):
-    def __init__(self, temperature: float, eot: int):
-        self.temperature = temperature
-        self.eot = eot
+        if dtype == np.float32:
+            decode_options["fp16"] = False
 
-    def update(
-        self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
-    ) -> Tuple[paddle.Tensor, bool]:
-        temperature = self.temperature
-        if temperature == 0:
-            next_tokens = paddle.argmax(logits, axis=-1)
-        else:
-            next_tokens = paddle.distribution.Categorical(
-                logits=logits / temperature
-            ).sample([1])
-            next_tokens = paddle.reshape(
-                next_tokens,
-                [
-                    next_tokens.shape[0] * next_tokens.shape[1],
-                ],
-            )
-
-        logprobs = paddle.nn.functional.log_softmax(
-            logits, axis=-1, dtype=paddle.float32
-        )
-        current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
-        sum_logprobs += current_logprobs * paddle.to_tensor(
-            (tokens[:, -1] != self.eot), dtype=paddle.float32
-        )
-
-        next_tokens[tokens[:, -1] == self.eot] = self.eot
-        tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
-
-        completed = paddle.all((tokens[:, -1] == self.eot))
-        return tokens, completed
+        if (
+            decode_options.get("language") == "None"
+            or decode_options.get("language", None) is None
+        ):
+            if not model.is_multilingual:
+                decode_options["language"] = "en"
+            else:
+                if verbose:
+                    print(
+                        "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
+                    )
+                segment = pad_or_trim(mel, N_FRAMES)
+                _, probs = model.detect_language(segment, resource_path)
+                decode_options["language"] = max(probs, key=probs.get)
+                if verbose is not None:
+                    print(
+                        f"Detected language: {LANGUAGES[decode_options['language']].title()}"
+                    )
 
-    def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
-        # make sure each sequence has at least one EOT token at the end
-        tokens = paddle.nn.functional.pad(
-            tokens, (0, 1), value=self.eot, data_format="NCL"
+        language = decode_options["language"]
+        task = decode_options.get("task", "transcribe")
+        tokenizer = get_tokenizer(
+            model.is_multilingual,
+            resource_path=resource_path,
+            language=language,
+            task=task,
         )
-        return tokens, sum_logprobs.tolist()
-
 
-class BeamSearchDecoder(TokenDecoder):
-    def __init__(
-        self,
-        beam_size: int,
-        eot: int,
-        inference: Inference,
-        patience: Optional[float] = None,
-    ):
-        self.beam_size = beam_size
-        self.eot = eot
-        self.inference = inference
-        self.patience = patience or 1.0
-        if patience is None or patience == "None":
-            self.patience = 1.0
-        else:
-            self.patience = patience
-        self.max_candidates: int = round(beam_size * self.patience)
-        self.finished_sequences = None
-
-        assert (
-            self.max_candidates > 0
-        ), f"Invalid beam size ({beam_size}) or patience ({patience})"
-
-    def reset(self):
-        self.finished_sequences = None
-
-    def update(
-        self, tokens: paddle.Tensor, logits: paddle.Tensor, sum_logprobs: paddle.Tensor
-    ) -> Tuple[paddle.Tensor, bool]:
-        if tokens.shape[0] % self.beam_size != 0:
-            raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
-
-        batch_size = tokens.shape[0] // self.beam_size
-        if self.finished_sequences is None:  # for the first update
-            self.finished_sequences = [{} for _ in range(batch_size)]
-
-        logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
-        next_tokens, source_indices, finished_sequences = [], [], []
-        for i in range(batch_size):
-            scores, sources, finished = {}, {}, {}
-
-            # STEP 1: calculate the cumulative log probabilities for possible candidates
-            for j in range(self.beam_size):
-                idx = i * self.beam_size + j
-                prefix = tokens[idx].tolist()
-                logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
-                for logprob, token in zip(logprob, token):
-                    new_logprob = (sum_logprobs[idx] + logprob).item()
-                    sequence = tuple(prefix + [token.item()])
-                    scores[sequence] = new_logprob
-                    sources[sequence] = idx
-
-            # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
-            saved = 0
-            for sequence in sorted(scores, key=scores.get, reverse=True):
-                if sequence[-1] == self.eot:
-                    finished[sequence] = scores[sequence]
+        def decode_with_fallback(segment: paddle.Tensor) -> DecodingResult:
+            temperatures = (
+                [temperature] if isinstance(temperature, (int, float)) else temperature
+            )
+            decode_result = None
+
+            for t in temperatures:
+                kwargs = {**decode_options}
+                if t > 0:
+                    # disable beam_size and patience when t > 0
+                    kwargs.pop("beam_size", None)
+                    kwargs.pop("patience", None)
                 else:
-                    sum_logprobs[len(next_tokens)] = scores[sequence]
-                    next_tokens.append(sequence)
-                    source_indices.append(sources[sequence])
+                    # disable best_of when t == 0
+                    kwargs.pop("best_of", None)
 
-                    saved += 1
-                    if saved == self.beam_size:
-                        break
+                options = DecodingOptions(**kwargs, temperature=t)
+                decode_result = model.decode(segment, options, resource_path)
 
-            finished_sequences.append(finished)
+                needs_fallback = False
+                if (
+                    compression_ratio_threshold is not None
+                    and decode_result.compression_ratio > compression_ratio_threshold
+                ):
+                    needs_fallback = True  # too repetitive
+                if (
+                    logprob_threshold is not None
+                    and decode_result.avg_logprob < logprob_threshold
+                ):
+                    needs_fallback = True  # average log probability is too low
 
-        tokens = paddle.to_tensor(next_tokens)
-        self.inference.rearrange_kv_cache(source_indices)
+                if not needs_fallback:
+                    break
 
-        # add newly finished sequences to self.finished_sequences
-        assert len(self.finished_sequences) == len(finished_sequences)
-        for previously_finished, newly_finished in zip(
-            self.finished_sequences, finished_sequences
+            return decode_result
+
+        seek = 0
+        input_stride = exact_div(
+            N_FRAMES, model.dims.n_audio_ctx
+        )  # mel frames per output token: 2
+        time_precision = (
+            input_stride * HOP_LENGTH / SAMPLE_RATE
+        )  # time per output token: 0.02 (seconds)
+        all_tokens = []
+        all_segments = []
+        prompt_reset_since = 0
+
+        initial_prompt = decode_options.pop("initial_prompt", None)
+        if initial_prompt and initial_prompt != "None":
+            initial_prompt = tokenizer.encode(" " + initial_prompt.strip()).input_ids
+            all_tokens.extend(initial_prompt)
+        else:
+            initial_prompt = []
+
+        def add_segment(
+            *,
+            start: float,
+            end: float,
+            text_tokens: paddle.Tensor,
+            result: DecodingResult,
         ):
-            for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
-                if len(previously_finished) >= self.max_candidates:
-                    break  # the candidate list is full
-                previously_finished[seq] = newly_finished[seq]
-
-        # mark as completed if all audio has enough number of samples
-        completed = all(
-            len(sequences) >= self.max_candidates
-            for sequences in self.finished_sequences
-        )
-        return tokens, completed
+            text = tokenizer.decode(
+                [token for token in text_tokens if token < tokenizer.eot]
+            )
+            if len(text.strip()) == 0:  # skip empty text output
+                return
+
+            all_segments.append(
+                {
+                    "id": len(all_segments),
+                    "seek": seek,
+                    "start": start,
+                    "end": end,
+                    "text": text,
+                    "tokens": result.tokens,
+                    "temperature": result.temperature,
+                    "avg_logprob": result.avg_logprob,
+                    "compression_ratio": result.compression_ratio,
+                    "no_speech_prob": result.no_speech_prob,
+                }
+            )
+            if verbose:
+                print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")
+
+        # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
+        num_frames = mel.shape[-1]
+        previous_seek_value = seek
+
+        with tqdm.tqdm(
+            total=num_frames, unit="frames", disable=verbose is not False
+        ) as pbar:
+            while seek < num_frames:
+                timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
+                segment = pad_or_trim(mel[:, seek:], N_FRAMES)
+                segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
+
+                decode_options["prompt"] = all_tokens[prompt_reset_since:]
+                result: DecodingResult = decode_with_fallback(segment)
+                tokens = paddle.to_tensor(result.tokens)
+
+                if no_speech_threshold is not None:
+                    # no voice activity check
+                    should_skip = result.no_speech_prob > no_speech_threshold
+                    if (
+                        logprob_threshold is not None
+                        and result.avg_logprob > logprob_threshold
+                    ):
+                        # don't skip if the logprob is high enough, despite the no_speech_prob
+                        should_skip = False
+
+                    if should_skip:
+                        seek += segment.shape[
+                            -1
+                        ]  # fast-forward to the next segment boundary
+                        continue
+
+                timestamp_tokens: paddle.Tensor = tokens.greater_equal(
+                    paddle.to_tensor(tokenizer.timestamp_begin)
+                )
 
-    def finalize(self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
-        # collect all finished sequences, including patience, and add unfinished ones if not enough
-        sum_logprobs = sum_logprobs.cpu()
-        for i, sequences in enumerate(self.finished_sequences):
-            if (
-                len(sequences) < self.beam_size
-            ):  # when not enough sequences are finished
-                for j in list(np.argsort(sum_logprobs[i]))[::-1]:
-                    sequence = preceding_tokens[i, j].tolist() + [self.eot]
-                    sequences[tuple(sequence)] = sum_logprobs[i][j].item()
-                    if len(sequences) >= self.beam_size:
-                        break
+                consecutive = paddle.where(
+                    timestamp_tokens[:-1] & timestamp_tokens[1:]
+                )[0]
+                if (
+                    len(consecutive) > 0
+                ):  # if the output contains two consecutive timestamp tokens
+                    consecutive = paddle.add(consecutive, paddle.to_tensor(1))
+                    last_slice = 0
+                    for current_slice in consecutive:
+                        sliced_tokens = tokens[last_slice:current_slice]
+                        start_timestamp_position = (
+                            sliced_tokens[0].item() - tokenizer.timestamp_begin
+                        )
+                        end_timestamp_position = (
+                            sliced_tokens[-1].item() - tokenizer.timestamp_begin
+                        )
+                        add_segment(
+                            start=timestamp_offset
+                            + start_timestamp_position * time_precision,
+                            end=timestamp_offset
+                            + end_timestamp_position * time_precision,
+                            text_tokens=sliced_tokens[1:-1],
+                            result=result,
+                        )
+                        last_slice = current_slice
+                    last_timestamp_position = (
+                        tokens[last_slice - 1].item() - tokenizer.timestamp_begin
+                    )
+                    seek += last_timestamp_position * input_stride
+                    all_tokens.extend(tokens[: last_slice + 1].tolist())
+                else:
+                    duration = segment_duration
+                    timestamps = tokens[timestamp_tokens.nonzero().flatten()]
+                    if (
+                        len(timestamps) > 0
+                        and timestamps[-1].item() != tokenizer.timestamp_begin
+                    ):
+                        # no consecutive timestamps but it has a timestamp; use the last one.
+                        # single timestamp at the end means no speech after the last timestamp.
+                        last_timestamp_position = (
+                            timestamps[-1].item() - tokenizer.timestamp_begin
+                        )
+                        duration = last_timestamp_position * time_precision
 
-        tokens: List[List[paddle.Tensor]] = [
-            [paddle.to_tensor(seq) for seq in sequences.keys()]
-            for sequences in self.finished_sequences
-        ]
-        sum_logprobs: List[List[float]] = [
-            list(sequences.values()) for sequences in self.finished_sequences
-        ]
-        return tokens, sum_logprobs
+                    add_segment(
+                        start=timestamp_offset,
+                        end=timestamp_offset + duration,
+                        text_tokens=tokens,
+                        result=result,
+                    )
 
+                    seek += segment.shape[-1]
+                    all_tokens.extend(tokens.tolist())
 
-class LogitFilter:
-    def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
-        """Apply any filtering or masking to logits in-place
+                if not condition_on_previous_text or result.temperature > 0.5:
+                    # do not feed the prompt tokens if a high temperature was used
+                    prompt_reset_since = len(all_tokens)
 
-        Parameters
-        ----------
-        logits : Tensor, shape = (n_batch, vocab_size)
-            per-token logits of the probability distribution at the current step
+                # update progress bar
+                pbar.update(min(num_frames, seek) - previous_seek_value)
+                previous_seek_value = seek
 
-        tokens : Tensor, shape = (n_batch, current_sequence_length)
-            all tokens in the context so far, including the prefix and sot_sequence tokens
+        return dict(
+            text=tokenizer.decode(all_tokens[len(initial_prompt) :]),
+            segments=all_segments,
+            language=language,
+        )
 
+    class SequenceRanker:
+        def rank(
+            self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
+        ) -> List[int]:
+            """
+            Given a list of groups of samples and their cumulative log probabilities,
+            return the indices of the samples in each group to select as the final result
+            """
+            raise NotImplementedError
+
+    class MaximumLikelihoodRanker(SequenceRanker):
+        """
+        Select the sample with the highest log probabilities, penalized using either
+        a simple length normalization or Google NMT paper's length penalty
         """
-        raise NotImplementedError
 
+        def __init__(self, length_penalty: Optional[float]):
+            self.length_penalty = length_penalty
 
-class SuppressBlank(LogitFilter):
-    def __init__(self, tokenizer: Tokenizer, sample_begin: int):
-        self.tokenizer = tokenizer
-        self.sample_begin = sample_begin
+        def rank(
+            self, tokens: List[List[paddle.Tensor]], sum_logprobs: List[List[float]]
+        ):
+            def scores(logprobs, lengths):
+                result = []
+                for logprob, length in zip(logprobs, lengths):
+                    if self.length_penalty is None or self.length_penalty == "None":
+                        penalty = length
+                    else:
+                        # from the Google NMT paper
+                        penalty = ((5 + length) / 6) ** self.length_penalty
+                    result.append(logprob / penalty)
+                return result
+
+            # get the sequence with the highest score
+            lengths = [[len(t) for t in s] for s in tokens]
+            return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
+
+    class TokenDecoder:
+        def reset(self):
+            """Initialize any stateful variables for decoding a new sequence"""
+
+        def update(
+            self,
+            tokens: paddle.Tensor,
+            logits: paddle.Tensor,
+            sum_logprobs: paddle.Tensor,
+        ) -> Tuple[paddle.Tensor, bool]:
+            """Specify how to select the next token, based on the current trace and logits
+            Parameters
+            ----------
+            tokens : Tensor, shape = (n_batch, current_sequence_length)
+                all tokens in the context so far, including the prefix and sot_sequence tokens
+            logits : Tensor, shape = (n_batch, vocab_size)
+                per-token logits of the probability distribution at the current step
+            sum_logprobs : Tensor, shape = (n_batch)
+                cumulative log probabilities for each sequence
+            Returns
+            -------
+            tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
+                the tokens, appended with the selected next token
+            completed : bool
+                True if all sequences has reached the end of text
+            """
+            raise NotImplementedError
+
+        def finalize(
+            self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
+        ) -> Tuple[Sequence[Sequence[paddle.Tensor]], List[List[float]]]:
+            """Finalize search and return the final candidate sequences
+            Parameters
+            ----------
+            tokens : Tensor, shape = (batch_size, beam_size, current_sequence_length)
+                all tokens in the context so far, including the prefix and sot_sequence
+            sum_logprobs : Tensor, shape = (batch_size, beam_size)
+                cumulative log probabilities for each sequence
+            Returns
+            -------
+            tokens : Sequence[Sequence[Tensor]], length = batch_size
+                sequence of Tensors containing candidate token sequences, for each audio input
+            sum_logprobs : List[List[float]], length = batch_size
+                sequence of cumulative log probabilities corresponding to the above
+            """
+            raise NotImplementedError
+
+    class GreedyDecoder(TokenDecoder):
+        def __init__(self, temperature: float, eot: int):
+            self.temperature = temperature
+            self.eot = eot
+
+        def update(
+            self,
+            tokens: paddle.Tensor,
+            logits: paddle.Tensor,
+            sum_logprobs: paddle.Tensor,
+        ) -> Tuple[paddle.Tensor, bool]:
+            temperature = self.temperature
+            if temperature == 0:
+                next_tokens = paddle.argmax(logits, axis=-1)
+            else:
+                next_tokens = paddle.distribution.Categorical(
+                    logits=logits / temperature
+                ).sample([1])
+                next_tokens = paddle.reshape(
+                    next_tokens,
+                    [
+                        next_tokens.shape[0] * next_tokens.shape[1],
+                    ],
+                )
 
-    def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
-        if tokens.shape[1] == self.sample_begin:
-            logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
-                -np.inf
+            logprobs = paddle.nn.functional.log_softmax(
+                logits, axis=-1, dtype=paddle.float32
+            )
+            current_logprobs = logprobs[paddle.arange(logprobs.shape[0]), next_tokens]
+            sum_logprobs += current_logprobs * paddle.to_tensor(
+                (tokens[:, -1] != self.eot), dtype=paddle.float32
             )
 
+            next_tokens[tokens[:, -1] == self.eot] = self.eot
+            tokens = paddle.concat([tokens, next_tokens[:, None]], axis=-1)
+
+            completed = paddle.all((tokens[:, -1] == self.eot))
+            return tokens, completed
+
+        def finalize(self, tokens: paddle.Tensor, sum_logprobs: paddle.Tensor):
+            # make sure each sequence has at least one EOT token at the end
+            tokens = paddle.nn.functional.pad(
+                tokens, (0, 1), value=self.eot, data_format="NCL"
+            )
+            return tokens, sum_logprobs.tolist()
+
+    class BeamSearchDecoder(TokenDecoder):
+        def __init__(
+            self,
+            beam_size: int,
+            eot: int,
+            inference: Inference,
+            patience: Optional[float] = None,
+        ):
+            self.beam_size = beam_size
+            self.eot = eot
+            self.inference = inference
+            self.patience = patience or 1.0
+            if patience is None or patience == "None":
+                self.patience = 1.0
+            else:
+                self.patience = patience
+            self.max_candidates: int = round(beam_size * self.patience)
+            self.finished_sequences = None
+
+            assert (
+                self.max_candidates > 0
+            ), f"Invalid beam size ({beam_size}) or patience ({patience})"
+
+        def reset(self):
+            self.finished_sequences = None
+
+        def update(
+            self,
+            tokens: paddle.Tensor,
+            logits: paddle.Tensor,
+            sum_logprobs: paddle.Tensor,
+        ) -> Tuple[paddle.Tensor, bool]:
+            if tokens.shape[0] % self.beam_size != 0:
+                raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
+
+            batch_size = tokens.shape[0] // self.beam_size
+            if self.finished_sequences is None:  # for the first update
+                self.finished_sequences = [{} for _ in range(batch_size)]
+
+            logprobs = paddle.nn.functional.log_softmax(
+                logits, axis=-1, dtype="float32"
+            )
+            next_tokens, source_indices, finished_sequences = [], [], []
+            for i in range(batch_size):
+                scores, sources, finished = {}, {}, {}
+
+                # STEP 1: calculate the cumulative log probabilities for possible candidates
+                for j in range(self.beam_size):
+                    idx = i * self.beam_size + j
+                    prefix = tokens[idx].tolist()
+                    logprob, token = paddle.topk(logprobs[idx], k=self.beam_size + 1)
+                    for logprob, token in zip(logprob, token):
+                        new_logprob = (sum_logprobs[idx] + logprob).item()
+                        sequence = tuple(prefix + [token.item()])
+                        scores[sequence] = new_logprob
+                        sources[sequence] = idx
+
+                # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
+                saved = 0
+                for sequence in sorted(scores, key=scores.get, reverse=True):
+                    if sequence[-1] == self.eot:
+                        finished[sequence] = scores[sequence]
+                    else:
+                        sum_logprobs[len(next_tokens)] = scores[sequence]
+                        next_tokens.append(sequence)
+                        source_indices.append(sources[sequence])
+
+                        saved += 1
+                        if saved == self.beam_size:
+                            break
+
+                finished_sequences.append(finished)
+
+            tokens = paddle.to_tensor(next_tokens)
+            self.inference.rearrange_kv_cache(source_indices)
+
+            # add newly finished sequences to self.finished_sequences
+            assert len(self.finished_sequences) == len(finished_sequences)
+            for previously_finished, newly_finished in zip(
+                self.finished_sequences, finished_sequences
+            ):
+                for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
+                    if len(previously_finished) >= self.max_candidates:
+                        break  # the candidate list is full
+                    previously_finished[seq] = newly_finished[seq]
+
+            # mark as completed if all audio has enough number of samples
+            completed = all(
+                len(sequences) >= self.max_candidates
+                for sequences in self.finished_sequences
+            )
+            return tokens, completed
 
-class SuppressTokens(LogitFilter):
-    def __init__(self, suppress_tokens: Sequence[int]):
-        self.suppress_tokens = list(suppress_tokens)
+        def finalize(
+            self, preceding_tokens: paddle.Tensor, sum_logprobs: paddle.Tensor
+        ):
+            # collect all finished sequences, including patience, and add unfinished ones if not enough
+            sum_logprobs = sum_logprobs.cpu()
+            for i, sequences in enumerate(self.finished_sequences):
+                if (
+                    len(sequences) < self.beam_size
+                ):  # when not enough sequences are finished
+                    for j in list(np.argsort(sum_logprobs[i]))[::-1]:
+                        sequence = preceding_tokens[i, j].tolist() + [self.eot]
+                        sequences[tuple(sequence)] = sum_logprobs[i][j].item()
+                        if len(sequences) >= self.beam_size:
+                            break
+
+            tokens: List[List[paddle.Tensor]] = [
+                [paddle.to_tensor(seq) for seq in sequences.keys()]
+                for sequences in self.finished_sequences
+            ]
+            sum_logprobs: List[List[float]] = [
+                list(sequences.values()) for sequences in self.finished_sequences
+            ]
+            return tokens, sum_logprobs
+
+    class LogitFilter:
+        def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor) -> None:
+            """Apply any filtering or masking to logits in-place
+
+            Parameters
+            ----------
+            logits : Tensor, shape = (n_batch, vocab_size)
+                per-token logits of the probability distribution at the current step
+
+            tokens : Tensor, shape = (n_batch, current_sequence_length)
+                all tokens in the context so far, including the prefix and sot_sequence tokens
+
+            """
+            raise NotImplementedError
+
+    class SuppressBlank(LogitFilter):
+        def __init__(self, tokenizer: Tokenizer, sample_begin: int):
+            self.tokenizer = tokenizer
+            self.sample_begin = sample_begin
+
+        def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
+            if tokens.shape[1] == self.sample_begin:
+                logits[
+                    :, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]
+                ] = -np.inf
+
+    class SuppressTokens(LogitFilter):
+        def __init__(self, suppress_tokens: Sequence[int]):
+            self.suppress_tokens = list(suppress_tokens)
+
+        def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
+            logits[:, self.suppress_tokens] = -np.inf
+
+    class ApplyTimestampRules(LogitFilter):
+        def __init__(
+            self,
+            tokenizer: Tokenizer,
+            sample_begin: int,
+            max_initial_timestamp_index: Optional[int],
+        ):
+            self.tokenizer = tokenizer
+            self.sample_begin = sample_begin
+            self.max_initial_timestamp_index = max_initial_timestamp_index
+
+        def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
+            # suppress <|notimestamps|> which is handled by without_timestamps
+            if self.tokenizer.no_timestamps is not None:
+                logits[:, self.tokenizer.no_timestamps] = -np.inf
+
+            # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
+            for k in range(tokens.shape[0]):
+                seq = [t for t in tokens[k, self.sample_begin :].tolist()]
+                last_was_timestamp = (
+                    len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
+                )
+                penultimate_was_timestamp = (
+                    len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
+                )
 
-    def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
-        logits[:, self.suppress_tokens] = -np.inf
+                if last_was_timestamp:
+                    if penultimate_was_timestamp:  # has to be non-timestamp
+                        logits[k, self.tokenizer.timestamp_begin :] = -np.inf
+                    else:  # cannot be normal text tokens
+                        logits[k, : self.tokenizer.eot] = -np.inf
 
+            # apply the `max_initial_timestamp` option
+            if (
+                tokens.shape[1] == self.sample_begin
+                and self.max_initial_timestamp_index is not None
+            ):
+                last_allowed = (
+                    self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+                )
+                logits[:, last_allowed + 1 :] = -np.inf
 
-class ApplyTimestampRules(LogitFilter):
-    def __init__(
-        self,
-        tokenizer: Tokenizer,
-        sample_begin: int,
-        max_initial_timestamp_index: Optional[int],
-    ):
-        self.tokenizer = tokenizer
-        self.sample_begin = sample_begin
-        self.max_initial_timestamp_index = max_initial_timestamp_index
-
-    def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
-        # suppress <|notimestamps|> which is handled by without_timestamps
-        if self.tokenizer.no_timestamps is not None:
-            logits[:, self.tokenizer.no_timestamps] = -np.inf
-
-        # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
-        for k in range(tokens.shape[0]):
-            seq = [t for t in tokens[k, self.sample_begin :].tolist()]
-            last_was_timestamp = (
-                len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
-            )
-            penultimate_was_timestamp = (
-                len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
+            # if sum of probability over timestamps is above any other token, sample timestamp
+            logprobs = paddle.nn.functional.log_softmax(
+                logits, axis=-1, dtype="float32"
             )
+            for k in range(tokens.shape[0]):
+                # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
+                # To bypass this issue in CI, we have decomposed the operation into separate steps.
+                # It will raise 2e-6 difference in precision.
+                # TODO: revert this after logsumexp been fixed.
+                timestamp_logprob = paddle.exp(
+                    logprobs[k, self.tokenizer.timestamp_begin :]
+                )
+                timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
+                timestamp_logprob = paddle.log(timestamp_logprob)
+                max_text_token_logprob = paddle.max(
+                    logprobs[k, : self.tokenizer.timestamp_begin]
+                )
+                if timestamp_logprob > max_text_token_logprob:
+                    logits[k, : self.tokenizer.timestamp_begin] = -np.inf
 
-            if last_was_timestamp:
-                if penultimate_was_timestamp:  # has to be non-timestamp
-                    logits[k, self.tokenizer.timestamp_begin :] = -np.inf
-                else:  # cannot be normal text tokens
-                    logits[k, : self.tokenizer.eot] = -np.inf
+    class DecodingTask:
+        inference: Inference
+        sequence_ranker: SequenceRanker
+        decoder: TokenDecoder
+        logit_filters: List[LogitFilter]
 
-        # apply the `max_initial_timestamp` option
-        if (
-            tokens.shape[1] == self.sample_begin
-            and self.max_initial_timestamp_index is not None
+        def __init__(
+            self, model: "Whisper", options: DecodingOptions, resource_path: str
         ):
-            last_allowed = (
-                self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
-            )
-            logits[:, last_allowed + 1 :] = -np.inf
-
-        # if sum of probability over timestamps is above any other token, sample timestamp
-        logprobs = paddle.nn.functional.log_softmax(logits, axis=-1, dtype="float32")
-        for k in range(tokens.shape[0]):
-            # When using paddle.logsumexp on a 32GB Tesla-V100 GPU, we encountered CUDA error 700.
-            # To bypass this issue in CI, we have decomposed the operation into separate steps.
-            # It will raise 2e-6 difference in precision.
-            # TODO: revert this after logsumexp been fixed.
-            timestamp_logprob = paddle.exp(
-                logprobs[k, self.tokenizer.timestamp_begin :]
-            )
-            timestamp_logprob = paddle.sum(timestamp_logprob, axis=-1)
-            timestamp_logprob = paddle.log(timestamp_logprob)
-            max_text_token_logprob = paddle.max(
-                logprobs[k, : self.tokenizer.timestamp_begin]
+            self.model = model
+
+            language = options.language or "en"
+            tokenizer = get_tokenizer(
+                model.is_multilingual,
+                resource_path=resource_path,
+                language=language,
+                task=options.task,
             )
-            if timestamp_logprob > max_text_token_logprob:
-                logits[k, : self.tokenizer.timestamp_begin] = -np.inf
+            self.tokenizer: Tokenizer = tokenizer
+            self.options: DecodingOptions = self._verify_options(options)
+            self.resource_path: str = resource_path
 
+            self.beam_size: int = options.beam_size or options.best_of or 1
+            self.n_ctx: int = model.dims.n_text_ctx
+            self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
 
-class DecodingTask:
-    inference: Inference
-    sequence_ranker: SequenceRanker
-    decoder: TokenDecoder
-    logit_filters: List[LogitFilter]
+            self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
+            if self.options.without_timestamps:
+                self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
 
-    def __init__(self, model: "Whisper", options: DecodingOptions, resource_path: str):
-        self.model = model
+            self.initial_tokens: Tuple[int] = self._get_initial_tokens()
+            self.sample_begin: int = len(self.initial_tokens)
+            self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
 
-        language = options.language or "en"
-        tokenizer = get_tokenizer(
-            model.is_multilingual,
-            resource_path=resource_path,
-            language=language,
-            task=options.task,
-        )
-        self.tokenizer: Tokenizer = tokenizer
-        self.options: DecodingOptions = self._verify_options(options)
-        self.resource_path: str = resource_path
+            # inference: implements the forward pass through the decoder, including kv caching
+            self.inference = WhisperInference(model, len(self.initial_tokens))
 
-        self.beam_size: int = options.beam_size or options.best_of or 1
-        self.n_ctx: int = model.dims.n_text_ctx
-        self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
+            # sequence ranker: implements how to rank a group of sampled sequences
+            self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
 
-        self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
-        if self.options.without_timestamps:
-            self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
+            # decoder: implements how to select the next tokens, given the autoregressive distribution
+            if options.beam_size is not None:
+                self.decoder = BeamSearchDecoder(
+                    options.beam_size, tokenizer.eot, self.inference, options.patience
+                )
+            else:
+                self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
 
-        self.initial_tokens: Tuple[int] = self._get_initial_tokens()
-        self.sample_begin: int = len(self.initial_tokens)
-        self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
+            # logit filters: applies various rules to suppress or penalize certain tokens
+            self.logit_filters = []
+            if self.options.suppress_blank:
+                self.logit_filters.append(
+                    SuppressBlank(self.tokenizer, self.sample_begin)
+                )
+            if self.options.suppress_tokens:
+                self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
+            if not options.without_timestamps:
+                precision = (
+                    CHUNK_LENGTH / model.dims.n_audio_ctx
+                )  # usually 0.02 seconds
+                max_initial_timestamp_index = None
+                if options.max_initial_timestamp:
+                    max_initial_timestamp_index = round(
+                        self.options.max_initial_timestamp / precision
+                    )
+                self.logit_filters.append(
+                    ApplyTimestampRules(
+                        tokenizer, self.sample_begin, max_initial_timestamp_index
+                    )
+                )
+
+        def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
+            if options.beam_size is not None and options.best_of is not None:
+                raise ValueError("beam_size and best_of can't be given together")
+            if options.temperature == 0:
+                if options.best_of is not None:
+                    raise ValueError(
+                        "best_of with greedy sampling (T=0) is not compatible"
+                    )
+            if options.patience is not None and options.beam_size is None:
+                raise ValueError("patience requires beam_size to be given")
+            if options.length_penalty is not None and options.length_penalty != "None":
+                if not (0 <= options.length_penalty <= 1):
+                    raise ValueError(
+                        "length_penalty (alpha) should be a value between 0 and 1"
+                    )
 
-        # inference: implements the forward pass through the decoder, including kv caching
-        self.inference = WhisperInference(model, len(self.initial_tokens))
+            return options
 
-        # sequence ranker: implements how to rank a group of sampled sequences
-        self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
+        def _get_initial_tokens(self) -> Tuple[int]:
+            tokens = list(self.sot_sequence)
+            prefix = self.options.prefix
+            prompt = self.options.prompt
 
-        # decoder: implements how to select the next tokens, given the autoregressive distribution
-        if options.beam_size is not None:
-            self.decoder = BeamSearchDecoder(
-                options.beam_size, tokenizer.eot, self.inference, options.patience
-            )
-        else:
-            self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
-
-        # logit filters: applies various rules to suppress or penalize certain tokens
-        self.logit_filters = []
-        if self.options.suppress_blank:
-            self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
-        if self.options.suppress_tokens:
-            self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
-        if not options.without_timestamps:
-            precision = CHUNK_LENGTH / model.dims.n_audio_ctx  # usually 0.02 seconds
-            max_initial_timestamp_index = None
-            if options.max_initial_timestamp:
-                max_initial_timestamp_index = round(
-                    self.options.max_initial_timestamp / precision
+            if prefix:
+                prefix_tokens = (
+                    self.tokenizer.encode(" " + prefix.strip().input_ids)
+                    if isinstance(prefix, str)
+                    else prefix
                 )
-            self.logit_filters.append(
-                ApplyTimestampRules(
-                    tokenizer, self.sample_begin, max_initial_timestamp_index
+                if self.sample_len is not None:
+                    max_prefix_len = self.n_ctx // 2 - self.sample_len
+                    prefix_tokens = prefix_tokens[-max_prefix_len:]
+                tokens = tokens + prefix_tokens
+
+            if prompt:
+                prompt_tokens = (
+                    self.tokenizer.encode(" " + prompt.strip().input_ids)
+                    if isinstance(prompt, str)
+                    else prompt
                 )
-            )
-
-    def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
-        if options.beam_size is not None and options.best_of is not None:
-            raise ValueError("beam_size and best_of can't be given together")
-        if options.temperature == 0:
-            if options.best_of is not None:
-                raise ValueError("best_of with greedy sampling (T=0) is not compatible")
-        if options.patience is not None and options.beam_size is None:
-            raise ValueError("patience requires beam_size to be given")
-        if options.length_penalty is not None and options.length_penalty != "None":
-            if not (0 <= options.length_penalty <= 1):
-                raise ValueError(
-                    "length_penalty (alpha) should be a value between 0 and 1"
+                tokens = (
+                    [self.tokenizer.sot_prev]
+                    + prompt_tokens[-(self.n_ctx // 2 - 1) :]
+                    + tokens
                 )
 
-        return options
-
-    def _get_initial_tokens(self) -> Tuple[int]:
-        tokens = list(self.sot_sequence)
-        prefix = self.options.prefix
-        prompt = self.options.prompt
+            return tuple(tokens)
 
-        if prefix:
-            prefix_tokens = (
-                self.tokenizer.encode(" " + prefix.strip().input_ids)
-                if isinstance(prefix, str)
-                else prefix
-            )
-            if self.sample_len is not None:
-                max_prefix_len = self.n_ctx // 2 - self.sample_len
-                prefix_tokens = prefix_tokens[-max_prefix_len:]
-            tokens = tokens + prefix_tokens
-
-        if prompt:
-            prompt_tokens = (
-                self.tokenizer.encode(" " + prompt.strip().input_ids)
-                if isinstance(prompt, str)
-                else prompt
-            )
-            tokens = (
-                [self.tokenizer.sot_prev]
-                + prompt_tokens[-(self.n_ctx // 2 - 1) :]
-                + tokens
-            )
+        def _get_suppress_tokens(self) -> Tuple[int]:
+            suppress_tokens = self.options.suppress_tokens
 
-        return tuple(tokens)
+            if isinstance(suppress_tokens, str):
+                suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
 
-    def _get_suppress_tokens(self) -> Tuple[int]:
-        suppress_tokens = self.options.suppress_tokens
+            if -1 in suppress_tokens:
+                suppress_tokens = [t for t in suppress_tokens if t >= 0]
+                suppress_tokens.extend(self.tokenizer.non_speech_tokens)
+            elif suppress_tokens is None or len(suppress_tokens) == 0:
+                suppress_tokens = []  # interpret empty string as an empty list
+            else:
+                assert isinstance(
+                    suppress_tokens, list
+                ), "suppress_tokens must be a list"
 
-        if isinstance(suppress_tokens, str):
-            suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
+            suppress_tokens.extend(
+                [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
+            )
+            if self.tokenizer.no_speech is not None:
+                # no-speech probability is collected separately
+                suppress_tokens.append(self.tokenizer.no_speech)
 
-        if -1 in suppress_tokens:
-            suppress_tokens = [t for t in suppress_tokens if t >= 0]
-            suppress_tokens.extend(self.tokenizer.non_speech_tokens)
-        elif suppress_tokens is None or len(suppress_tokens) == 0:
-            suppress_tokens = []  # interpret empty string as an empty list
-        else:
-            assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
+            return tuple(sorted(set(suppress_tokens)))
 
-        suppress_tokens.extend(
-            [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
-        )
-        if self.tokenizer.no_speech is not None:
-            # no-speech probability is collected separately
-            suppress_tokens.append(self.tokenizer.no_speech)
+        def _get_audio_features(self, mel: paddle.Tensor):
 
-        return tuple(sorted(set(suppress_tokens)))
+            if mel.shape[-2:] == (
+                self.model.dims.n_audio_ctx,
+                self.model.dims.n_audio_state,
+            ):
+                # encoded audio features are given; skip audio encoding
+                audio_features = mel
+            else:
+                audio_features = self.model.encoder(mel)
 
-    def _get_audio_features(self, mel: paddle.Tensor):
+            return audio_features
 
-        if mel.shape[-2:] == (
-            self.model.dims.n_audio_ctx,
-            self.model.dims.n_audio_state,
+        def _detect_language(
+            self,
+            audio_features: paddle.Tensor,
+            tokens: paddle.Tensor,
+            resource_path: str,
         ):
-            # encoded audio features are given; skip audio encoding
-            audio_features = mel
-        else:
-            audio_features = self.model.encoder(mel)
+            languages = [self.options.language] * audio_features.shape[0]
+            lang_probs = None
 
-        return audio_features
+            if self.options.language is None or self.options.task == "lang_id":
+                lang_tokens, lang_probs = self.model.detect_language(
+                    audio_features, self.tokenizer, self.resource_path
+                )
+                languages = [max(probs, key=probs.get) for probs in lang_probs]
+                if self.options.language is None:
+                    tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
 
-    def _detect_language(
-        self, audio_features: paddle.Tensor, tokens: paddle.Tensor, resource_path: str
-    ):
-        languages = [self.options.language] * audio_features.shape[0]
-        lang_probs = None
+            return languages, lang_probs
 
-        if self.options.language is None or self.options.task == "lang_id":
-            lang_tokens, lang_probs = self.model.detect_language(
-                audio_features, self.tokenizer, self.resource_path
+        def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
+            assert audio_features.shape[0] == tokens.shape[0]
+            n_batch = tokens.shape[0]
+            sum_logprobs: paddle.Tensor = paddle.zeros(
+                paddle.to_tensor(n_batch), dtype=paddle.float32
             )
-            languages = [max(probs, key=probs.get) for probs in lang_probs]
-            if self.options.language is None:
-                tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
-
-        return languages, lang_probs
-
-    def _main_loop(self, audio_features: paddle.Tensor, tokens: paddle.Tensor):
-        assert audio_features.shape[0] == tokens.shape[0]
-        n_batch = tokens.shape[0]
-        sum_logprobs: paddle.Tensor = paddle.zeros(
-            paddle.to_tensor(n_batch), dtype=paddle.float32
-        )
-        no_speech_probs = [np.nan] * n_batch
-
-        try:
-            for i in range(self.sample_len):
-                logits = self.inference.logits(tokens, audio_features)
-
-                if (
-                    i == 0 and self.tokenizer.no_speech is not None
-                ):  # save no_speech_probs
-                    probs_at_sot = paddle.nn.functional.softmax(
-                        logits[:, self.sot_index], axis=-1, dtype=paddle.float32
+            no_speech_probs = [np.nan] * n_batch
+
+            try:
+                for i in range(self.sample_len):
+                    logits = self.inference.logits(tokens, audio_features)
+
+                    if (
+                        i == 0 and self.tokenizer.no_speech is not None
+                    ):  # save no_speech_probs
+                        probs_at_sot = paddle.nn.functional.softmax(
+                            logits[:, self.sot_index], axis=-1, dtype=paddle.float32
+                        )
+                        no_speech_probs = probs_at_sot[
+                            :, self.tokenizer.no_speech
+                        ].tolist()
+
+                    # now we need to consider the logits at the last token only
+                    logits = logits[:, -1]
+
+                    # apply the logit filters, e.g. for suppressing or applying penalty to
+                    for logit_filter in self.logit_filters:
+                        logit_filter.apply(logits, tokens)
+
+                    # expand the tokens tensor with the selected next tokens
+                    tokens, completed = self.decoder.update(
+                        tokens, logits, sum_logprobs
                     )
-                    no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
-
-                # now we need to consider the logits at the last token only
-                logits = logits[:, -1]
-
-                # apply the logit filters, e.g. for suppressing or applying penalty to
-                for logit_filter in self.logit_filters:
-                    logit_filter.apply(logits, tokens)
+                    if completed or tokens.shape[-1] > self.n_ctx:
+                        break
+            finally:
+                self.inference.cleanup_caching()
+
+            return tokens, sum_logprobs, no_speech_probs
+
+        @paddle.no_grad()
+        def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
+            self.decoder.reset()
+            tokenizer: Tokenizer = self.tokenizer
+            batch_size: int = mel.shape[0]
+
+            audio_features: paddle.Tensor = self._get_audio_features(
+                mel
+            )  # encoder forward pass
+
+            tokens: paddle.Tensor
+            if batch_size > 1:
+                for i in range(batch_size):
+                    tokens = paddle.concat(
+                        x=[
+                            paddle.to_tensor([self.initial_tokens]),
+                            paddle.to_tensor([self.initial_tokens]),
+                        ],
+                        axis=0,
+                    )
+            elif batch_size == 1:
+                tokens = paddle.to_tensor([self.initial_tokens])
+
+            # detect language if requested, overwriting the language token
+            languages, language_probs = self._detect_language(
+                paddle.to_tensor(audio_features),
+                paddle.to_tensor(tokens),
+                self.resource_path,
+            )
 
-                # expand the tokens tensor with the selected next tokens
-                tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
-                if completed or tokens.shape[-1] > self.n_ctx:
-                    break
-        finally:
-            self.inference.cleanup_caching()
+            if self.options.task == "lang_id":
+                return [
+                    DecodingResult(
+                        audio_features=features, language=language, language_probs=probs
+                    )
+                    for features, language, probs in zip(
+                        audio_features, languages, language_probs
+                    )
+                ]
 
-        return tokens, sum_logprobs, no_speech_probs
+            # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
+            audio_features = paddle.repeat_interleave(
+                audio_features, self.beam_size, axis=0
+            )
+            tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
+            # call the main sampling loop
+            tokens, sum_logprobs, no_speech_probs = self._main_loop(
+                audio_features, tokens
+            )
+            # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
+            audio_features = audio_features[:: self.beam_size]
+            no_speech_probs = no_speech_probs[:: self.beam_size]
+            assert audio_features.shape[0] == len(no_speech_probs) == batch_size
+            tokens = tokens.reshape([batch_size, self.beam_size, -1])
+            sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
+
+            # get the final candidates for each group, and slice between the first sampled token and EOT
+            tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
+            tokens: List[List[paddle.Tensor]] = [
+                [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
+                for s in tokens
+            ]
 
-    @paddle.no_grad()
-    def run(self, mel: paddle.Tensor) -> List[DecodingResult]:
-        self.decoder.reset()
-        tokenizer: Tokenizer = self.tokenizer
-        batch_size: int = mel.shape[0]
+            # select the top-ranked sample in each group
+            selected = self.sequence_ranker.rank(tokens, sum_logprobs)
+            tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
+            texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
 
-        audio_features: paddle.Tensor = self._get_audio_features(
-            mel
-        )  # encoder forward pass
+            sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
+            avg_logprobs: List[float] = [
+                lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
+            ]
 
-        tokens: paddle.Tensor
-        if batch_size > 1:
-            for i in range(batch_size):
-                tokens = paddle.concat(
-                    x=[
-                        paddle.to_tensor([self.initial_tokens]),
-                        paddle.to_tensor([self.initial_tokens]),
-                    ],
-                    axis=0,
+            fields = (
+                texts,
+                languages,
+                tokens,
+                audio_features,
+                avg_logprobs,
+                no_speech_probs,
+            )
+            if len(set(map(len, fields))) != 1:
+                raise RuntimeError(
+                    f"inconsistent result lengths: {list(map(len, fields))}"
                 )
-        elif batch_size == 1:
-            tokens = paddle.to_tensor([self.initial_tokens])
-
-        # detect language if requested, overwriting the language token
-        languages, language_probs = self._detect_language(
-            paddle.to_tensor(audio_features),
-            paddle.to_tensor(tokens),
-            self.resource_path,
-        )
 
-        if self.options.task == "lang_id":
             return [
                 DecodingResult(
-                    audio_features=features, language=language, language_probs=probs
+                    audio_features=features,
+                    language=language,
+                    tokens=tokens,
+                    text=text,
+                    avg_logprob=avg_logprob,
+                    no_speech_prob=no_speech_prob,
+                    temperature=self.options.temperature,
+                    compression_ratio=compression_ratio(text),
                 )
-                for features, language, probs in zip(
-                    audio_features, languages, language_probs
+                for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
+                    *fields
                 )
             ]
 
-        # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
-
-        audio_features = paddle.repeat_interleave(
-            audio_features, self.beam_size, axis=0
-        )
-        tokens = paddle.repeat_interleave(tokens, self.beam_size, axis=0)
-
-        # call the main sampling loop
-        tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
-
-        # reshape the tensors to have (batch_size, beam_size) as the first two dimensions
-        audio_features = audio_features[:: self.beam_size]
-        no_speech_probs = no_speech_probs[:: self.beam_size]
-        assert audio_features.shape[0] == len(no_speech_probs) == batch_size
-
-        tokens = tokens.reshape([batch_size, self.beam_size, -1])
-        sum_logprobs = sum_logprobs.reshape([batch_size, self.beam_size])
+    @paddle.no_grad()
+    def decode(
+        model: "Whisper",
+        mel: paddle.Tensor,
+        options: DecodingOptions = DecodingOptions(),
+        resource_path=str,
+    ) -> Union[DecodingResult, List[DecodingResult]]:
+        """
+        Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
+        Parameters
+        ----------
+        model: Whisper
+            the Whisper model instance
+        mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
+            A tensor containing the Mel spectrogram(s)
+        options: DecodingOptions
+            A dataclass that contains all necessary options for decoding 30-second segments
+        Returns
+        -------
+        result: Union[DecodingResult, List[DecodingResult]]
+            The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
+        """
+        single = mel.ndim == 2
+        if single:
+            mel = mel.unsqueeze(0)
 
-        # get the final candidates for each group, and slice between the first sampled token and EOT
-        tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
-        tokens: List[List[paddle.Tensor]] = [
-            [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
-            for s in tokens
-        ]
+        result = DecodingTask(model, options, resource_path).run(mel)
 
-        # select the top-ranked sample in each group
-        selected = self.sequence_ranker.rank(tokens, sum_logprobs)
-        tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
-        texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
+        if single:
+            result = result[0]
 
-        sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
-        avg_logprobs: List[float] = [
-            lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
-        ]
+        return result
 
-        fields = (
-            texts,
-            languages,
-            tokens,
-            audio_features,
-            avg_logprobs,
-            no_speech_probs,
-        )
-        if len(set(map(len, fields))) != 1:
-            raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
+    class Whisper(paddle.nn.Layer):
+        """
+        The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
+        """
 
-        return [
-            DecodingResult(
-                audio_features=features,
-                language=language,
-                tokens=tokens,
-                text=text,
-                avg_logprob=avg_logprob,
-                no_speech_prob=no_speech_prob,
-                temperature=self.options.temperature,
-                compression_ratio=compression_ratio(text),
+        def __init__(self, dims: ModelDimensions):
+            super().__init__()
+            self.dims = dims
+            self.encoder = AudioEncoder(
+                self.dims.n_mels,
+                self.dims.n_audio_ctx,
+                self.dims.n_audio_state,
+                self.dims.n_audio_head,
+                self.dims.n_audio_layer,
             )
-            for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
-                *fields
+            self.decoder = TextDecoder(
+                self.dims.n_vocab,
+                self.dims.n_text_ctx,
+                self.dims.n_text_state,
+                self.dims.n_text_head,
+                self.dims.n_text_layer,
             )
-        ]
 
+        def embed_audio(self, mel: paddle.Tensor):
+            return self.encoder.forward(mel)
+
+        def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
+            return self.decoder.forward(tokens, audio_features)
+
+        def forward(
+            self, mel: paddle.Tensor, tokens: paddle.Tensor
+        ) -> Dict[str, paddle.Tensor]:
+            return self.decoder(tokens, self.encoder(mel))
+
+        @property
+        def device(self):
+            return paddle.device.get_device()
+
+        @property
+        def is_multilingual(self):
+            return self.dims.n_vocab == 51865
+
+        def install_kv_cache_hooks(self, cache: Optional[dict] = None):
+            """
+            The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
+            tensors calculated for the previous positions. This method returns a dictionary that stores
+            all caches, and the necessary hooks for the key and value projection modules that save the
+            intermediate tensors to be reused during later calculations.
+            Returns
+            -------
+            cache : Dict[nn.Layer, paddle.Tensor]
+                A dictionary object mapping the key/value projection modules to its cache
+            hooks : List[RemovableHandle]
+                List of PyTorch RemovableHandle objects to stop the hooks to be called
+            """
+            cache = {**cache} if cache is not None else {}
+            hooks = []
+
+            def save_to_cache(module, _, output):
+                if (
+                    module not in cache
+                    or output.shape[1] > self.decoder.positional_embedding.shape[0]
+                ):
+                    cache[module] = (
+                        output  # save as-is, for the first token or cross attention
+                    )
+                else:
+                    cache[module] = paddle.concat(
+                        [cache[module], output], axis=1
+                    ).detach()
+                return cache[module]
 
-@paddle.no_grad()
-def decode(
-    model: "Whisper",
-    mel: paddle.Tensor,
-    options: DecodingOptions = DecodingOptions(),
-    resource_path=str,
-) -> Union[DecodingResult, List[DecodingResult]]:
-    """
-    Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
-
-    Parameters
-    ----------
-    model: Whisper
-        the Whisper model instance
-
-    mel: paddle.Tensor, shape = (80, 3000) or (*, 80, 3000)
-        A tensor containing the Mel spectrogram(s)
-
-    options: DecodingOptions
-        A dataclass that contains all necessary options for decoding 30-second segments
-
-    Returns
-    -------
-    result: Union[DecodingResult, List[DecodingResult]]
-        The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
-    """
-    single = mel.ndim == 2
-    if single:
-        mel = mel.unsqueeze(0)
-
-    result = DecodingTask(model, options, resource_path).run(mel)
-
-    if single:
-        result = result[0]
-
-    return result
-
-
-class Whisper(paddle.nn.Layer):
-    """
-    The `Whisper` module use AudioEncoder and TextDecoder, and return detect_language, transcribe, decode.
-    """
-
-    def __init__(self, dims: ModelDimensions):
-        super().__init__()
-        self.dims = dims
-        self.encoder = AudioEncoder(
-            self.dims.n_mels,
-            self.dims.n_audio_ctx,
-            self.dims.n_audio_state,
-            self.dims.n_audio_head,
-            self.dims.n_audio_layer,
-        )
-        self.decoder = TextDecoder(
-            self.dims.n_vocab,
-            self.dims.n_text_ctx,
-            self.dims.n_text_state,
-            self.dims.n_text_head,
-            self.dims.n_text_layer,
-        )
+            def install_hooks(layer: paddle.nn.Layer):
+                if isinstance(layer, MultiHeadAttention):
+                    hooks.append(layer.key.register_forward_post_hook(save_to_cache))
+                    hooks.append(layer.value.register_forward_post_hook(save_to_cache))
 
-    def embed_audio(self, mel: paddle.Tensor):
-        return self.encoder.forward(mel)
+            self.decoder.apply(install_hooks)
+            return cache, hooks
 
-    def logits(self, tokens: paddle.Tensor, audio_features: paddle.Tensor):
-        return self.decoder.forward(tokens, audio_features)
+        detect_language = detect_language
+        transcribe = transcribe
+        decode = decode
 
-    def forward(
-        self, mel: paddle.Tensor, tokens: paddle.Tensor
-    ) -> Dict[str, paddle.Tensor]:
-        return self.decoder(tokens, self.encoder(mel))
+    def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
+        """
+        Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+        """
+        if paddle.is_tensor(array):
+            if array.shape[axis] > length:
+                array = array.index_select(axis=axis, index=paddle.arange(length))
+
+            if array.shape[axis] < length:
+                pad_widths = [(0, 0)] * array.ndim
+                pad_widths[axis] = (0, length - array.shape[axis])
+                array = paddle.transpose(array, (1, 0))
+                array = paddle.nn.functional.pad(
+                    array,
+                    [pad for sizes in pad_widths[::-1] for pad in sizes],
+                    data_format="NLC",
+                )
+                array = paddle.transpose(array, (1, 0))
+        else:
+            if array.shape[axis] > length:
+                array = array.take(indices=range(length), axis=axis)
 
-    @property
-    def device(self):
-        return paddle.device.get_device()
+            if array.shape[axis] < length:
+                pad_widths = [(0, 0)] * array.ndim
+                pad_widths[axis] = (0, length - array.shape[axis])
+                array = paddle.transpose(array, (1, 0))
+                array = np.pad(array, pad_widths)
+                array = paddle.transpose(array, (1, 0))
 
-    @property
-    def is_multilingual(self):
-        return self.dims.n_vocab == 51865
+        return array
 
-    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
+    def hann_window(n_fft: int = N_FFT):
         """
-        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
-        tensors calculated for the previous positions. This method returns a dictionary that stores
-        all caches, and the necessary hooks for the key and value projection modules that save the
-        intermediate tensors to be reused during later calculations.
+        hanning window
+        n_fft:  The number of frequency components of the discrete Fourier transform.
+        """
+        return paddle.to_tensor(
+            [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
+            dtype=paddle.float32,
+        )
 
+    @lru_cache(maxsize=None)
+    def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
+        """
+        load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+        Allows decoupling librosa dependency; saved using:
+            np.savez_compressed(
+                "mel_filters.npz",
+                mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+            )
+        """
+        assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+        with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
+            return paddle.to_tensor(f[f"mel_{n_mels}"])
+
+    @function_requires_deps("soundfile")
+    def log_mel_spectrogram(
+        audio: Union[str, np.ndarray, paddle.Tensor],
+        n_mels: int = N_MELS,
+        resource_path: str = None,
+    ):
+        """
+        Compute the log-Mel spectrogram of
+        Parameters
+        ----------
+        audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
+            The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
+        n_mels: int
+            The number of Mel-frequency filters, only 80 is supported
         Returns
         -------
-        cache : Dict[nn.Layer, paddle.Tensor]
-            A dictionary object mapping the key/value projection modules to its cache
-        hooks : List[RemovableHandle]
-            List of PyTorch RemovableHandle objects to stop the hooks to be called
+        paddle.Tensor, shape = (80, n_frames)
+            A Tensor that contains the Mel spectrogram
         """
-        cache = {**cache} if cache is not None else {}
-        hooks = []
+        if not paddle.is_tensor(audio):
+            if isinstance(audio, str):
+                audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
+                audio = audio[:, 0]
+            audio = paddle.to_tensor(audio)
 
-        def save_to_cache(module, _, output):
-            if (
-                module not in cache
-                or output.shape[1] > self.decoder.positional_embedding.shape[0]
-            ):
-                cache[module] = (
-                    output  # save as-is, for the first token or cross attention
-                )
-            else:
-                cache[module] = paddle.concat([cache[module], output], axis=1).detach()
-            return cache[module]
-
-        def install_hooks(layer: paddle.nn.Layer):
-            if isinstance(layer, MultiHeadAttention):
-                hooks.append(layer.key.register_forward_post_hook(save_to_cache))
-                hooks.append(layer.value.register_forward_post_hook(save_to_cache))
-
-        self.decoder.apply(install_hooks)
-        return cache, hooks
-
-    detect_language = detect_language
-    transcribe = transcribe
-    decode = decode
-
-
-def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
-    """
-    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
-    """
-    if paddle.is_tensor(array):
-        if array.shape[axis] > length:
-            array = array.index_select(axis=axis, index=paddle.arange(length))
-
-        if array.shape[axis] < length:
-            pad_widths = [(0, 0)] * array.ndim
-            pad_widths[axis] = (0, length - array.shape[axis])
-            array = paddle.transpose(array, (1, 0))
-            array = paddle.nn.functional.pad(
-                array,
-                [pad for sizes in pad_widths[::-1] for pad in sizes],
-                data_format="NLC",
-            )
-            array = paddle.transpose(array, (1, 0))
-    else:
-        if array.shape[axis] > length:
-            array = array.take(indices=range(length), axis=axis)
-
-        if array.shape[axis] < length:
-            pad_widths = [(0, 0)] * array.ndim
-            pad_widths[axis] = (0, length - array.shape[axis])
-            array = paddle.transpose(array, (1, 0))
-            array = np.pad(array, pad_widths)
-            array = paddle.transpose(array, (1, 0))
-
-    return array
-
-
-def hann_window(n_fft: int = N_FFT):
-    """
-    hanning window
-    n_fft:  The number of frequency components of the discrete Fourier transform.
-    """
-    return paddle.to_tensor(
-        [0.5 - 0.5 * np.cos(2 * np.pi * n / n_fft) for n in range(n_fft)],
-        dtype=paddle.float32,
-    )
-
-
-@lru_cache(maxsize=None)
-def mel_filters(resource_path: str, n_mels: int = N_MELS) -> paddle.Tensor:
-    """
-    load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
-    Allows decoupling librosa dependency; saved using:
-
-        np.savez_compressed(
-            "mel_filters.npz",
-            mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
-        )
-    """
-    assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
-    with np.load(os.path.join(resource_path, "assets", "mel_filters.npz")) as f:
-        return paddle.to_tensor(f[f"mel_{n_mels}"])
-
-
-def log_mel_spectrogram(
-    audio: Union[str, np.ndarray, paddle.Tensor],
-    n_mels: int = N_MELS,
-    resource_path: str = None,
-):
-    """
-    Compute the log-Mel spectrogram of
-
-    Parameters
-    ----------
-    audio: Union[str, np.ndarray, paddle.Tensor], shape = (*)
-        The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
-
-    n_mels: int
-        The number of Mel-frequency filters, only 80 is supported
-
-    Returns
-    -------
-    paddle.Tensor, shape = (80, n_frames)
-        A Tensor that contains the Mel spectrogram
-    """
-    if not paddle.is_tensor(audio):
-        if isinstance(audio, str):
-            audio, _ = soundfile.read(audio, dtype="float32", always_2d=True)
-            audio = audio[:, 0]
-        audio = paddle.to_tensor(audio)
-
-    window = hann_window(N_FFT)
-    stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
-
-    magnitudes = stft[:, :-1].abs() ** 2
-
-    filters = mel_filters(resource_path, n_mels)
-    mel_spec = filters @ magnitudes
-    mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
-
-    log_spec = paddle.clip(mel_spec, min=1e-10).log10()
-    log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
-    log_spec = (log_spec + 4.0) / 4.0
-    return log_spec
+        window = hann_window(N_FFT)
+        stft = paddle.signal.stft(audio, N_FFT, HOP_LENGTH, window=window)
+
+        magnitudes = stft[:, :-1].abs() ** 2
+
+        filters = mel_filters(resource_path, n_mels)
+        mel_spec = filters @ magnitudes
+        mel_spec = paddle.to_tensor(mel_spec.numpy().tolist())
+
+        log_spec = paddle.clip(mel_spec, min=1e-10).log10()
+        log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0)
+        log_spec = (log_spec + 4.0) / 4.0
+        return log_spec

+ 7 - 1
paddlex/inference/models/object_detection/processors.py

@@ -14,20 +14,24 @@
 
 from typing import List, Optional, Sequence, Tuple, Union
 
-import cv2
 import numpy as np
 from numpy import ndarray
 
+from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
 from ...common.reader import ReadImage as CommonReadImage
 from ...utils.benchmark import benchmark
 from ..common import Normalize as CommonNormalize
 from ..common import Resize as CommonResize
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 Boxes = List[dict]
 Number = Union[int, float]
 
 
 @benchmark.timeit_with_options(name=None, is_read_operation=True)
+@class_requires_deps("opencv-contrib-python")
 class ReadImage(CommonReadImage):
     """Reads images from a list of raw image data or file paths."""
 
@@ -311,6 +315,7 @@ def _get_3rd_point(a: ndarray, b: ndarray) -> ndarray:
     return third_pt
 
 
+@function_requires_deps("opencv-contrib-python")
 def get_affine_transform(
     center: ndarray,
     input_size: Union[Number, Tuple[Number, Number], ndarray],
@@ -368,6 +373,7 @@ def get_affine_transform(
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class WarpAffine:
     """Apply warp affine transformation to the image based on the given parameters.
 

+ 6 - 1
paddlex/inference/models/open_vocabulary_detection/processors/common.py

@@ -15,10 +15,15 @@
 
 from typing import Dict, List, Tuple
 
-import cv2
 import numpy as np
 
+from .....utils.deps import class_requires_deps, is_dep_available
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
+
+@class_requires_deps("opencv-contrib-python")
 class LetterResize(object):
     def __init__(
         self,

+ 16 - 6
paddlex/inference/models/open_vocabulary_detection/processors/groundingdino_processors.py

@@ -18,15 +18,10 @@ from typing import Dict, List, Optional, Tuple, Union
 import numpy as np
 import PIL
 
-from .....utils.lazy_loader import LazyLoader
+from .....utils.deps import class_requires_deps
 from ....utils.benchmark import benchmark
 from ...common.tokenizer.bert_tokenizer import BertTokenizer
 
-# NOTE: LazyLoader is used to avoid conflicts between ultra-infer and Paddle
-paddle = LazyLoader("lazy_paddle", globals(), "paddle")
-T = LazyLoader("T", globals(), "paddle.vision.transforms")
-F = LazyLoader("F", globals(), "paddle.nn.functional")
-
 
 def _max_by_axis(the_list):
     maxes = the_list[0]
@@ -99,6 +94,7 @@ def _text_pad_batch_data(
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class GroundingDINOPostProcessor(object):
     """PostProcessors for GroundingDINO"""
 
@@ -129,6 +125,7 @@ class GroundingDINOPostProcessor(object):
         text_threshold=None,
         **kwargs,
     ):
+        import paddle
 
         box_threshold = self.box_threshold if box_threshold is None else box_threshold
         text_threshold = (
@@ -168,6 +165,8 @@ class GroundingDINOPostProcessor(object):
         text_threshold,
     ):
         """Post Process for prediction result of single image."""
+        import paddle
+        import paddle.nn.functional as F
 
         logits = F.sigmoid(pred_logits)
         boxes = pred_boxes
@@ -265,6 +264,7 @@ class GroundingDINOProcessor(object):
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class GroundingDinoTextProcessor(object):
     """Constructs a GroundingDino text processor."""
 
@@ -280,6 +280,8 @@ class GroundingDinoTextProcessor(object):
         special_tokens_list,
     ):
         """Preprocess the text with tokenization."""
+        import paddle
+
         tokenized_out = {}
         input_ids = _text_pad_batch_data(input_ids)
         input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64).squeeze(-1)
@@ -325,6 +327,8 @@ class GroundingDinoTextProcessor(object):
         Returns:
             torch.Tensor: attention mask between each special tokens.
         """
+        import paddle
+
         input_ids = tokenized["input_ids"]
         bs, num_token = input_ids.shape
         special_tokens_mask = paddle.zeros((bs, num_token), dtype=paddle.bool)
@@ -368,6 +372,7 @@ class GroundingDinoTextProcessor(object):
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class GroundingDinoImageProcessor(object):
     """Constructs a GroundingDino image processor."""
 
@@ -398,6 +403,7 @@ class GroundingDinoImageProcessor(object):
 
     def resize(self, image, size=None, max_size=1333):
         """Officially aligned Image resize."""
+        import paddle.vision.transforms as T
 
         def get_size_with_aspect_ratio(image_size, size, max_size=None):
             w, h = image_size
@@ -431,6 +437,8 @@ class GroundingDinoImageProcessor(object):
         return rescaled_image
 
     def nested_tensor_from_tensor_list(self, tensor_list):
+        import paddle
+
         if tensor_list[0].ndim == 3:
             max_size = _max_by_axis([list(img.shape) for img in tensor_list])
             batch_shape = [len(tensor_list)] + max_size
@@ -460,6 +468,8 @@ class GroundingDinoImageProcessor(object):
         **kwargs,
     ):
         """Preprocess an image or batch of images."""
+        import paddle.vision.transforms as T
+
         do_resize = do_resize if do_resize is not None else self.do_resize
         do_normalize = do_normalize if do_normalize is not None else self.do_normalize
         do_nested = do_nested if do_nested is not None else self.do_nested

+ 9 - 9
paddlex/inference/models/open_vocabulary_segmentation/processors/sam_processer.py

@@ -18,13 +18,7 @@ from typing import List, Optional, Tuple, Union
 import numpy as np
 import PIL
 
-from .....utils.lazy_loader import LazyLoader
-
-# NOTE: LazyLoader is used to avoid conflicts between ultra-infer and Paddle
-paddle = LazyLoader("lazy_paddle", globals(), "paddle")
-T = LazyLoader("T", globals(), "paddle.vision.transforms")
-F = LazyLoader("F", globals(), "paddle.nn.functional")
-
+from .....utils.deps import class_requires_deps
 from ....utils.benchmark import benchmark
 
 
@@ -39,6 +33,7 @@ def _get_preprocess_shape(
     return (newh, neww)
 
 
+@class_requires_deps("paddlepaddle")
 class SAMProcessor(object):
 
     def __init__(
@@ -107,6 +102,8 @@ class SAMProcessor(object):
         return image_seg, prompt
 
     def postprocess(self, low_res_masks, mask_threshold: float = 0.0):
+        import paddle
+        import paddle.nn.functional as F
 
         if isinstance(low_res_masks, list):
             assert len(low_res_masks) == 1
@@ -183,6 +180,7 @@ class SamPromptProcessor(object):
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class SamImageProcessor(object):
     """Constructs a Sam image processor."""
 
@@ -210,6 +208,8 @@ class SamImageProcessor(object):
 
     def apply_image(self, image: np.ndarray) -> np.ndarray:
         """Expects a numpy array with shape HxWxC in uint8 format."""
+        import paddle.vision.transforms as T
+
         target_size = _get_preprocess_shape(image.shape[0], image.shape[1], self.size)
         if isinstance(image, np.ndarray):
             image = PIL.Image.fromarray(image)
@@ -226,8 +226,8 @@ class SamImageProcessor(object):
         images,
     ):
         """Preprocess an image or a batch of images with a same shape."""
-
-        self.size
+        import paddle
+        import paddle.nn.functional as F
 
         input_image = [self.apply_image(image) for image in images]
 

+ 5 - 1
paddlex/inference/models/open_vocabulary_segmentation/results/sam_result.py

@@ -15,14 +15,18 @@
 import copy
 import random
 
-import cv2
 import numpy as np
 from PIL import Image
 
+from .....utils.deps import function_requires_deps, is_dep_available
 from ....common.result import BaseCVResult, JsonMixin
 from ....utils.color_map import get_colormap
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@function_requires_deps("opencv-contrib-python")
 def draw_segm(im, masks, mask_info, alpha=0.7):
     """
     Draw segmentation on image

+ 5 - 1
paddlex/inference/models/semantic_segmentation/processors.py

@@ -14,13 +14,16 @@
 
 import math
 
-import cv2
 import numpy as np
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 from ..common.vision import funcs as F
 from ..common.vision.processors import _BaseResize
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 
 @benchmark.timeit
 class Resize(_BaseResize):
@@ -80,6 +83,7 @@ class Resize(_BaseResize):
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class SegPostProcess:
     """Semantic Segmentation PostProcess
 

+ 5 - 1
paddlex/inference/models/table_structure_recognition/result.py

@@ -15,12 +15,16 @@
 import copy
 from pathlib import Path
 
-import cv2
 import numpy as np
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...common.result import BaseCVResult, JsonMixin
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@class_requires_deps("opencv-contrib-python")
 class TableRecResult(BaseCVResult):
     """SaveTableResults"""
 

+ 9 - 2
paddlex/inference/models/text_detection/processors.py

@@ -16,15 +16,20 @@ import math
 import sys
 from typing import Union
 
-import cv2
 import numpy as np
-import pyclipper
 
 from ....utils import logging
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+if is_dep_available("pyclipper"):
+    import pyclipper
+
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class DetResizeForTest:
     """DetResizeForTest"""
 
@@ -193,6 +198,7 @@ class DetResizeForTest:
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class NormalizeImage:
     """normalize image such as substract mean, divide std"""
 
@@ -232,6 +238,7 @@ class NormalizeImage:
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python", "pyclipper")
 class DBPostProcess:
     """
     The post process for Differentiable Binarization (DB).

+ 5 - 1
paddlex/inference/models/text_detection/result.py

@@ -15,12 +15,16 @@
 import copy
 from pathlib import Path
 
-import cv2
 import numpy as np
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...common.result import BaseCVResult, JsonMixin
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@class_requires_deps("opencv-contrib-python")
 class TextDetResult(BaseCVResult):
 
     def _get_input_fn(self):

+ 5 - 1
paddlex/inference/models/text_recognition/processors.py

@@ -17,13 +17,17 @@ import math
 import re
 from typing import List
 
-import cv2
 import numpy as np
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class OCRReisizeNormImg:
     """for ocr image resize and normalization"""
 

+ 6 - 2
paddlex/inference/models/ts_anomaly_detection/result.py

@@ -15,13 +15,17 @@
 import io
 from typing import Any
 
-import matplotlib.pyplot as plt
-import pandas as pd
 from PIL import Image
 
+from ....utils.deps import function_requires_deps, is_dep_available
 from ...common.result import BaseTSResult
 
+if is_dep_available("matplotlib"):
+    import matplotlib.pyplot as plt
+import pandas as pd
+
 
+@function_requires_deps("matplotlib")
 def visualize(forecast: pd.DataFrame) -> Image.Image:
     """
     Visualizes both the time series forecast and actual results, returning them as a Pillow image.

+ 5 - 1
paddlex/inference/models/ts_classification/result.py

@@ -15,12 +15,16 @@
 import io
 from typing import Any
 
-import matplotlib.pyplot as plt
 from PIL import Image
 
+from ....utils.deps import function_requires_deps, is_dep_available
 from ...common.result import BaseTSResult
 
+if is_dep_available("matplotlib"):
+    import matplotlib.pyplot as plt
 
+
+@function_requires_deps("matplotlib")
 def visualize(predicted_label, input_ts, target_cols):
     """
     Visualize time series data and its prediction results.

+ 5 - 1
paddlex/inference/models/ts_forecasting/processors.py

@@ -14,14 +14,18 @@
 
 from typing import Any, Dict, List
 
-import joblib
 import numpy as np
 import pandas as pd
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 
+if is_dep_available("joblib"):
+    import joblib
+
 
 @benchmark.timeit
+@class_requires_deps("joblib")
 class TSDeNormalize:
     """A class to de-normalize time series prediction data using a pre-fitted scaler."""
 

+ 5 - 1
paddlex/inference/models/ts_forecasting/result.py

@@ -15,13 +15,17 @@
 import io
 from typing import Any
 
-import matplotlib.pyplot as plt
 import pandas as pd
 from PIL import Image
 
+from ....utils.deps import function_requires_deps, is_dep_available
 from ...common.result import BaseTSResult
 
+if is_dep_available("matplotlib"):
+    import matplotlib.pyplot as plt
 
+
+@function_requires_deps("matplotlib")
 def visualize(forecast: pd.DataFrame, actual_data: pd.DataFrame) -> Image.Image:
     """
     Visualizes both the time series forecast and actual results, returning them as a Pillow image.

+ 14 - 16
paddlex/inference/models/video_classification/processors.py

@@ -15,14 +15,17 @@
 
 from typing import List, Optional, Sequence, Tuple, Union
 
-import cv2
-import lazy_paddle
 import numpy as np
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.benchmark import benchmark
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class Scale:
     """Scale images."""
 
@@ -161,22 +164,16 @@ class CenterCrop:
 
         crop_imgs = []
         th, tw = self.target_size, self.target_size
-        if isinstance(imgs, lazy_paddle.Tensor):
-            h, w = imgs.shape[-2:]
+        for img in imgs:
+            h, w, _ = img.shape
+            assert (w >= self.target_size) and (
+                h >= self.target_size
+            ), "image width({}) and height({}) should be larger than crop size".format(
+                w, h, self.target_size
+            )
             x1 = int(round((w - tw) / 2.0)) if self.do_round else (w - tw) // 2
             y1 = int(round((h - th) / 2.0)) if self.do_round else (h - th) // 2
-            crop_imgs = imgs[:, :, y1 : y1 + th, x1 : x1 + tw]
-        else:
-            for img in imgs:
-                h, w, _ = img.shape
-                assert (w >= self.target_size) and (
-                    h >= self.target_size
-                ), "image width({}) and height({}) should be larger than crop size".format(
-                    w, h, self.target_size
-                )
-                x1 = int(round((w - tw) / 2.0)) if self.do_round else (w - tw) // 2
-                y1 = int(round((h - th) / 2.0)) if self.do_round else (h - th) // 2
-                crop_imgs.append(img[y1 : y1 + th, x1 : x1 + tw])
+            crop_imgs.append(img[y1 : y1 + th, x1 : x1 + tw])
         return crop_imgs
 
     def __call__(self, videos: List[np.ndarray]) -> List[np.ndarray]:
@@ -247,6 +244,7 @@ class Image2Array:
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class NormalizeVideo:
     """
     Normalize video frames by subtracting the mean and dividing by the standard deviation.

+ 5 - 1
paddlex/inference/models/video_classification/result.py

@@ -12,17 +12,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import cv2
 import numpy as np
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ...common.result import BaseVideoResult
 from ...utils.color_map import get_colormap
 from ...utils.io import VideoReader
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@class_requires_deps("opencv-contrib-python")
 class TopkVideoResult(BaseVideoResult):
 
     def _to_video(self):

+ 16 - 3
paddlex/inference/models/video_detection/processors.py

@@ -15,14 +15,14 @@
 
 from typing import List
 
-import cv2
-import lazy_paddle as paddle
 import numpy as np
 
+from ....utils.deps import class_requires_deps, function_requires_deps
 from ...utils.benchmark import benchmark
 
 
 @benchmark.timeit
+@class_requires_deps("opencv-contrib-python")
 class ResizeVideo:
     """Resizes frames of a video to a specified target size.
 
@@ -57,6 +57,8 @@ class ResizeVideo:
         Raises:
             NotImplementedError: If a frame is not an instance of numpy.ndarray.
         """
+        import cv2
+
         num_seg = len(video)
         seg_len = len(video[0])
 
@@ -204,6 +206,7 @@ def convert2cpu_long(gpu_matrix):
     return int_64_g.cpu()
 
 
+@function_requires_deps("paddlepaddle")
 def get_region_boxes(
     output,
     conf_thresh=0.005,
@@ -237,6 +240,8 @@ def get_region_boxes(
     Returns:
         all_box(List[List[float]]): A list of predicted bounding boxes for each image in the batch.
     """
+    import paddle
+
     anchor_step = len(anchors) // num_anchors
     if output.dim() == 3:
         output = output.unsqueeze(0)
@@ -342,10 +347,13 @@ def get_region_boxes(
     return all_boxes
 
 
+@function_requires_deps("paddlepaddle")
 def nms(boxes, nms_thresh):
     """
     Performs non-maximum suppression on the input boxes based on their IoUs.
     """
+    import paddle
+
     if len(boxes) == 0:
         return boxes
     det_confs = paddle.zeros([len(boxes)])
@@ -365,10 +373,13 @@ def nms(boxes, nms_thresh):
     return out_boxes
 
 
+@function_requires_deps("paddlepaddle")
 def bbox_iou(box1, box2, x1y1x2y2=True):
     """
     Returns the Intersection over Union (IoU) of two bounding boxes.
     """
+    import paddle
+
     if x1y1x2y2:
         mx = min(box1[0], box2[0])
         Mx = max(box1[2], box2[2])
@@ -403,6 +414,7 @@ def bbox_iou(box1, box2, x1y1x2y2=True):
 
 
 @benchmark.timeit
+@class_requires_deps("paddlepaddle")
 class DetVideoPostProcess:
     """
     A class used to perform post-processing on detection results in videos.
@@ -421,7 +433,8 @@ class DetVideoPostProcess:
         self.labels = label_list
 
     def postprocess(self, pred: List, nms_thresh: float, score_thresh: float) -> List:
-        cv2.FONT_HERSHEY_SIMPLEX
+        import paddle
+
         num_seg = len(pred)
         pred_all = []
         for i in range(num_seg):

+ 5 - 1
paddlex/inference/models/video_detection/result.py

@@ -14,17 +14,21 @@
 
 import random
 
-import cv2
 import numpy as np
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ...common.result import BaseVideoResult
 from ...utils.color_map import get_colormap
 from ...utils.io import VideoReader
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@class_requires_deps("opencv-contrib-python")
 class DetVideoResult(BaseVideoResult):
 
     def _to_video(self):

+ 2 - 0
paddlex/inference/pipelines/anomaly_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.anomaly_detection.result import UadResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class AnomalyDetectionPipeline(BasePipeline):
     """Image AnomalyDetectionPipeline Pipeline"""
 

+ 3 - 0
paddlex/inference/pipelines/attribute_recognition/pipeline.py

@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...utils.hpi import HPIConfig
@@ -99,9 +100,11 @@ class AttributeRecPipeline(BasePipeline):
         return AttributeRecResult(single_img_res)
 
 
+@pipeline_requires_extra("cv")
 class PedestrianAttributeRecPipeline(AttributeRecPipeline):
     entities = "pedestrian_attribute_recognition"
 
 
+@pipeline_requires_extra("cv")
 class VehicleAttributeRecPipeline(AttributeRecPipeline):
     entities = "vehicle_attribute_recognition"

+ 5 - 1
paddlex/inference/pipelines/attribute_recognition/result.py

@@ -14,14 +14,17 @@
 
 import copy
 
-import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
+from ....utils.deps import class_requires_deps, is_dep_available
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ...common.result import BaseCVResult, JsonMixin
 from ...utils.color_map import font_colormap, get_colormap
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 
 def draw_attribute_result(img, boxes):
     """
@@ -71,6 +74,7 @@ def draw_attribute_result(img, boxes):
     return img
 
 
+@class_requires_deps("opencv-contrib-python")
 class AttributeRecResult(BaseCVResult):
 
     def _to_str(self, *args, **kwargs):

+ 4 - 5
paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py

@@ -18,9 +18,11 @@ import re
 from typing import Dict
 
 from .....utils import logging
+from .....utils.deps import class_requires_deps
 from .base import BaseChat
 
 
+@class_requires_deps("openai")
 class OpenAIBotChat(BaseChat):
     """OpenAI Bot Chat"""
 
@@ -40,6 +42,8 @@ class OpenAIBotChat(BaseChat):
             api_key is None for api_type is openai.
             ValueError: If end_point is not one of ['completion', 'chat_completion'].
         """
+        from openai import OpenAI
+
         super().__init__()
         model_name = config.get("model_name", None)
         # compatible with historical model name
@@ -64,11 +68,6 @@ class OpenAIBotChat(BaseChat):
                 "end_point must be one of ['completion', 'chat_completion']"
             )
 
-        try:
-            from openai import OpenAI
-        except:
-            raise Exception("openai is not installed, please install it first.")
-
         self.client = OpenAI(base_url=base_url, api_key=api_key)
 
         self.model_name = model_name

+ 7 - 2
paddlex/inference/pipelines/components/common/crop_image_regions.py

@@ -15,14 +15,18 @@
 import copy
 from typing import List, Tuple
 
-import cv2
 import numpy as np
 from numpy.linalg import norm
-from shapely.geometry import Polygon
 
+from .....utils.deps import class_requires_deps, is_dep_available
 from .base_operator import BaseOperator
 from .seal_det_warp import AutoRectifier
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+if is_dep_available("shapely"):
+    from shapely.geometry import Polygon
+
 
 class CropByBoxes(BaseOperator):
     """Crop Image by Boxes"""
@@ -60,6 +64,7 @@ class CropByBoxes(BaseOperator):
         return output_list
 
 
+@class_requires_deps("opencv-contrib-python", "shapely")
 class CropByPolys(BaseOperator):
     """Crop Image by Polys"""
 

+ 40 - 9
paddlex/inference/pipelines/components/common/seal_det_warp.py

@@ -14,24 +14,35 @@
 
 import copy
 
-import cv2
 import numpy as np
 from numpy import arctan, cos, sin, sqrt
 
 from .....utils import logging
+from .....utils.deps import (
+    class_requires_deps,
+    function_requires_deps,
+    is_dep_available,
+)
+
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
 #### [TODO] need sunting to add explanatory notes
 
 
+@function_requires_deps("opencv-contrib-python")
 def Homography(
     image,
     img_points,
     world_width,
     world_height,
-    interpolation=cv2.INTER_CUBIC,
+    interpolation=None,
     ratio_width=1.0,
     ratio_height=1.0,
 ):
+    if interpolation is None:
+        interpolation = cv2.INTER_CUBIC
+
     _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
 
     expand_x = int(0.5 * world_width * (ratio_width - 1))
@@ -60,13 +71,14 @@ def Homography(
     return dst_img
 
 
+@class_requires_deps("opencv-contrib-python")
 class PlanB:
     def __call__(
         self,
         image,
         points,
         curveTextRectifier,
-        interpolation=cv2.INTER_LINEAR,
+        interpolation=None,
         ratio_width=1.0,
         ratio_height=1.0,
         loss_thresh=5.0,
@@ -84,6 +96,8 @@ class PlanB:
         :param square: crop square image or not. True or False. The default is False
         :return:
         """
+        if interpolation is None:
+            interpolation = cv2.INTER_LINEAR
         h, w = image.shape[:2]
         _points = np.array(points).reshape(-1, 2).astype(np.float32)
         x_min = int(np.min(_points[:, 0]))
@@ -126,6 +140,7 @@ class PlanB:
         return dst_img, loss
 
 
+@class_requires_deps("opencv-contrib-python")
 class CurveTextRectifier:
     """
     spatial transformer via monocular vision
@@ -538,7 +553,7 @@ class CurveTextRectifier:
         img_points,
         obj_points,
         is_horizontal_text,
-        interpolation=cv2.INTER_LINEAR,
+        interpolation=None,
         ratio_width=1.0,
         ratio_height=1.0,
     ):
@@ -546,6 +561,9 @@ class CurveTextRectifier:
         divide and conquer: homography
         # ratio_width and ratio_height must be 1.0 here
         """
+        if interpolation is None:
+            interpolation = cv2.INTER_LINEAR
+
         _img_points = img_points.reshape(-1, 2)
         _obj_points = obj_points.reshape(-1, 3)
 
@@ -607,10 +625,13 @@ class CurveTextRectifier:
         img_points,
         world_width,
         world_height,
-        interpolation=cv2.INTER_CUBIC,
+        interpolation=None,
         ratio_width=1.0,
         ratio_height=1.0,
     ):
+        if interpolation is None:
+            interpolation = cv2.INTER_CUBIC
+
         _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
 
         expand_x = int(0.5 * world_width * (ratio_width - 1))
@@ -642,7 +663,7 @@ class CurveTextRectifier:
         self,
         image_data,
         points,
-        interpolation=cv2.INTER_LINEAR,
+        interpolation=None,
         ratio_width=1.0,
         ratio_height=1.0,
         mode="calibration",
@@ -657,6 +678,9 @@ class CurveTextRectifier:
         :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
         :return:
         """
+        if interpolation is None:
+            interpolation = cv2.INTER_LINEAR
+
         org_h, org_w = image_data.shape[:2]
         org_size = (org_w, org_h)
         self.image = image_data
@@ -703,6 +727,7 @@ class CurveTextRectifier:
         return dst, ret
 
 
+@class_requires_deps("opencv-contrib-python")
 class AutoRectifier:
     def __init__(self):
         self.npoints = 10
@@ -710,7 +735,7 @@ class AutoRectifier:
 
     @staticmethod
     def get_rotate_crop_image(
-        img, points, interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0
+        img, points, interpolation=None, ratio_width=1.0, ratio_height=1.0
     ):
         """
         crop or homography
@@ -721,6 +746,8 @@ class AutoRectifier:
         :param ratio_height:
         :return:
         """
+        if interpolation is None:
+            interpolation = cv2.INTER_CUBIC
         h, w = img.shape[:2]
         _points = np.array(points).reshape(-1, 2).astype(np.float32)
 
@@ -796,7 +823,7 @@ class AutoRectifier:
         self,
         image_data,
         points,
-        interpolation=cv2.INTER_LINEAR,
+        interpolation=None,
         ratio_width=1.0,
         ratio_height=1.0,
         loss_thresh=5.0,
@@ -813,6 +840,8 @@ class AutoRectifier:
         :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
         :return:
         """
+        if interpolation is None:
+            interpolation = cv2.INTER_LINEAR
         _points = np.array(points).reshape(-1, 2)
         if len(_points) >= self.npoints and len(_points) % 2 == 0:
             try:
@@ -879,7 +908,7 @@ class AutoRectifier:
         self,
         image_data,
         points_list,
-        interpolation=cv2.INTER_LINEAR,
+        interpolation=None,
         ratio_width=1.0,
         ratio_height=1.0,
         loss_thresh=5.0,
@@ -903,6 +932,8 @@ class AutoRectifier:
         for points in points_list:
             if not isinstance(points, list):
                 raise ValueError
+        if interpolation is None:
+            interpolation = cv2.INTER_LINEAR
 
         if ratio_width < 1.0 or ratio_height < 1.0:
             raise ValueError(

+ 6 - 1
paddlex/inference/pipelines/components/common/warp_image.py

@@ -12,10 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import cv2
 import numpy as np
 
+from .....utils.deps import function_requires_deps, is_dep_available
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
+
+@function_requires_deps("opencv-contrib-python")
 def rotate_image(image, angle):
     if angle < 0 or angle >= 360:
         raise ValueError("`angle` should be in range [0, 360)")

+ 6 - 1
paddlex/inference/pipelines/components/faisser.py

@@ -15,13 +15,17 @@
 import pickle
 from pathlib import Path
 
-import faiss
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import class_requires_deps, is_dep_available
 from ...utils.io import YAMLReader, YAMLWriter
 
+if is_dep_available("faiss-cpu"):
+    import faiss
 
+
+@class_requires_deps("faiss-cpu")
 class IndexData:
     VECTOR_FN = "vector"
     VECTOR_SUFFIX = ".index"
@@ -164,6 +168,7 @@ class FaissIndexer:
         return preds
 
 
+@class_requires_deps("faiss-cpu")
 class FaissBuilder:
 
     SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")

+ 13 - 9
paddlex/inference/pipelines/components/retriever/base.py

@@ -16,16 +16,20 @@ import time
 from abc import ABC, abstractmethod
 from typing import List
 
-from langchain.docstore.document import Document
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain_community import vectorstores
-from langchain_community.vectorstores import FAISS
-
 from paddlex.utils import logging
 
+from .....utils.deps import class_requires_deps, is_dep_available
 from .....utils.subclass_register import AutoRegisterABCMetaClass
 
+if is_dep_available("langchain"):
+    from langchain.docstore.document import Document
+    from langchain.text_splitter import RecursiveCharacterTextSplitter
+if is_dep_available("langchain-community"):
+    from langchain_community import vectorstores
+    from langchain_community.vectorstores import FAISS
+
 
+@class_requires_deps("langchain", "langchain-community")
 class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     """Base Retriever"""
 
@@ -111,7 +115,7 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
         text_list: List[str],
         block_size: int = 300,
         separators: List[str] = ["\t", "\n", "。", "\n\n", ""],
-    ) -> FAISS:
+    ) -> "FAISS":
         """
         Generates a vector database from a list of texts.
 
@@ -141,7 +145,7 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
 
         return vectorstore
 
-    def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str:
+    def encode_vector_store_to_bytes(self, vectorstore: "FAISS") -> str:
         """
         Encode the vector store serialized to bytes.
 
@@ -157,7 +161,7 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
             vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
         return vectorstore
 
-    def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS:
+    def decode_vector_store_from_bytes(self, vectorstore: str) -> "FAISS":
         """
         Decode a vector store from bytes according to the specified API type.
 
@@ -190,7 +194,7 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     def similarity_retrieval(
         self,
         query_text_list: List[str],
-        vectorstore: FAISS,
+        vectorstore: "FAISS",
         sleep_time: float = 0.5,
         topk: int = 2,
         min_characters: int = 3500,

+ 84 - 81
paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py

@@ -16,10 +16,10 @@ import json
 from typing import Dict, List
 
 import requests
-from langchain_core.embeddings import Embeddings
 
 from paddlex.utils import logging
 
+from .....utils.deps import is_dep_available
 from .base import BaseRetriever
 
 
@@ -80,84 +80,87 @@ class QianFanBotRetriever(BaseRetriever):
         self.config = config
 
 
-class QianfanEmbeddings(Embeddings):
-    """`Baidu Qianfan Embeddings` embedding models."""
-
-    def __init__(
-        self,
-        api_key: str,
-        base_url: str = "https://qianfan.baidubce.com/v2",
-        model: str = "embedding-v1",
-        **kwargs,
-    ):
-        """
-        Initialize the Baidu Qianfan Embeddings class.
-
-        Args:
-            api_key (str): The Qianfan API key.
-            base_url (str): The base URL for 'qianfan' API.
-            model (str): Model name. Default is "embedding-v1",select in ["tao-8k","embedding-v1","bge-large-en","bge-large-zh"].
-            kwargs (dict): Additional keyword arguments passed to the base Embeddings class.
-        """
-        super().__init__(**kwargs)
-        chunk_size_map = {
-            "tao-8k": 1,
-            "embedding-v1": 16,
-            "bge-large-en": 16,
-            "bge-large-zh": 16,
-        }
-        self.api_key = api_key
-        self.base_url = base_url
-        self.model = model
-        self.chunk_size = chunk_size_map.get(model, 1)
-
-    def embed(self, texts: str, **kwargs) -> List[float]:
-        url = f"{self.base_url}/embeddings"
-        payload = json.dumps(
-            {"model": kwargs.get("model", self.model), "input": [f"{texts}"]}
-        )
-        headers = {
-            "Content-Type": "application/json",
-            "Authorization": f"Bearer {self.api_key}",
-        }
-
-        response = requests.request("POST", url, headers=headers, data=payload)
-        if response.status_code != 200:
-            logging.error(
-                f"Failed to call Qianfan API. Status code: {response.status_code}, Response content: {response}"
+if is_dep_available("langchain-core"):
+    from langchain_core.embeddings import Embeddings
+
+    class QianfanEmbeddings(Embeddings):
+        """`Baidu Qianfan Embeddings` embedding models."""
+
+        def __init__(
+            self,
+            api_key: str,
+            base_url: str = "https://qianfan.baidubce.com/v2",
+            model: str = "embedding-v1",
+            **kwargs,
+        ):
+            """
+            Initialize the Baidu Qianfan Embeddings class.
+
+            Args:
+                api_key (str): The Qianfan API key.
+                base_url (str): The base URL for 'qianfan' API.
+                model (str): Model name. Default is "embedding-v1",select in ["tao-8k","embedding-v1","bge-large-en","bge-large-zh"].
+                kwargs (dict): Additional keyword arguments passed to the base Embeddings class.
+            """
+            super().__init__(**kwargs)
+            chunk_size_map = {
+                "tao-8k": 1,
+                "embedding-v1": 16,
+                "bge-large-en": 16,
+                "bge-large-zh": 16,
+            }
+            self.api_key = api_key
+            self.base_url = base_url
+            self.model = model
+            self.chunk_size = chunk_size_map.get(model, 1)
+
+        def embed(self, texts: str, **kwargs) -> List[float]:
+            url = f"{self.base_url}/embeddings"
+            payload = json.dumps(
+                {"model": kwargs.get("model", self.model), "input": [f"{texts}"]}
             )
-
-        return response.json()
-
-    def embed_query(self, text: str) -> List[float]:
-        resp = self.embed_documents([text])
-        return resp[0]
-
-    def embed_documents(self, texts: List[str]) -> List[List[float]]:
-        """
-        Embeds a list of text documents using the AutoVOT algorithm.
-
-        Args:
-            texts (List[str]): A list of text documents to embed.
-
-        Returns:
-            List[List[float]]: A list of embeddings for each document in the input list.
-                            Each embedding is represented as a list of float values.
-        """
-        lst = []
-        for chunk in texts:
-            resp = self.embed(texts=chunk)
-            lst.extend([res["embedding"] for res in resp["data"]])
-        return lst
-
-    async def aembed_query(self, text: str) -> List[float]:
-        embeddings = await self.aembed_documents([text])
-        return embeddings[0]
-
-    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
-        lst = []
-        for chunk in texts:
-            resp = await self.embed(texts=chunk)
-            for res in resp["data"]:
-                lst.extend([res["embedding"]])
-        return lst
+            headers = {
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {self.api_key}",
+            }
+
+            response = requests.request("POST", url, headers=headers, data=payload)
+            if response.status_code != 200:
+                logging.error(
+                    f"Failed to call Qianfan API. Status code: {response.status_code}, Response content: {response}"
+                )
+
+            return response.json()
+
+        def embed_query(self, text: str) -> List[float]:
+            resp = self.embed_documents([text])
+            return resp[0]
+
+        def embed_documents(self, texts: List[str]) -> List[List[float]]:
+            """
+            Embeds a list of text documents using the AutoVOT algorithm.
+
+            Args:
+                texts (List[str]): A list of text documents to embed.
+
+            Returns:
+                List[List[float]]: A list of embeddings for each document in the input list.
+                                Each embedding is represented as a list of float values.
+            """
+            lst = []
+            for chunk in texts:
+                resp = self.embed(texts=chunk)
+                lst.extend([res["embedding"] for res in resp["data"]])
+            return lst
+
+        async def aembed_query(self, text: str) -> List[float]:
+            embeddings = await self.aembed_documents([text])
+            return embeddings[0]
+
+        async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+            lst = []
+            for chunk in texts:
+                resp = await self.embed(texts=chunk)
+                for res in resp["data"]:
+                    lst.extend([res["embedding"]])
+            return lst

+ 2 - 0
paddlex/inference/pipelines/doc_preprocessor/pipeline.py

@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...utils.hpi import HPIConfig
@@ -26,6 +27,7 @@ from ..components import rotate_image
 from .result import DocPreprocessorResult
 
 
+@pipeline_requires_extra("ocr")
 class DocPreprocessorPipeline(BasePipeline):
     """Doc Preprocessor Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/face_recognition/pipeline.py

@@ -14,10 +14,12 @@
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ..pp_shitu_v2 import ShiTuV2Pipeline
 from .result import FaceRecResult
 
 
+@pipeline_requires_extra("cv")
 class FaceRecPipeline(ShiTuV2Pipeline):
     """Face Recognition Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/formula_recognition/pipeline.py

@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...models.formula_recognition.result import (
@@ -30,6 +31,7 @@ from ..components import CropByBoxes
 from .result import FormulaRecognitionResult
 
 
+@pipeline_requires_extra("ocr")
 class FormulaRecognitionPipeline(BasePipeline):
     """Formula Recognition Pipeline"""
 

+ 6 - 1
paddlex/inference/pipelines/formula_recognition/result.py

@@ -20,11 +20,11 @@ import tempfile
 from pathlib import Path
 from typing import Dict, Tuple
 
-import cv2
 import numpy as np
 from PIL import Image, ImageDraw
 
 from ....utils import logging
+from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ...common.result import BaseCVResult, JsonMixin
 from ...models.formula_recognition.result import (
@@ -37,7 +37,11 @@ from ...models.formula_recognition.result import (
     pdf2img,
 )
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@class_requires_deps("opencv-contrib-python")
 class FormulaRecognitionResult(BaseCVResult):
     """Formula Recognition Result"""
 
@@ -222,6 +226,7 @@ class FormulaRecognitionResult(BaseCVResult):
         return JsonMixin._to_json(data, *args, **kwargs)
 
 
+@function_requires_deps("opencv-contrib-python")
 def draw_box_formula_fine(
     img_size: Tuple[int, int], box: np.ndarray, formula: str, is_debug: bool = False
 ) -> np.ndarray:

+ 2 - 0
paddlex/inference/pipelines/image_classification/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.image_classification.result import TopkResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class ImageClassificationPipeline(BasePipeline):
     """Image Classification Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/image_multilabel_classification/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.image_multilabel_classification.result import MLClassResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class ImageMultiLabelClassificationPipeline(BasePipeline):
     """Image Multi Label Classification Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/instance_segmentation/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.instance_segmentation.result import InstanceSegResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class InstanceSegmentationPipeline(BasePipeline):
     """Instance Segmentation Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/keypoint_detection/pipeline.py

@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.keypoint_detection.result import KptResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
@@ -24,6 +25,7 @@ from ..base import BasePipeline
 Number = Union[int, float]
 
 
+@pipeline_requires_extra("cv")
 class KeypointDetectionPipeline(BasePipeline):
     """Keypoint Detection pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/layout_parsing/pipeline.py

@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...models.object_detection.result import DetResult
@@ -29,6 +30,7 @@ from .result import LayoutParsingResult
 from .utils import get_sub_regions_ocr_res, sorted_layout_boxes
 
 
+@pipeline_requires_extra("ocr")
 class LayoutParsingPipeline(BasePipeline):
     """Layout Parsing Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

@@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, Tuple, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...models.object_detection.result import DetResult
@@ -31,6 +32,7 @@ from .result_v2 import LayoutParsingResultV2
 from .utils import gather_imgs, get_single_block_parsing_res, get_sub_regions_ocr_res
 
 
+@pipeline_requires_extra("ocr")
 class LayoutParsingPipelineV2(BasePipeline):
     """Layout Parsing Pipeline V2"""
 

+ 2 - 0
paddlex/inference/pipelines/m_3d_bev_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.m_3d_bev_detection.result import BEV3DDetResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class BEVDet3DPipeline(BasePipeline):
     """3D Detection Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/multilingual_speech_recognition/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.multilingual_speech_recognition.result import WhisperResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("speech")
 class MultilingualSpeechRecognitionPipeline(BasePipeline):
     """Multilingual Speech Recognition Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/object_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.object_detection.result import DetResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class ObjectDetectionPipeline(BasePipeline):
     """Object Detection Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/ocr/pipeline.py

@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...utils.hpi import HPIConfig
@@ -32,6 +33,7 @@ from ..components import (
 from .result import OCRResult
 
 
+@pipeline_requires_extra("ocr")
 class OCRPipeline(BasePipeline):
     """OCR Pipeline"""
 

+ 6 - 1
paddlex/inference/pipelines/ocr/result.py

@@ -17,14 +17,18 @@ import random
 from pathlib import Path
 from typing import Dict
 
-import cv2
 import numpy as np
 from PIL import Image, ImageDraw
 
+from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
 from ....utils.fonts import SIMFANG_FONT_FILE_PATH, create_font
 from ...common.result import BaseCVResult, JsonMixin
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@class_requires_deps("opencv-contrib-python")
 class OCRResult(BaseCVResult):
     """OCR result"""
 
@@ -193,6 +197,7 @@ class OCRResult(BaseCVResult):
 
 
 # Adds a function comment according to Google Style Guide
+@function_requires_deps("opencv-contrib-python")
 def draw_box_txt_fine(
     img_size: tuple, box: np.ndarray, txt: str, font_path: str
 ) -> np.ndarray:

+ 2 - 0
paddlex/inference/pipelines/open_vocabulary_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.object_detection.result import DetResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("multimodal")
 class OpenVocabularyDetectionPipeline(BasePipeline):
     """Open Vocabulary Detection Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/open_vocabulary_segmentation/pipeline.py

@@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.open_vocabulary_segmentation.results import SAMSegResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
@@ -24,6 +25,7 @@ from ..base import BasePipeline
 Number = Union[int, float]
 
 
+@pipeline_requires_extra("multimodal")
 class OpenVocabularySegmentationPipeline(BasePipeline):
     """Open Vocabulary Segmentation pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py

@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ....utils.file_interface import custom_open
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
@@ -31,6 +32,7 @@ from ..layout_parsing.result import LayoutParsingResult
 from .pipeline_base import PP_ChatOCR_Pipeline
 
 
+@pipeline_requires_extra("ie")
 class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
     """PP-ChatOCR Pipeline"""
 

+ 10 - 1
paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py

@@ -19,10 +19,14 @@ import os
 import re
 from typing import Any, Dict, List, Optional, Tuple, Union
 
-import cv2
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import (
+    function_requires_deps,
+    is_dep_available,
+    pipeline_requires_extra,
+)
 from ....utils.file_interface import custom_open
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
@@ -32,7 +36,11 @@ from ..components.chat_server import BaseChat
 from ..layout_parsing.result import LayoutParsingResult
 from .pipeline_base import PP_ChatOCR_Pipeline
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
 
+
+@pipeline_requires_extra("ie")
 class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
     """PP-ChatOCRv4 Pipeline"""
 
@@ -583,6 +591,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
 
         return []
 
+    @function_requires_deps("opencv-contrib-python")
     def mllm_pred(
         self,
         input: Union[str, np.ndarray],

+ 2 - 0
paddlex/inference/pipelines/pp_shitu_v2/pipeline.py

@@ -14,6 +14,7 @@
 
 from typing import Any, Dict, Optional, Union
 
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...utils.hpi import HPIConfig
@@ -23,6 +24,7 @@ from ..components import CropByBoxes, FaissBuilder, FaissIndexer
 from .result import ShiTuResult
 
 
+@pipeline_requires_extra("cv")
 class ShiTuV2Pipeline(BasePipeline):
     """ShiTuV2 Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/rotated_object_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.object_detection.result import DetResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class RotatedObjectDetectionPipeline(BasePipeline):
     """Rotated Object Detection Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/seal_recognition/pipeline.py

@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...models.object_detection.result import DetResult
@@ -27,6 +28,7 @@ from ..components import CropByBoxes
 from .result import SealRecognitionResult
 
 
+@pipeline_requires_extra("ocr")
 class SealRecognitionPipeline(BasePipeline):
     """Seal Recognition Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/semantic_segmentation/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.semantic_segmentation.result import SegResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class SemanticSegmentationPipeline(BasePipeline):
     """Semantic Segmentation Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/small_object_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.object_detection.result import DetResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("cv")
 class SmallObjectDetectionPipeline(BasePipeline):
     """Small Object Detection Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/table_recognition/pipeline.py

@@ -18,6 +18,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import numpy as np
 
 from ....utils import logging
+from ....utils.deps import pipeline_requires_extra
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...models.object_detection.result import DetResult
@@ -32,6 +33,7 @@ from .table_recognition_post_processing import get_table_recognition_res
 from .utils import get_neighbor_boxes_idx
 
 
+@pipeline_requires_extra("ocr")
 class TableRecognitionPipeline(BasePipeline):
     """Table Recognition Pipeline"""
 

+ 10 - 1
paddlex/inference/pipelines/table_recognition/pipeline_v2.py

@@ -16,9 +16,13 @@ import math
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
-from sklearn.cluster import KMeans
 
 from ....utils import logging
+from ....utils.deps import (
+    function_requires_deps,
+    is_dep_available,
+    pipeline_requires_extra,
+)
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
 from ...models.object_detection.result import DetResult
@@ -35,7 +39,11 @@ from .table_recognition_post_processing import (
 from .table_recognition_post_processing_v2 import get_table_recognition_res
 from .utils import get_neighbor_boxes_idx
 
+if is_dep_available("scikit-learn"):
+    from sklearn.cluster import KMeans
+
 
+@pipeline_requires_extra("ocr")
 class TableRecognitionPipelineV2(BasePipeline):
     """Table Recognition Pipeline"""
 
@@ -422,6 +430,7 @@ class TableRecognitionPipelineV2(BasePipeline):
             return iou
 
         # Function to combine rectangles into N rectangles
+        @function_requires_deps("scikit-learn")
         def combine_rectangles(rectangles, N):
             """
             Combine rectangles into N rectangles based on geometric proximity.

+ 2 - 0
paddlex/inference/pipelines/ts_anomaly_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import pandas as pd
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.ts_anomaly_detection.result import TSAdResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("ts")
 class TSAnomalyDetPipeline(BasePipeline):
     """TSAnomalyDetPipeline Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/ts_classification/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import pandas as pd
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.ts_classification.result import TSClsResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("ts")
 class TSClsPipeline(BasePipeline):
     """TSClsPipeline Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/ts_forecasting/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import pandas as pd
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.ts_forecasting.result import TSFcResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("ts")
 class TSFcPipeline(BasePipeline):
     """TSFcPipeline Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/video_classification/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.video_classification.result import TopkVideoResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("video")
 class VideoClassificationPipeline(BasePipeline):
     """Video Classification Pipeline"""
 

+ 2 - 0
paddlex/inference/pipelines/video_detection/pipeline.py

@@ -16,12 +16,14 @@ from typing import Any, Dict, List, Optional, Union
 
 import numpy as np
 
+from ....utils.deps import pipeline_requires_extra
 from ...models.video_detection.result import DetVideoResult
 from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 
 
+@pipeline_requires_extra("video")
 class VideoDetectionPipeline(BasePipeline):
     """Video detection Pipeline"""
 

+ 4 - 0
paddlex/inference/serving/__init__.py

@@ -11,3 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from ...utils.deps import require_serving_plugin
+
+require_serving_plugin()

+ 20 - 11
paddlex/inference/serving/basic_serving/_app.py

@@ -28,20 +28,25 @@ from typing import (
     TypeVar,
 )
 
-import aiohttp
-import fastapi
-from fastapi.encoders import jsonable_encoder
-from fastapi.exceptions import RequestValidationError
-from fastapi.responses import JSONResponse
-from starlette.exceptions import HTTPException
 from typing_extensions import ParamSpec, TypeGuard
 
 from ....utils import logging
+from ....utils.deps import class_requires_deps, function_requires_deps, is_dep_available
 from ...pipelines import BasePipeline
 from ..infra.config import AppConfig
 from ..infra.models import NoResultResponse
 from ..infra.utils import call_async, generate_log_id
 
+if is_dep_available("aiohttp"):
+    import aiohttp
+if is_dep_available("fastapi"):
+    import fastapi
+    from fastapi.encoders import jsonable_encoder
+    from fastapi.exceptions import RequestValidationError
+    from fastapi.responses import JSONResponse
+if is_dep_available("starlette"):
+    from starlette.exceptions import HTTPException
+
 PipelineT = TypeVar("PipelineT", bound=BasePipeline)
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -64,6 +69,7 @@ def _is_error(obj: object) -> TypeGuard[_Error]:
 # for type hinting. However, I would stick with the current design, as it does
 # not introduce runtime overhead at the moment and may prove useful in the
 # future.
+@class_requires_deps("fastapi")
 class PipelineWrapper(Generic[PipelineT]):
     def __init__(self, pipeline: PipelineT) -> None:
         super().__init__()
@@ -94,6 +100,7 @@ class PipelineWrapper(Generic[PipelineT]):
             return await call_async(func, *args, **kwargs)
 
 
+@class_requires_deps("aiohttp")
 class AppContext(Generic[PipelineT]):
     def __init__(self, *, config: AppConfig) -> None:
         super().__init__()
@@ -117,21 +124,22 @@ class AppContext(Generic[PipelineT]):
         self._pipeline = val
 
     @property
-    def aiohttp_session(self) -> aiohttp.ClientSession:
+    def aiohttp_session(self) -> "aiohttp.ClientSession":
         if not self._aiohttp_session:
             raise AttributeError("`aiohttp_session` has not been set.")
         return self._aiohttp_session
 
     @aiohttp_session.setter
-    def aiohttp_session(self, val: aiohttp.ClientSession) -> None:
+    def aiohttp_session(self, val: "aiohttp.ClientSession") -> None:
         self._aiohttp_session = val
 
 
+@function_requires_deps("fastapi", "aiohttp", "starlette")
 def create_app(
     *, pipeline: PipelineT, app_config: AppConfig, app_aiohttp_session: bool = True
-) -> Tuple[fastapi.FastAPI, AppContext[PipelineT]]:
+) -> Tuple["fastapi.FastAPI", AppContext[PipelineT]]:
     @contextlib.asynccontextmanager
-    async def _app_lifespan(app: fastapi.FastAPI) -> AsyncGenerator[None, None]:
+    async def _app_lifespan(app: "fastapi.FastAPI") -> AsyncGenerator[None, None]:
         ctx.pipeline = PipelineWrapper[PipelineT](pipeline)
         if app_aiohttp_session:
             async with aiohttp.ClientSession(
@@ -197,8 +205,9 @@ def create_app(
 
 
 # TODO: Precise type hints
+@function_requires_deps("fastapi")
 def primary_operation(
-    app: fastapi.FastAPI, path: str, operation_id: str, **kwargs: Any
+    app: "fastapi.FastAPI", path: str, operation_id: str, **kwargs: Any
 ) -> Callable:
     return app.post(
         path,

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py

@@ -15,10 +15,12 @@
 import importlib
 from typing import Any, Dict
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra.config import create_app_config
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
 def _pipeline_name_to_mod_name(pipeline_name: str) -> str:
     if not pipeline_name:
@@ -31,7 +33,8 @@ def _pipeline_name_to_mod_name(pipeline_name: str) -> str:
 
 # XXX: A dynamic approach is used here for writing fewer lines of code, at the
 # cost of sacrificing some benefits of type hints.
-def create_pipeline_app(pipeline: Any, pipeline_config: Dict[str, Any]) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, pipeline_config: Dict[str, Any]) -> "FastAPI":
     pipeline_name = pipeline_config["pipeline_name"]
     mod_name = _pipeline_name_to_mod_name(pipeline_name)
     mod = importlib.import_module(f".{mod_name}", package=__package__)

+ 5 - 1
paddlex/inference/serving/basic_serving/_pipeline_apps/_common/common.py

@@ -15,13 +15,16 @@
 import os
 from typing import Dict, Optional, Tuple, Union
 
-import cv2
 import numpy as np
 from PIL.Image import Image
 
+from ......utils.deps import function_requires_deps, is_dep_available
 from ....infra import utils as serving_utils
 from ....infra.storage import Storage, SupportsGetURL
 
+if is_dep_available("opencv-contrib-python"):
+    import cv2
+
 
 def prune_result(result: dict) -> dict:
     KEYS_TO_REMOVE = ["input_path"]
@@ -39,6 +42,7 @@ def prune_result(result: dict) -> dict:
     return _process_obj(result)
 
 
+@function_requires_deps("opencv-contrib-python")
 def postprocess_image(
     image: np.ndarray,
     log_id: str,

+ 6 - 1
paddlex/inference/serving/basic_serving/_pipeline_apps/_common/ocr.py

@@ -15,15 +15,19 @@
 from typing import Final, List, Tuple, Union
 
 import numpy as np
-from fastapi import HTTPException
 from typing_extensions import Literal
 
+from ......utils.deps import function_requires_deps, is_dep_available
 from ....infra import utils as serving_utils
 from ....infra.models import ImageInfo, PDFInfo
 from ....infra.storage import SupportsGetURL, create_storage
 from ....schemas.shared.ocr import BaseInferRequest
 from ..._app import AppContext
 
+if is_dep_available("fastapi"):
+    from fastapi import HTTPException
+
+
 DEFAULT_MAX_NUM_INPUT_IMGS: Final[int] = 10
 DEFAULT_MAX_OUTPUT_IMG_SIZE: Final[Tuple[int, int]] = (2000, 2000)
 
@@ -52,6 +56,7 @@ def update_app_context(app_context: AppContext) -> None:
     )
 
 
+@function_requires_deps("fastapi")
 def get_file_type(request: BaseInferRequest) -> Literal["PDF", "IMAGE"]:
     if request.fileType is None:
         if serving_utils.is_url(request.file):

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/anomaly_detection.py

@@ -14,16 +14,19 @@
 
 from typing import Any
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
 from ...schemas.anomaly_detection import INFER_ENDPOINT, InferRequest, InferResult
 from .._app import create_app, primary_operation
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/doc_preprocessor.py

@@ -14,8 +14,7 @@
 
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
@@ -24,8 +23,12 @@ from .._app import create_app, primary_operation
 from ._common import common
 from ._common import ocr as ocr_common
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/face_recognition.py

@@ -16,8 +16,7 @@ import asyncio
 from operator import attrgetter
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ....pipelines.components import IndexData
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
@@ -26,13 +25,17 @@ from ...schemas import face_recognition as schema
 from .._app import create_app, primary_operation
 from ._common import image_recognition as ir_common
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 # XXX: Currently the implementations of the face recognition and PP-ShiTuV2
 # pipeline apps overlap significantly. We should aim to facilitate code reuse,
 # but is it acceptable to assume a strong similarity between these two
 # pipelines?
 
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/formula_recognition.py

@@ -14,8 +14,7 @@
 
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
@@ -24,8 +23,12 @@ from .._app import create_app, primary_operation
 from ._common import common
 from ._common import ocr as ocr_common
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/human_keypoint_detection.py

@@ -14,8 +14,7 @@
 
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
@@ -26,8 +25,12 @@ from ...schemas.human_keypoint_detection import (
 )
 from .._app import create_app, primary_operation
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/image_classification.py

@@ -14,16 +14,19 @@
 
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
 from ...schemas.image_classification import INFER_ENDPOINT, InferRequest, InferResult
 from .._app import create_app, primary_operation
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/image_multilabel_classification.py

@@ -14,8 +14,7 @@
 
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
@@ -26,8 +25,12 @@ from ...schemas.image_multilabel_classification import (
 )
 from .._app import create_app, primary_operation
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 9 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/instance_segmentation.py

@@ -15,22 +15,28 @@
 from typing import Any, Dict, List
 
 import numpy as np
-import pycocotools.mask as mask_util
-from fastapi import FastAPI
 
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
 from ...schemas.instance_segmentation import INFER_ENDPOINT, InferRequest, InferResult
 from .._app import create_app, primary_operation
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+if is_dep_available("pycocotools"):
+    import pycocotools.mask as mask_util
 
+
+@function_requires_deps("pycocotools")
 def _rle(mask: np.ndarray) -> str:
     rle_res = mask_util.encode(np.asarray(mask[..., None], order="F", dtype="uint8"))[0]
     return rle_res["counts"].decode("utf-8")
 
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline,
         app_config=app_config,

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/layout_parsing.py

@@ -14,8 +14,7 @@
 
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
@@ -24,8 +23,12 @@ from .._app import create_app, primary_operation
 from ._common import common
 from ._common import ocr as ocr_common
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/m_3d_bev_detection.py

@@ -15,16 +15,19 @@
 import os
 from typing import Any, Dict, List
 
-from fastapi import FastAPI
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
 from ...schemas.m_3d_bev_detection import INFER_ENDPOINT, InferRequest, InferResult
 from .._app import create_app, primary_operation
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

+ 6 - 3
paddlex/inference/serving/basic_serving/_pipeline_apps/multilingual_speech_recognition.py

@@ -15,8 +15,7 @@
 import os
 from typing import Any, Dict, List
 
-from fastapi import FastAPI, HTTPException
-
+from .....utils.deps import function_requires_deps, is_dep_available
 from ...infra import utils as serving_utils
 from ...infra.config import AppConfig
 from ...infra.models import ResultResponse
@@ -27,8 +26,12 @@ from ...schemas.multilingual_speech_recognition import (
 )
 from .._app import create_app, primary_operation
 
+if is_dep_available("fastapi"):
+    from fastapi import FastAPI, HTTPException
+
 
-def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
+@function_requires_deps("fastapi")
+def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
     app, ctx = create_app(
         pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
     )

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini