benchmark.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import functools
  16. from types import GeneratorType
  17. import time
  18. from pathlib import Path
  19. import numpy as np
  20. from prettytable import PrettyTable
  21. from ...utils.flags import INFER_BENCHMARK, INFER_BENCHMARK_OUTPUT_DIR
  22. from ...utils import logging
  23. ENTRY_POINT_NAME = "_entry_point_"
  24. # XXX: Global mutable state
  25. _inference_operations = []
  26. class Benchmark:
  27. def __init__(self, enabled):
  28. self._enabled = enabled
  29. self._elapses = {}
  30. self._warmup = False
  31. def timeit_with_name(self, name=None):
  32. # TODO: Refactor
  33. def _deco(func_or_cls):
  34. nonlocal name
  35. if name is None:
  36. name = func_or_cls.__qualname__
  37. if isinstance(func_or_cls, type):
  38. if not hasattr(func_or_cls, "__call__"):
  39. raise TypeError
  40. func = func_or_cls.__call__
  41. else:
  42. if not callable(func_or_cls):
  43. raise TypeError
  44. func = func_or_cls
  45. @functools.wraps(func)
  46. def _wrapper(*args, **kwargs):
  47. if not self._enabled:
  48. return func(*args, **kwargs)
  49. tic = time.perf_counter()
  50. output = func(*args, **kwargs)
  51. if isinstance(output, GeneratorType):
  52. return self.watch_generator(output, name)
  53. else:
  54. self._update(time.perf_counter() - tic, name)
  55. return output
  56. if isinstance(func_or_cls, type):
  57. func_or_cls.__call__ = _wrapper
  58. return func_or_cls
  59. else:
  60. return _wrapper
  61. return _deco
  62. def timeit(self, func_or_cls):
  63. return self.timeit_with_name(None)(func_or_cls)
  64. def watch_generator(self, generator, name):
  65. @functools.wraps(generator)
  66. def wrapper():
  67. while True:
  68. try:
  69. tic = time.perf_counter()
  70. item = next(generator)
  71. self._update(time.perf_counter() - tic, name)
  72. yield item
  73. except StopIteration:
  74. break
  75. return wrapper()
  76. def reset(self):
  77. self._elapses = {}
  78. def _update(self, elapse, name):
  79. elapse = elapse * 1000
  80. if name in self._elapses:
  81. self._elapses[name].append(elapse)
  82. else:
  83. self._elapses[name] = [elapse]
  84. @property
  85. def logs(self):
  86. return self._elapses
  87. def start_timing(self):
  88. self._enabled = True
  89. def stop_timing(self):
  90. self._enabled = False
  91. def start_warmup(self):
  92. self._warmup = True
  93. def stop_warmup(self):
  94. self._warmup = False
  95. self.reset()
  96. def gather(self, batch_size):
  97. # NOTE: The gathering logic here is based on the following assumptions:
  98. # 1. The operations are performed sequentially.
  99. # 2. An operation is performed only once at each iteration.
  100. # 3. Operations do not nest, except that the entry point operation
  101. # contains all other operations.
  102. # 4. The input batch size for each operation is `batch_size`.
  103. # 5. Inference operations are always performed, while preprocessing and
  104. # postprocessing operations are optional.
  105. # 6. If present, preprocessing operations are always performed before
  106. # inference operations, and inference operations are completed before
  107. # any postprocessing operations. There is no interleaving among these
  108. # stages.
  109. logs = {k: v for k, v in self.logs.items()}
  110. summary = {"preprocessing": 0, "inference": 0, "postprocessing": 0}
  111. base_predictor_time_list = logs.pop(ENTRY_POINT_NAME)
  112. iters = len(base_predictor_time_list)
  113. instances = iters * batch_size
  114. summary["end_to_end"] = np.mean(base_predictor_time_list)
  115. detail_list = []
  116. op_tag = "preprocessing"
  117. for name, time_list in logs.items():
  118. assert len(time_list) == iters
  119. avg = np.mean(time_list)
  120. detail_list.append(
  121. (iters, batch_size, instances, name, avg, avg / batch_size)
  122. )
  123. if name in _inference_operations:
  124. summary["inference"] += avg
  125. op_tag = "postprocessing"
  126. else:
  127. summary[op_tag] += avg
  128. summary["core"] = (
  129. summary["preprocessing"] + summary["inference"] + summary["postprocessing"]
  130. )
  131. summary["other"] = summary["end_to_end"] - summary["core"]
  132. summary_list = [
  133. (
  134. iters,
  135. batch_size,
  136. instances,
  137. "Preprocessing",
  138. summary["preprocessing"],
  139. summary["preprocessing"] / batch_size,
  140. ),
  141. (
  142. iters,
  143. batch_size,
  144. instances,
  145. "Inference",
  146. summary["inference"],
  147. summary["inference"] / batch_size,
  148. ),
  149. (
  150. iters,
  151. batch_size,
  152. instances,
  153. "Postprocessing",
  154. summary["postprocessing"],
  155. summary["postprocessing"] / batch_size,
  156. ),
  157. (
  158. iters,
  159. batch_size,
  160. instances,
  161. "Core",
  162. summary["core"],
  163. summary["core"] / batch_size,
  164. ),
  165. (
  166. iters,
  167. batch_size,
  168. instances,
  169. "Other",
  170. summary["other"],
  171. summary["other"] / batch_size,
  172. ),
  173. (
  174. iters,
  175. batch_size,
  176. instances,
  177. "End-to-End",
  178. summary["end_to_end"],
  179. summary["end_to_end"] / batch_size,
  180. ),
  181. ]
  182. return detail_list, summary_list
  183. def collect(self, batch_size):
  184. detail_list, summary_list = self.gather(batch_size)
  185. if self._warmup:
  186. summary_head = [
  187. "Iters",
  188. "Batch Size",
  189. "Instances",
  190. "Type",
  191. "Avg Time Per Iter (ms)",
  192. "Avg Time Per Instance (ms)",
  193. ]
  194. table = PrettyTable(summary_head)
  195. summary_list = [
  196. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  197. ]
  198. table.add_rows(summary_list)
  199. header = "WarmUp Data".center(len(str(table).split("\n")[0]), " ")
  200. logging.info(header)
  201. logging.info(table)
  202. else:
  203. detail_head = [
  204. "Iters",
  205. "Batch Size",
  206. "Instances",
  207. "Operation",
  208. "Avg Time Per Iter (ms)",
  209. "Avg Time Per Instance (ms)",
  210. ]
  211. table = PrettyTable(detail_head)
  212. detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
  213. table.add_rows(detail_list)
  214. header = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  215. logging.info(header)
  216. logging.info(table)
  217. summary_head = [
  218. "Iters",
  219. "Batch Size",
  220. "Instances",
  221. "Type",
  222. "Avg Time Per Iter (ms)",
  223. "Avg Time Per Instance (ms)",
  224. ]
  225. table = PrettyTable(summary_head)
  226. summary_list = [
  227. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  228. ]
  229. table.add_rows(summary_list)
  230. header = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  231. logging.info(header)
  232. logging.info(table)
  233. if INFER_BENCHMARK_OUTPUT_DIR:
  234. save_dir = Path(INFER_BENCHMARK_OUTPUT_DIR)
  235. save_dir.mkdir(parents=True, exist_ok=True)
  236. csv_data = [detail_head, *detail_list]
  237. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  238. writer = csv.writer(file)
  239. writer.writerows(csv_data)
  240. csv_data = [summary_head, *summary_list]
  241. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  242. writer = csv.writer(file)
  243. writer.writerows(csv_data)
  244. def get_inference_operations():
  245. return _inference_operations
  246. def set_inference_operations(val):
  247. global _inference_operations
  248. _inference_operations = val
  249. if INFER_BENCHMARK:
  250. benchmark = Benchmark(enabled=True)
  251. else:
  252. benchmark = Benchmark(enabled=False)