Bläddra i källkod

add cache func (#3509)

* add cache func

* update

* select the corresponding warpper by

* update

* update doc

* use ENV to select inference API

* update

* update
zhang-prog 8 månader sedan
förälder
incheckning
eb3e60685b

+ 5 - 3
docs/module_usage/instructions/benchmark.en.md

@@ -14,10 +14,12 @@ The benchmark feature collects the average execution time per iteration for each
 
 To enable the benchmark feature, you must set the following environment variables:
 
-* `PADDLE_PDX_INFER_BENCHMARK`: When set to `True`, the benchmark feature is enabled (default is `False`);
-* `PADDLE_PDX_INFER_BENCHMARK_WARMUP`: The number of warm-up iterations before testing (default is `0`);
-* `PADDLE_PDX_INFER_BENCHMARK_ITERS`: The number of iterations for testing (default is `0`);
+* `PADDLE_PDX_INFER_BENCHMARK`: When set to `True`, the benchmark feature is enabled (default is `False`).
+* `PADDLE_PDX_INFER_BENCHMARK_WARMUP`: The number of warm-up iterations before testing (default is `0`).
+* `PADDLE_PDX_INFER_BENCHMARK_ITERS`: The number of iterations for testing (default is `0`).
 * `PADDLE_PDX_INFER_BENCHMARK_OUTPUT_DIR`: The directory where the metrics are saved (e.g., `./benchmark`). The default is `None`, meaning the benchmark metrics will not be saved.
+* `PADDLE_PDX_INFER_BENCHMARK_USE_CACHE_FOR_READ`: When set to `True`, the caching mechanism is applied to the operation of reading input data to avoid repetitive I/O overhead, and the time consumed by data read and cache is not recorded in the core time (default is `False`).
+* `PADDLE_PDX_INFER_BENCHMARK_USE_NEW_INFER_API`: When set to `True`,the new inference API is enabled, providing more detailed information for inference operations on benchmarks (default is `False`).
 
 **Note**:
 

+ 3 - 1
docs/module_usage/instructions/benchmark.md

@@ -17,7 +17,9 @@ Benchmark 功能会统计模型在端到端推理过程中,所有操作的每
 * `PADDLE_PDX_INFER_BENCHMARK`:设置为 `True` 时则开启 benchmark 功能,默认为 `False`;
 * `PADDLE_PDX_INFER_BENCHMARK_WARMUP`:测试前的预热次数,默认为 `0`;
 * `PADDLE_PDX_INFER_BENCHMARK_ITERS`:测试的循环次数,默认为 `0`;
-* `PADDLE_PDX_INFER_BENCHMARK_OUTPUT_DIR`:保存指标的目录,如 `./benchmark`,默认为 `None`,表示不保存 benchmark 指标。
+* `PADDLE_PDX_INFER_BENCHMARK_OUTPUT_DIR`:保存指标的目录,如 `./benchmark`,默认为 `None`,表示不保存 benchmark 指标;
+* `PADDLE_PDX_INFER_BENCHMARK_USE_CACHE_FOR_READ`:设置为 `True` 时则对读取输入数据操作应用缓存机制,避免重复I/O开销,并且数据读取及缓存消耗的时间不记录到核心耗时中。默认为 `False`;
+* `PADDLE_PDX_INFER_BENCHMARK_USE_NEW_INFER_API`:设置为 `True` 时则使用新的推理API,可以看更细致的分阶段结果。默认为 `False`;
 
 **注意**:
 

+ 1 - 1
paddlex/inference/common/reader/det_3d_reader.py

@@ -81,7 +81,7 @@ class Sample(_EasyDict):
         self.attrs = None
 
 
-@benchmark.timeit
+@benchmark.timeit_with_options(name=None, is_read_operation=True)
 class ReadNuscenesData:
 
     def __init__(

+ 1 - 1
paddlex/inference/common/reader/image_reader.py

@@ -19,7 +19,7 @@ from ...utils.io import ImageReader, PDFReader
 from ...utils.benchmark import benchmark
 
 
-@benchmark.timeit
+@benchmark.timeit_with_options(name=None, is_read_operation=True)
 class ReadImage:
     """Load image from the file."""
 

+ 1 - 1
paddlex/inference/common/reader/ts_reader.py

@@ -19,7 +19,7 @@ from ...utils.io import CSVReader
 from ...utils.benchmark import benchmark
 
 
-@benchmark.timeit
+@benchmark.timeit_with_options(name=None, is_read_operation=True)
 class ReadTS:
 
     def __init__(self):

+ 1 - 1
paddlex/inference/common/reader/video_reader.py

@@ -19,7 +19,7 @@ from ...utils.io import VideoReader
 from ...utils.benchmark import benchmark
 
 
-@benchmark.timeit
+@benchmark.timeit_with_options(name=None, is_read_operation=True)
 class ReadVideo:
     """Load video from the file."""
 

+ 1 - 1
paddlex/inference/models/base/predictor/basic_predictor.py

@@ -104,7 +104,7 @@ class BasicPredictor(
         self.set_predictor(batch_size, device, pp_option)
         if INFER_BENCHMARK:
             # TODO(zhang-prog): Get metadata of input data
-            @benchmark.timeit_with_name(ENTRY_POINT_NAME)
+            @benchmark.timeit_with_options(name=ENTRY_POINT_NAME)
             def _apply(input, **kwargs):
                 return list(self.apply(input, **kwargs))
 

+ 19 - 5
paddlex/inference/models/common/static_infer.py

@@ -20,7 +20,11 @@ import numpy as np
 
 from ....utils import logging
 from ....utils.device import constr_device
-from ....utils.flags import DEBUG, USE_PIR_TRT
+from ....utils.flags import (
+    DEBUG,
+    USE_PIR_TRT,
+    INFER_BENCHMARK_USE_NEW_INFER_API,
+)
 from ...utils.benchmark import benchmark, set_inference_operations
 from ...utils.hpi import get_model_paths
 from ...utils.pp_option import PaddlePredictorOption
@@ -29,7 +33,14 @@ from ...utils.trt_config import TRT_CFG
 
 CACHE_DIR = ".cache"
 
-INFERENCE_OPERATIONS = ["PaddleCopyToDevice", "PaddleCopyToHost", "PaddleModelInfer"]
+if INFER_BENCHMARK_USE_NEW_INFER_API:
+    INFERENCE_OPERATIONS = [
+        "PaddleCopyToDevice",
+        "PaddleCopyToHost",
+        "PaddleModelInfer",
+    ]
+else:
+    INFERENCE_OPERATIONS = ["PaddleInferChainLegacy"]
 set_inference_operations(INFERENCE_OPERATIONS)
 
 
@@ -299,7 +310,7 @@ class StaticInfer(object):
         self.model_file_prefix = model_prefix
         self._option = option
         self.predictor = self._create()
-        if not self._use_legacy_api:
+        if self._use_new_inference_api:
             device_type = self._option.device_type
             device_type = "gpu" if device_type == "dcu" else device_type
             copy_to_device = PaddleCopyToDevice(device_type, self._option.device_id)
@@ -310,8 +321,11 @@ class StaticInfer(object):
             self.infer = PaddleInferChainLegacy(self.predictor)
 
     @property
-    def _use_legacy_api(self):
-        return self._option.device_type not in ("cpu", "gpu", "dcu")
+    def _use_new_inference_api(self):
+        # HACK: Temp fallback to legacy API via env var
+        return INFER_BENCHMARK_USE_NEW_INFER_API
+
+        # return self._option.device_type in ("cpu", "gpu", "dcu")
 
     def __call__(self, x: Sequence[np.ndarray]) -> List[np.ndarray]:
         names = self.predictor.get_input_names()

+ 52 - 30
paddlex/inference/utils/benchmark.py

@@ -21,7 +21,11 @@ import inspect
 import numpy as np
 from prettytable import PrettyTable
 
-from ...utils.flags import INFER_BENCHMARK, INFER_BENCHMARK_OUTPUT_DIR
+from ...utils.flags import (
+    INFER_BENCHMARK,
+    INFER_BENCHMARK_OUTPUT_DIR,
+    INFER_BENCHMARK_USE_CACHE_FOR_READ,
+)
 from ...utils import logging
 
 ENTRY_POINT_NAME = "_entry_point_"
@@ -36,9 +40,12 @@ class Benchmark:
         self._elapses = {}
         self._warmup = False
 
-    def timeit_with_name(self, name=None):
+    def timeit_with_options(self, name=None, is_read_operation=False):
         # TODO: Refactor
         def _deco(func_or_cls):
+            if not self._enabled:
+                return func_or_cls
+
             nonlocal name
             if name is None:
                 name = func_or_cls.__qualname__
@@ -52,34 +59,50 @@ class Benchmark:
                     raise TypeError
                 func = func_or_cls
 
-            location = None
-
-            @functools.wraps(func)
-            def _wrapper(*args, **kwargs):
-                nonlocal location
-
-                if not self._enabled:
-                    return func(*args, **kwargs)
-
-                if location is None:
-                    try:
-                        source_file = inspect.getsourcefile(func)
-                        source_line = inspect.getsourcelines(func)[1]
-                        location = f"{source_file}:{source_line}"
-                    except (TypeError, OSError) as e:
-                        location = "Unknown"
-                        logging.debug(
-                            f"Benchmark: failed to get source file and line number: {e}"
-                        )
-
-                tic = time.perf_counter()
-                output = func(*args, **kwargs)
-                if isinstance(output, GeneratorType):
-                    return self.watch_generator(output, f"{name}@{location}")
-                else:
-                    self._update(time.perf_counter() - tic, f"{name}@{location}")
+            try:
+                source_file = inspect.getsourcefile(func)
+                source_line = inspect.getsourcelines(func)[1]
+                location = f"{source_file}:{source_line}"
+            except (TypeError, OSError) as e:
+                location = "Unknown"
+                logging.debug(
+                    f"Benchmark: failed to get source file and line number: {e}"
+                )
+
+            use_cache = is_read_operation and INFER_BENCHMARK_USE_CACHE_FOR_READ
+            if use_cache:
+                if inspect.isgeneratorfunction(func):
+                    raise RuntimeError(
+                        "When `is_read_operation` is `True`, the wrapped function should not be a generator."
+                    )
+
+                func = functools.lru_cache(maxsize=128)(func)
+
+                @functools.wraps(func)
+                def _wrapper(*args, **kwargs):
+                    args = tuple(
+                        tuple(arg) if isinstance(arg, list) else arg for arg in args
+                    )
+                    kwargs = {
+                        k: tuple(v) if isinstance(v, list) else v
+                        for k, v in kwargs.items()
+                    }
+                    output = func(*args, **kwargs)
                     return output
 
+            else:
+
+                @functools.wraps(func)
+                def _wrapper(*args, **kwargs):
+                    operation_name = f"{name}@{location}"
+                    tic = time.perf_counter()
+                    output = func(*args, **kwargs)
+                    if isinstance(output, GeneratorType):
+                        return self.watch_generator(output, operation_name)
+                    else:
+                        self._update(time.perf_counter() - tic, operation_name)
+                        return output
+
             if isinstance(func_or_cls, type):
                 func_or_cls.__call__ = _wrapper
                 return func_or_cls
@@ -89,7 +112,7 @@ class Benchmark:
         return _deco
 
     def timeit(self, func_or_cls):
-        return self.timeit_with_name(None)(func_or_cls)
+        return self.timeit_with_options()(func_or_cls)
 
     def watch_generator(self, generator, name):
         @functools.wraps(generator)
@@ -156,7 +179,6 @@ class Benchmark:
         iters = len(base_predictor_time_list)
         instances = iters * batch_size
         summary["end_to_end"] = np.mean(base_predictor_time_list)
-
         detail_list = []
         operation_list = []
         op_tag = "preprocessing"

+ 7 - 1
paddlex/utils/flags.py

@@ -51,7 +51,7 @@ USE_PIR_TRT = get_flag_from_env_var("PADDLE_PDX_USE_PIR_TRT", False)
 DISABLE_DEV_MODEL_WL = get_flag_from_env_var("PADDLE_PDX_DISABLE_DEV_MODEL_WL", False)
 
 # Inference Benchmark
-INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", None)
+INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", False)
 INFER_BENCHMARK_WARMUP = get_flag_from_env_var(
     "PADDLE_PDX_INFER_BENCHMARK_WARMUP", 0, int
 )
@@ -61,3 +61,9 @@ INFER_BENCHMARK_OUTPUT_DIR = get_flag_from_env_var(
 INFER_BENCHMARK_ITERS = get_flag_from_env_var(
     "PADDLE_PDX_INFER_BENCHMARK_ITERS", 0, int
 )
+INFER_BENCHMARK_USE_CACHE_FOR_READ = get_flag_from_env_var(
+    "PADDLE_PDX_INFER_BENCHMARK_USE_CACHE_FOR_READ", False
+)
+INFER_BENCHMARK_USE_NEW_INFER_API = get_flag_from_env_var(
+    "PADDLE_PDX_INFER_BENCHMARK_USE_NEW_INFER_API", False
+)