benchmark.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import functools
  16. from types import GeneratorType
  17. import time
  18. from pathlib import Path
  19. import numpy as np
  20. from prettytable import PrettyTable
  21. from ...utils.flags import INFER_BENCHMARK, INFER_BENCHMARK_OUTPUT
  22. from ...utils import logging
  23. class Benchmark:
  24. def __init__(self, enabled):
  25. self._enabled = enabled
  26. self._elapses = {}
  27. self._warmup = False
  28. def timeit(self, func):
  29. @functools.wraps(func)
  30. def wrapper(*args, **kwargs):
  31. if not self._enabled:
  32. return func(*args, **kwargs)
  33. name = func.__qualname__
  34. tic = time.time()
  35. output = func(*args, **kwargs)
  36. if isinstance(output, GeneratorType):
  37. return self.watch_generator(output, name)
  38. else:
  39. self._update(time.time() - tic, name)
  40. return output
  41. return wrapper
  42. def watch_generator(self, generator, name):
  43. @functools.wraps(generator)
  44. def wrapper():
  45. while True:
  46. try:
  47. tic = time.time()
  48. item = next(generator)
  49. self._update(time.time() - tic, name)
  50. yield item
  51. except StopIteration:
  52. break
  53. return wrapper()
  54. def reset(self):
  55. self._elapses = {}
  56. def _update(self, elapse, name):
  57. elapse = elapse * 1000
  58. if name in self._elapses:
  59. self._elapses[name].append(elapse)
  60. else:
  61. self._elapses[name] = [elapse]
  62. @property
  63. def logs(self):
  64. return self._elapses
  65. def start_timing(self):
  66. self._enabled = True
  67. def stop_timing(self):
  68. self._enabled = False
  69. def start_warmup(self):
  70. self._warmup = True
  71. def stop_warmup(self):
  72. self._warmup = False
  73. self.reset()
  74. def gather(self, batch_size):
  75. logs = {k.split(".")[0]: v for k, v in self.logs.items()}
  76. iters = len(logs["Infer"])
  77. instances = iters * batch_size
  78. detail_list = []
  79. summary = {"preprocess": 0, "inference": 0, "postprocess": 0}
  80. op_tag = "preprocess"
  81. for name, time_list in logs.items():
  82. avg = np.mean(time_list)
  83. detail_list.append(
  84. (iters, batch_size, instances, name, avg, avg / batch_size)
  85. )
  86. if name in ["Copy2GPU", "Infer", "Copy2CPU"]:
  87. summary["inference"] += avg
  88. op_tag = "postprocess"
  89. else:
  90. summary[op_tag] += avg
  91. summary["end2end"] = (
  92. summary["preprocess"] + summary["inference"] + summary["postprocess"]
  93. )
  94. summary_list = [
  95. (
  96. iters,
  97. batch_size,
  98. instances,
  99. "PreProcess",
  100. summary["preprocess"],
  101. summary["preprocess"] / batch_size,
  102. ),
  103. (
  104. iters,
  105. batch_size,
  106. instances,
  107. "Inference",
  108. summary["inference"],
  109. summary["inference"] / batch_size,
  110. ),
  111. (
  112. iters,
  113. batch_size,
  114. instances,
  115. "PostProcess",
  116. summary["postprocess"],
  117. summary["postprocess"] / batch_size,
  118. ),
  119. (
  120. iters,
  121. batch_size,
  122. instances,
  123. "End2End",
  124. summary["end2end"],
  125. summary["end2end"] / batch_size,
  126. ),
  127. ]
  128. return detail_list, summary_list
  129. def collect(self, batch_size):
  130. detail_list, summary_list = self.gather(batch_size)
  131. if self._warmup:
  132. summary_head = [
  133. "Iters",
  134. "Batch Size",
  135. "Instances",
  136. "Stage",
  137. "Avg Time Per Iter (ms)",
  138. "Avg Time Per Instance (ms)",
  139. ]
  140. table = PrettyTable(summary_head)
  141. summary_list = [
  142. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  143. ]
  144. table.add_rows(summary_list)
  145. header = "WarmUp Data".center(len(str(table).split("\n")[0]), " ")
  146. logging.info(header)
  147. logging.info(table)
  148. else:
  149. detail_head = [
  150. "Iters",
  151. "Batch Size",
  152. "Instances",
  153. "Operation",
  154. "Avg Time Per Iter (ms)",
  155. "Avg Time Per Instance (ms)",
  156. ]
  157. table = PrettyTable(detail_head)
  158. detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
  159. table.add_rows(detail_list)
  160. header = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  161. logging.info(header)
  162. logging.info(table)
  163. summary_head = [
  164. "Iters",
  165. "Batch Size",
  166. "Instances",
  167. "Stage",
  168. "Avg Time Per Iter (ms)",
  169. "Avg Time Per Instance (ms)",
  170. ]
  171. table = PrettyTable(summary_head)
  172. summary_list = [
  173. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  174. ]
  175. table.add_rows(summary_list)
  176. header = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  177. logging.info(header)
  178. logging.info(table)
  179. if INFER_BENCHMARK_OUTPUT:
  180. save_dir = Path(INFER_BENCHMARK_OUTPUT)
  181. save_dir.mkdir(parents=True, exist_ok=True)
  182. csv_data = [detail_head, *detail_list]
  183. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  184. writer = csv.writer(file)
  185. writer.writerows(csv_data)
  186. csv_data = [summary_head, *summary_list]
  187. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  188. writer = csv.writer(file)
  189. writer.writerows(csv_data)
  190. if INFER_BENCHMARK:
  191. benchmark = Benchmark(enabled=True)
  192. else:
  193. benchmark = Benchmark(enabled=False)