benchmark.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import functools
  16. from types import GeneratorType
  17. import time
  18. import numpy as np
  19. from prettytable import PrettyTable
  20. from ...utils.flags import INFER_BENCHMARK_OUTPUT
  21. from ...utils import logging
  22. class Benchmark:
  23. def __init__(self, components):
  24. self._components = components
  25. self._warmup_start = None
  26. self._warmup_elapse = None
  27. self._warmup_num = None
  28. self._e2e_tic = None
  29. self._e2e_elapse = None
  30. def start(self):
  31. self._warmup_start = time.time()
  32. self._reset()
  33. def warmup_stop(self, warmup_num):
  34. self._warmup_elapse = time.time() - self._warmup_start
  35. self._warmup_num = warmup_num
  36. self._reset()
  37. def _reset(self):
  38. for name, cmp in self.iterate_cmp(self._components):
  39. cmp.timer.reset()
  40. self._e2e_tic = time.time()
  41. def iterate_cmp(self, cmps):
  42. if cmps is None:
  43. return
  44. for name, cmp in cmps.items():
  45. if cmp.sub_cmps is not None:
  46. yield from self.iterate_cmp(cmp.sub_cmps)
  47. yield name, cmp
  48. def gather(self, e2e_num):
  49. # lazy import for avoiding circular import
  50. from ..components.paddle_predictor import BasePaddlePredictor
  51. detail = []
  52. summary = {"preprocess": 0, "inference": 0, "postprocess": 0}
  53. op_tag = "preprocess"
  54. for name, cmp in self._components.items():
  55. if isinstance(cmp, BasePaddlePredictor):
  56. # TODO(gaotingquan): show by hierarchy. Now dont show xxxPredictor benchmark info to ensure mutual exclusivity between components.
  57. for name, sub_cmp in cmp.sub_cmps.items():
  58. times = sub_cmp.timer.logs
  59. counts = len(times)
  60. avg = np.mean(times)
  61. total = np.sum(times)
  62. detail.append((name, total, counts, avg))
  63. summary["inference"] += total
  64. op_tag = "postprocess"
  65. else:
  66. times = cmp.timer.logs
  67. counts = len(times)
  68. avg = np.mean(times)
  69. total = np.sum(times)
  70. detail.append((name, total, counts, avg))
  71. summary[op_tag] += total
  72. summary = [
  73. (
  74. "PreProcess",
  75. summary["preprocess"],
  76. e2e_num,
  77. summary["preprocess"] / e2e_num,
  78. ),
  79. (
  80. "Inference",
  81. summary["inference"],
  82. e2e_num,
  83. summary["inference"] / e2e_num,
  84. ),
  85. (
  86. "PostProcess",
  87. summary["postprocess"],
  88. e2e_num,
  89. summary["postprocess"] / e2e_num,
  90. ),
  91. ("End2End", self._e2e_elapse, e2e_num, self._e2e_elapse / e2e_num),
  92. ]
  93. if self._warmup_elapse:
  94. summary.append(
  95. (
  96. "WarmUp",
  97. self._warmup_elapse,
  98. self._warmup_num,
  99. self._warmup_elapse / self._warmup_num,
  100. )
  101. )
  102. return detail, summary
  103. def collect(self, e2e_num):
  104. self._e2e_elapse = time.time() - self._e2e_tic
  105. detail, summary = self.gather(e2e_num)
  106. table_head = ["Stage", "Total Time (ms)", "Nums", "Avg Time (ms)"]
  107. table = PrettyTable(table_head)
  108. table.add_rows(
  109. [
  110. (name, f"{total * 1000:.8f}", cnts, f"{avg * 1000:.8f}")
  111. for name, total, cnts, avg in detail
  112. ]
  113. )
  114. logging.info(table)
  115. table = PrettyTable(table_head)
  116. table.add_rows(
  117. [
  118. (name, f"{total * 1000:.8f}", cnts, f"{avg * 1000:.8f}")
  119. for name, total, cnts, avg in summary
  120. ]
  121. )
  122. logging.info(table)
  123. if INFER_BENCHMARK_OUTPUT:
  124. csv_data = [table_head]
  125. csv_data.extend(detail)
  126. csv_data.extend(summary)
  127. with open("benchmark.csv", "w", newline="") as file:
  128. writer = csv.writer(file)
  129. writer.writerows(csv_data)
  130. class Timer:
  131. def __init__(self):
  132. self._tic = None
  133. self._elapses = []
  134. def watch_func(self, func):
  135. @functools.wraps(func)
  136. def wrapper(*args, **kwargs):
  137. tic = time.time()
  138. output = func(*args, **kwargs)
  139. if isinstance(output, GeneratorType):
  140. return self.watch_generator(output)
  141. else:
  142. self._update(time.time() - tic)
  143. return output
  144. return wrapper
  145. def watch_generator(self, generator):
  146. @functools.wraps(generator)
  147. def wrapper():
  148. while 1:
  149. try:
  150. tic = time.time()
  151. item = next(generator)
  152. self._update(time.time() - tic)
  153. yield item
  154. except StopIteration:
  155. break
  156. return wrapper()
  157. def reset(self):
  158. self._tic = None
  159. self._elapses = []
  160. def _update(self, elapse):
  161. self._elapses.append(elapse)
  162. @property
  163. def logs(self):
  164. return self._elapses