benchmark.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import functools
  16. from types import GeneratorType
  17. import time
  18. from pathlib import Path
  19. import numpy as np
  20. from prettytable import PrettyTable
  21. from ...utils.flags import INFER_BENCHMARK_OUTPUT
  22. from ...utils import logging
  23. class Benchmark:
  24. def __init__(self, components):
  25. self._components = components
  26. self._warmup_start = None
  27. self._warmup_elapse = None
  28. self._warmup_num = None
  29. self._e2e_tic = None
  30. self._e2e_elapse = None
  31. def start(self):
  32. self._warmup_start = time.time()
  33. self._reset()
  34. def warmup_stop(self, warmup_num):
  35. self._warmup_elapse = time.time() - self._warmup_start
  36. self._warmup_num = warmup_num
  37. self._reset()
  38. def _reset(self):
  39. for name, cmp in self.iterate_cmp(self._components):
  40. cmp.timer.reset()
  41. self._e2e_tic = time.time()
  42. def iterate_cmp(self, cmps):
  43. if cmps is None:
  44. return
  45. for name, cmp in cmps.items():
  46. if cmp.sub_cmps is not None:
  47. yield from self.iterate_cmp(cmp.sub_cmps)
  48. yield name, cmp
  49. def gather(self, e2e_num):
  50. # lazy import for avoiding circular import
  51. from ..components.paddle_predictor import BasePaddlePredictor
  52. detail = []
  53. summary = {"preprocess": 0, "inference": 0, "postprocess": 0}
  54. op_tag = "preprocess"
  55. for name, cmp in self._components.items():
  56. if isinstance(cmp, BasePaddlePredictor):
  57. # TODO(gaotingquan): show by hierarchy. Now dont show xxxPredictor benchmark info to ensure mutual exclusivity between components.
  58. for name, sub_cmp in cmp.sub_cmps.items():
  59. times = sub_cmp.timer.logs
  60. counts = len(times)
  61. avg = np.mean(times) * 1000
  62. total = np.sum(times) * 1000
  63. detail.append((name, total, counts, avg))
  64. summary["inference"] += total
  65. op_tag = "postprocess"
  66. else:
  67. times = cmp.timer.logs
  68. counts = len(times)
  69. avg = np.mean(times) * 1000
  70. total = np.sum(times) * 1000
  71. detail.append((name, total, counts, avg))
  72. summary[op_tag] += total
  73. summary = [
  74. (
  75. "PreProcess",
  76. summary["preprocess"],
  77. e2e_num,
  78. summary["preprocess"] / e2e_num,
  79. ),
  80. (
  81. "Inference",
  82. summary["inference"],
  83. e2e_num,
  84. summary["inference"] / e2e_num,
  85. ),
  86. (
  87. "PostProcess",
  88. summary["postprocess"],
  89. e2e_num,
  90. summary["postprocess"] / e2e_num,
  91. ),
  92. ("End2End", self._e2e_elapse, e2e_num, self._e2e_elapse / e2e_num),
  93. ]
  94. if self._warmup_elapse:
  95. summary.append(
  96. (
  97. "WarmUp",
  98. self._warmup_elapse,
  99. self._warmup_num,
  100. self._warmup_elapse / self._warmup_num,
  101. )
  102. )
  103. return detail, summary
  104. def collect(self, e2e_num):
  105. self._e2e_elapse = time.time() - self._e2e_tic
  106. detail, summary = self.gather(e2e_num)
  107. detail_head = [
  108. "Component",
  109. "Total Time (ms)",
  110. "Number of Calls",
  111. "Avg Time Per Call (ms)",
  112. ]
  113. table = PrettyTable(detail_head)
  114. table.add_rows(
  115. [
  116. (name, f"{total:.8f}", cnts, f"{avg:.8f}")
  117. for name, total, cnts, avg in detail
  118. ]
  119. )
  120. logging.info(table)
  121. summary_head = [
  122. "Stage",
  123. "Total Time (ms)",
  124. "Number of Instances",
  125. "Avg Time Per Instance (ms)",
  126. ]
  127. table = PrettyTable(summary_head)
  128. table.add_rows(
  129. [
  130. (name, f"{total:.8f}", cnts, f"{avg:.8f}")
  131. for name, total, cnts, avg in summary
  132. ]
  133. )
  134. logging.info(table)
  135. if INFER_BENCHMARK_OUTPUT:
  136. save_dir = Path(INFER_BENCHMARK_OUTPUT)
  137. save_dir.mkdir(parents=True, exist_ok=True)
  138. csv_data = [detail_head, *detail]
  139. # csv_data.extend(detail)
  140. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  141. writer = csv.writer(file)
  142. writer.writerows(csv_data)
  143. csv_data = [summary_head, *summary]
  144. # csv_data.extend(summary)
  145. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  146. writer = csv.writer(file)
  147. writer.writerows(csv_data)
  148. class Timer:
  149. def __init__(self):
  150. self._tic = None
  151. self._elapses = []
  152. def watch_func(self, func):
  153. @functools.wraps(func)
  154. def wrapper(*args, **kwargs):
  155. tic = time.time()
  156. output = func(*args, **kwargs)
  157. if isinstance(output, GeneratorType):
  158. return self.watch_generator(output)
  159. else:
  160. self._update(time.time() - tic)
  161. return output
  162. return wrapper
  163. def watch_generator(self, generator):
  164. @functools.wraps(generator)
  165. def wrapper():
  166. while 1:
  167. try:
  168. tic = time.time()
  169. item = next(generator)
  170. self._update(time.time() - tic)
  171. yield item
  172. except StopIteration:
  173. break
  174. return wrapper()
  175. def reset(self):
  176. self._tic = None
  177. self._elapses = []
  178. def _update(self, elapse):
  179. self._elapses.append(elapse)
  180. @property
  181. def logs(self):
  182. return self._elapses