benchmark.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import functools
  16. from types import GeneratorType
  17. import time
  18. import numpy as np
  19. from prettytable import PrettyTable
  20. from ...utils.flags import INFER_BENCHMARK_OUTPUT
  21. from ...utils import logging
  22. class Benchmark:
  23. def __init__(self, components):
  24. self._components = components
  25. self._warmup_start = None
  26. self._warmup_elapse = None
  27. self._warmup_num = None
  28. self._e2e_tic = None
  29. self._e2e_elapse = None
  30. def start(self):
  31. self._warmup_start = time.time()
  32. self._reset()
  33. def warmup_stop(self, warmup_num):
  34. self._warmup_elapse = time.time() - self._warmup_start
  35. self._warmup_num = warmup_num
  36. self._reset()
  37. def _reset(self):
  38. for name in self._components:
  39. cmp = self._components[name]
  40. cmp.timer.reset()
  41. self._e2e_tic = time.time()
  42. def gather(self, e2e_num):
  43. # lazy import for avoiding circular import
  44. from ..components.paddle_predictor import BasePaddlePredictor
  45. detail = []
  46. summary = {"preprocess": 0, "inference": 0, "postprocess": 0}
  47. op_tag = "preprocess"
  48. for name in self._components:
  49. cmp = self._components[name]
  50. times = cmp.timer.logs
  51. counts = len(times)
  52. avg = np.mean(times)
  53. total = np.sum(times)
  54. detail.append((name, total, counts, avg))
  55. if isinstance(cmp, BasePaddlePredictor):
  56. summary["inference"] += total
  57. op_tag = "postprocess"
  58. else:
  59. summary[op_tag] += total
  60. summary = [
  61. (
  62. "PreProcess",
  63. summary["preprocess"],
  64. e2e_num,
  65. summary["preprocess"] / e2e_num,
  66. ),
  67. (
  68. "Inference",
  69. summary["inference"],
  70. e2e_num,
  71. summary["inference"] / e2e_num,
  72. ),
  73. (
  74. "PostProcess",
  75. summary["postprocess"],
  76. e2e_num,
  77. summary["postprocess"] / e2e_num,
  78. ),
  79. ("End2End", self._e2e_elapse, e2e_num, self._e2e_elapse / e2e_num),
  80. ]
  81. if self._warmup_elapse:
  82. summary.append(
  83. (
  84. "WarmUp",
  85. self._warmup_elapse,
  86. self._warmup_num,
  87. self._warmup_elapse / self._warmup_num,
  88. )
  89. )
  90. return detail, summary
  91. def collect(self, e2e_num):
  92. self._e2e_elapse = time.time() - self._e2e_tic
  93. detail, summary = self.gather(e2e_num)
  94. table = PrettyTable(
  95. ["Component", "Total Time (ms)", "Call Counts", "Avg Time Per Call (ms)"]
  96. )
  97. table.add_rows(
  98. [
  99. (name, f"{total * 1000:.8f}", cnts, f"{avg * 1000:.8f}")
  100. for name, total, cnts, avg in detail
  101. ]
  102. )
  103. logging.info(table)
  104. table = PrettyTable(
  105. [
  106. "Stage",
  107. "Total Time (ms)",
  108. "Num of Instances",
  109. "Avg Time Per Instance (ms)",
  110. ]
  111. )
  112. table.add_rows(
  113. [
  114. (name, f"{total * 1000:.8f}", cnts, f"{avg * 1000:.8f}")
  115. for name, total, cnts, avg in summary
  116. ]
  117. )
  118. logging.info(table)
  119. if INFER_BENCHMARK_OUTPUT:
  120. csv_data = [["Stage", "Total Time", "Num", "Avg Time"]]
  121. csv_data.extend(detail)
  122. csv_data.extend(summary)
  123. with open("benchmark.csv", "w", newline="") as file:
  124. writer = csv.writer(file)
  125. writer.writerows(csv_data)
  126. class Timer:
  127. def __init__(self):
  128. self._tic = None
  129. self._elapses = []
  130. def watch_func(self, func):
  131. @functools.wraps(func)
  132. def wrapper(*args, **kwargs):
  133. tic = time.time()
  134. output = func(*args, **kwargs)
  135. if isinstance(output, GeneratorType):
  136. return self.watch_generator(output)
  137. else:
  138. self._update(time.time() - tic)
  139. return output
  140. return wrapper
  141. def watch_generator(self, generator):
  142. @functools.wraps(generator)
  143. def wrapper():
  144. while 1:
  145. try:
  146. tic = time.time()
  147. item = next(generator)
  148. self._update(time.time() - tic)
  149. yield item
  150. except StopIteration:
  151. break
  152. return wrapper()
  153. def reset(self):
  154. self._tic = None
  155. self._elapses = []
  156. def _update(self, elapse):
  157. self._elapses.append(elapse)
  158. @property
  159. def logs(self):
  160. return self._elapses