ソースを参照

support benchmark

gaotingquan 1 年間 前
コミット
f20662bc7e

+ 59 - 0
docs/module_usage/instructions/benchmark.md

@@ -0,0 +1,59 @@
+# 模型推理 Benchmark
+
+PaddleX 支持统计模型推理耗时,需通过环境变量进行设置,具体如下:
+
+* `PADDLE_PDX_INFER_BENCHMARK`:设置为 `True` 时则开启 Benchmark,默认为 `False`;
+* `PADDLE_PDX_INFER_BENCHMARK_WARMUP`:设置 warm up,在开始测试前,使用随机数据循环迭代 n 次,默认为 `0`;
+* `PADDLE_PDX_INFER_BENCHMARK_DATA_SIZE`: 设置随机数据的尺寸,默认为 `224`;
+* `PADDLE_PDX_INFER_BENCHMARK_ITER`:使用随机数据进行 Benchmark 测试的循环次数,仅当输入数据为 `None` 时,将使用随机数据进行测试;
+* `PADDLE_PDX_INFER_BENCHMARK_OUTPUT`:用于设置保存本次 benchmark 指标到 `txt` 文件,如 `./benchmark.txt`,默认为 `None`,表示不保存 Benchmark 指标;
+
+使用示例如下:
+
+```bash
+PADDLE_PDX_INFER_BENCHMARK=True \
+PADDLE_PDX_INFER_BENCHMARK_WARMUP=5 \
+PADDLE_PDX_INFER_BENCHMARK_DATA_SIZE=320 \
+PADDLE_PDX_INFER_BENCHMARK_ITER=10 \
+PADDLE_PDX_INFER_BENCHMARK_OUTPUT=./benchmark.txt \
+python main.py \
+    -c ./paddlex/configs/object_detection/PicoDet-XS.yaml \
+    -o Global.mode=predict \
+    -o Predict.model_dir=None \
+    -o Predict.input=None
+```
+
+在开启 Benchmark 后,将自动打印 benchmark 指标:
+
+```
++-------------------+--------+------------------+
+|     Component     | Counts | Average Time(ms) |
++-------------------+--------+------------------+
+|      ReadCmp      |   10   |    7.86035061    |
+|       Resize      |   10   |    1.38545036    |
+|     Normalize     |   10   |    3.77433300    |
+|     ToCHWImage    |   10   |    0.00545979    |
+| ImageDetPredictor |   10   |   14.97282982    |
+|   DetPostProcess  |   10   |    0.06134510    |
+|  ***************  | ****** | ***************  |
+|     PreProcess    |   \    |   13.02559376    |
+|     Inference     |   \    |   14.97282982    |
+|    PostProcess    |   \    |    0.06134510    |
++-------------------+--------+------------------+
+```
+
+在 Benchmark 结果中,会统计该模型全部组件(`Component`)的平均执行耗时(`Average Time`,单位为“毫秒”)和调用次数(`Counts`),以及按预处理(`PreProcess`)、模型推理(`Inference`)和后处理(`PostProcess`)汇总得到的执行耗时,同时,保存相关指标会到本地 `./benchmark.txt` 文件中:
+
+```
+Component, Counts, Average Time(ms)
+ReadCmp, 10, 7.860350608825682706
+Resize, 10, 1.385450363159179688
+Normalize, 10, 3.774333000183105469
+ToCHWImage, 10, 0.005459785461425781
+ImageDetPredictor, 10, 14.972829818725585938
+DetPostProcess, 10, 0.061345100402832031
+***************, ***, ***************
+PreProcess, \, 13.025593757629394531
+Inference, \, 14.972829818725585938
+PostProcess, \, 0.061345100402832031
+```

+ 2 - 2
paddlex/engine.py

@@ -48,7 +48,7 @@ class Engine(object):
         elif self._mode == "predict":
             for res in self._model.predict():
                 res.print(json_format=False)
-            if self._output:
-                res.save_all(save_path=self._output)
+                if self._output:
+                    res.save_all(save_path=self._output)
         else:
             raise_unsupported_api_error(f"{self._mode}", self.__class__)

+ 6 - 0
paddlex/inference/components/base.py

@@ -17,7 +17,9 @@ from abc import ABC, abstractmethod
 from copy import deepcopy
 from types import GeneratorType
 
+from ...utils.flags import INFER_BENCHMARK
 from ...utils import logging
+from ..utils.benchmark import Timer
 
 
 class BaseComponent(ABC):
@@ -33,6 +35,10 @@ class BaseComponent(ABC):
         self.inputs = self.DEAULT_INPUTS if hasattr(self, "DEAULT_INPUTS") else {}
         self.outputs = self.DEAULT_OUTPUTS if hasattr(self, "DEAULT_OUTPUTS") else {}
 
+        if INFER_BENCHMARK:
+            self.timer = Timer()
+            self.apply = self.timer.watch_func(self.apply)
+
     def __call__(self, input_list):
         # use list type for batched data
         if not isinstance(input_list, list):

+ 1 - 0
paddlex/inference/components/paddle_predictor/__init__.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from .predictor import (
+    BasePaddlePredictor,
     ImagePredictor,
     ImageDetPredictor,
     TSPPPredictor,

+ 23 - 10
paddlex/inference/components/transforms/image/common.py

@@ -19,6 +19,7 @@ from copy import deepcopy
 import numpy as np
 import cv2
 
+from .....utils.flags import INFER_BENCHMARK, INFER_BENCHMARK_DATA_SIZE
 from .....utils.cache import CACHE_DIR, temp_file_manager
 from ....utils.io import ImageReader, ImageWriter, PDFReader
 from ...base import BaseComponent
@@ -89,21 +90,33 @@ class ReadImage(_BaseRead):
 
     def apply(self, img):
         """apply"""
-        if isinstance(img, np.ndarray):
+
+        def process_ndarray(img):
             with temp_file_manager.temp_file_context(suffix=".png") as temp_file:
                 img_path = Path(temp_file.name)
                 self._writer.write(img_path, img)
                 if self.format == "RGB":
                     img = img[:, :, ::-1]
-                yield [
-                    {
-                        "input_path": img_path,
-                        "img": img,
-                        "img_size": [img.shape[1], img.shape[0]],
-                        "ori_img": deepcopy(img),
-                        "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
-                    }
-                ]
+                return {
+                    "input_path": img_path,
+                    "img": img,
+                    "img_size": [img.shape[1], img.shape[0]],
+                    "ori_img": deepcopy(img),
+                    "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
+                }
+
+        if INFER_BENCHMARK and img is None:
+            size = int(INFER_BENCHMARK_DATA_SIZE)
+            yield [
+                process_ndarray(
+                    np.random.randint(0, 256, (size, size, 3), dtype=np.uint8)
+                )
+                for _ in range(self.batch_size)
+            ]
+
+        elif isinstance(img, np.ndarray):
+            yield [process_ndarray(img)]
+
         elif isinstance(img, str):
             file_path = img
             file_path = self._download_from_url(file_path)

+ 2 - 0
paddlex/inference/models/base/base_predictor.py

@@ -41,6 +41,8 @@ class BasePredictor(BaseComponent):
         # alias predict() to the __call__()
         self.predict = self.__call__
 
+        self.benchmark = None
+
     def __call__(self, input, **kwargs):
         self.set_predictor(**kwargs)
         for res in super().__call__(input):

+ 23 - 0
paddlex/inference/models/base/basic_predictor.py

@@ -16,10 +16,16 @@ from abc import abstractmethod
 import inspect
 
 from ....utils.subclass_register import AutoRegisterABCMetaClass
+from ....utils.flags import (
+    INFER_BENCHMARK,
+    INFER_BENCHMARK_WARMUP,
+    INFER_BENCHMARK_ITER,
+)
 from ....utils import logging
 from ...components.base import BaseComponent, ComponentsEngine
 from ...utils.pp_option import PaddlePredictorOption
 from ...utils.process_hook import generatorable_method
+from ...utils.benchmark import Benchmark
 from .base_predictor import BasePredictor
 
 
@@ -43,6 +49,23 @@ class BasicPredictor(
         self.engine = ComponentsEngine(self.components)
         logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
 
+        if INFER_BENCHMARK:
+            self.benchmark = Benchmark(self.components)
+
+    def __call__(self, input, **kwargs):
+        if self.benchmark:
+            for _ in range(INFER_BENCHMARK_WARMUP):
+                list(super().__call__(None))
+            self.benchmark.reset()
+            if input is None:
+                for _ in range(INFER_BENCHMARK_ITER):
+                    list(super().__call__(input))
+            else:
+                list(super().__call__(input))
+            self.benchmark.collect()
+        else:
+            yield from super().__call__(input)
+
     def apply(self, input):
         """predict"""
         yield from self._generate_res(self.engine(input))

+ 122 - 0
paddlex/inference/utils/benchmark.py

@@ -0,0 +1,122 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+from types import GeneratorType
+import time
+import numpy as np
+from prettytable import PrettyTable
+
+from ...utils.flags import INFER_BENCHMARK_OUTPUT
+from ...utils import logging
+
+
+class Benchmark:
+    def __init__(self, components):
+        self._components = components
+
+    def reset(self):
+        for name in self._components:
+            cmp = self._components[name]
+            cmp.timer.reset()
+
+    def gather(self):
+        # lazy import for avoiding circular import
+        from ..components.paddle_predictor import BasePaddlePredictor
+
+        detail = []
+        summary = {"preprocess": 0, "inference": 0, "postprocess": 0}
+        op_tag = "preprocess"
+        for name in self._components:
+            cmp = self._components[name]
+            times = cmp.timer.logs
+            counts = len(times)
+            avg = np.mean(times) * 1000
+            detail.append((name, counts, avg))
+            if isinstance(cmp, BasePaddlePredictor):
+                summary["inference"] += avg
+                op_tag = "postprocess"
+            else:
+                summary[op_tag] += avg
+        return detail, summary
+
+    def collect(self):
+        detail, summary = self.gather()
+        table = PrettyTable(["Component", "Counts", "Average Time(ms)"])
+        table.add_rows([(name, cnts, f"{avg:.8f}") for name, cnts, avg in detail])
+        table.add_row(("***************", "******", "***************"))
+        table.add_row(("PreProcess", "\\", f"{summary['preprocess']:.8f}"))
+        table.add_row(("Inference", "\\", f"{summary['inference']:.8f}"))
+        table.add_row(("PostProcess", "\\", f"{summary['postprocess']:.8f}"))
+        logging.info(table)
+
+        if INFER_BENCHMARK_OUTPUT:
+            str_ = "Component, Counts, Average Time(ms)\n"
+            str_ += "\n".join(
+                [f"{name}, {cnts}, {avg:.18f}" for name, cnts, avg in detail]
+            )
+            str_ += "\n***************, ***, ***************\n"
+            str_ += "\n".join(
+                [
+                    f"PreProcess, \, {summary['preprocess']:.18f}",
+                    f"Inference, \, {summary['inference']:.18f}",
+                    f"PostProcess, \, {summary['postprocess']:.18f}",
+                ]
+            )
+            with open(INFER_BENCHMARK_OUTPUT, "w") as f:
+                f.write(str_)
+
+
+class Timer:
+    def __init__(self):
+        self._tic = None
+        self._elapses = []
+
+    def watch_func(self, func):
+        @functools.wraps(func)
+        def wrapper(*args, **kwargs):
+            tic = time.time()
+            output = func(*args, **kwargs)
+            if isinstance(output, GeneratorType):
+                return self.watch_generator(output)
+            else:
+                self._update(time.time() - tic)
+                return output
+
+        return wrapper
+
+    def watch_generator(self, generator):
+        @functools.wraps(generator)
+        def wrapper():
+            while 1:
+                try:
+                    tic = time.time()
+                    item = next(generator)
+                    self._update(time.time() - tic)
+                    yield item
+                except StopIteration:
+                    break
+
+        return wrapper()
+
+    def reset(self):
+        self._tic = None
+        self._elapses = []
+
+    def _update(self, elapse):
+        self._elapses.append(elapse)
+
+    @property
+    def logs(self):
+        return self._elapses

+ 25 - 6
paddlex/utils/flags.py

@@ -20,19 +20,25 @@ __all__ = [
     "DRY_RUN",
     "CHECK_OPTS",
     "EAGER_INITIALIZATION",
+    "INFER_BENCHMARK",
+    "INFER_BENCHMARK_ITER",
+    "INFER_BENCHMARK_WARMUP",
+    "INFER_BENCHMARK_OUTPUT",
+    "INFER_BENCHMARK_DATA_SIZE",
     "FLAGS_json_format_model",
 ]
 
 
-def get_flag_from_env_var(name, default):
+def get_flag_from_env_var(name, default, format_func=str):
     """get_flag_from_env_var"""
-    env_var = os.environ.get(name, None)
-    if env_var in ("True", "true", "TRUE", "1"):
+    env_var = os.environ.get(name, default)
+    if env_var in (True, "True", "true", "TRUE", "1"):
         return True
-    elif env_var in ("False", "false", "FALSE", "0"):
+    elif env_var in (False, "False", "false", "FALSE", "0"):
         return False
-    else:
-        return default
+    elif env_var in (None, "None", "none", "Null", "null"):
+        return None
+    return format_func(env_var)
 
 
 DEBUG = get_flag_from_env_var("PADDLE_PDX_DEBUG", False)
@@ -40,3 +46,16 @@ DRY_RUN = get_flag_from_env_var("PADDLE_PDX_DRY_RUN", False)
 CHECK_OPTS = get_flag_from_env_var("PADDLE_PDX_CHECK_OPTS", False)
 EAGER_INITIALIZATION = get_flag_from_env_var("PADDLE_PDX_EAGER_INIT", True)
 FLAGS_json_format_model = get_flag_from_env_var("FLAGS_json_format_model", None)
+
+# Inference Benchmark
+INFER_BENCHMARK = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK", None)
+INFER_BENCHMARK_WARMUP = get_flag_from_env_var(
+    "PADDLE_PDX_INFER_BENCHMARK_WARMUP", 0, int
+)
+INFER_BENCHMARK_OUTPUT = get_flag_from_env_var(
+    "PADDLE_PDX_INFER_BENCHMARK_OUTPUT", None
+)
+INFER_BENCHMARK_ITER = get_flag_from_env_var("PADDLE_PDX_INFER_BENCHMARK_ITER", 10, int)
+INFER_BENCHMARK_DATA_SIZE = get_flag_from_env_var(
+    "PADDLE_PDX_INFER_BENCHMARK_DATA_SIZE", 1024
+)

+ 1 - 0
requirements.txt

@@ -1,3 +1,4 @@
+prettytable # only for benchmark
 imagesize
 colorlog
 PyYAML