benchmark.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import functools
  16. from types import GeneratorType
  17. import time
  18. from pathlib import Path
  19. import inspect
  20. import numpy as np
  21. from prettytable import PrettyTable
  22. from ...utils.flags import INFER_BENCHMARK, INFER_BENCHMARK_OUTPUT_DIR
  23. from ...utils import logging
  24. ENTRY_POINT_NAME = "_entry_point_"
  25. # XXX: Global mutable state
  26. _inference_operations = []
  27. class Benchmark:
  28. def __init__(self, enabled):
  29. self._enabled = enabled
  30. self._elapses = {}
  31. self._warmup = False
  32. def timeit_with_name(self, name=None):
  33. # TODO: Refactor
  34. def _deco(func_or_cls):
  35. nonlocal name
  36. if name is None:
  37. name = func_or_cls.__qualname__
  38. if isinstance(func_or_cls, type):
  39. if not hasattr(func_or_cls, "__call__"):
  40. raise TypeError
  41. func = func_or_cls.__call__
  42. else:
  43. if not callable(func_or_cls):
  44. raise TypeError
  45. func = func_or_cls
  46. location = None
  47. @functools.wraps(func)
  48. def _wrapper(*args, **kwargs):
  49. nonlocal location
  50. if not self._enabled:
  51. return func(*args, **kwargs)
  52. if location is None:
  53. try:
  54. source_file = inspect.getsourcefile(func)
  55. source_line = inspect.getsourcelines(func)[1]
  56. location = f"{source_file}:{source_line}"
  57. except (TypeError, OSError) as e:
  58. location = "Unknown"
  59. logging.debug(
  60. f"Benchmark: failed to get source file and line number: {e}"
  61. )
  62. tic = time.perf_counter()
  63. output = func(*args, **kwargs)
  64. if isinstance(output, GeneratorType):
  65. return self.watch_generator(output, f"{name}@{location}")
  66. else:
  67. self._update(time.perf_counter() - tic, f"{name}@{location}")
  68. return output
  69. if isinstance(func_or_cls, type):
  70. func_or_cls.__call__ = _wrapper
  71. return func_or_cls
  72. else:
  73. return _wrapper
  74. return _deco
  75. def timeit(self, func_or_cls):
  76. return self.timeit_with_name(None)(func_or_cls)
  77. def watch_generator(self, generator, name):
  78. @functools.wraps(generator)
  79. def wrapper():
  80. while True:
  81. try:
  82. tic = time.perf_counter()
  83. item = next(generator)
  84. self._update(time.perf_counter() - tic, name)
  85. yield item
  86. except StopIteration:
  87. break
  88. return wrapper()
  89. def reset(self):
  90. self._elapses = {}
  91. def _update(self, elapse, name):
  92. elapse = elapse * 1000
  93. if name in self._elapses:
  94. self._elapses[name].append(elapse)
  95. else:
  96. self._elapses[name] = [elapse]
  97. @property
  98. def logs(self):
  99. return self._elapses
  100. def start_timing(self):
  101. self._enabled = True
  102. def stop_timing(self):
  103. self._enabled = False
  104. def start_warmup(self):
  105. self._warmup = True
  106. def stop_warmup(self):
  107. self._warmup = False
  108. self.reset()
  109. def gather(self, batch_size):
  110. # NOTE: The gathering logic here is based on the following assumptions:
  111. # 1. The operations are performed sequentially.
  112. # 2. An operation is performed only once at each iteration.
  113. # 3. Operations do not nest, except that the entry point operation
  114. # contains all other operations.
  115. # 4. The input batch size for each operation is `batch_size`.
  116. # 5. Inference operations are always performed, while preprocessing and
  117. # postprocessing operations are optional.
  118. # 6. If present, preprocessing operations are always performed before
  119. # inference operations, and inference operations are completed before
  120. # any postprocessing operations. There is no interleaving among these
  121. # stages.
  122. logs = {k: v for k, v in self.logs.items()}
  123. summary = {"preprocessing": 0, "inference": 0, "postprocessing": 0}
  124. for key in logs:
  125. if key.startswith(f"{ENTRY_POINT_NAME}@"):
  126. base_predictor_time_list = logs.pop(key)
  127. break
  128. iters = len(base_predictor_time_list)
  129. instances = iters * batch_size
  130. summary["end_to_end"] = np.mean(base_predictor_time_list)
  131. detail_list = []
  132. operation_list = []
  133. op_tag = "preprocessing"
  134. for name, time_list in logs.items():
  135. assert len(time_list) == iters
  136. avg = np.mean(time_list)
  137. operation_name = name.split("@")[0]
  138. location = name.split("@")[1]
  139. detail_list.append(
  140. (iters, batch_size, instances, operation_name, avg, avg / batch_size)
  141. )
  142. operation_list.append((operation_name, location))
  143. if operation_name in _inference_operations:
  144. summary["inference"] += avg
  145. op_tag = "postprocessing"
  146. else:
  147. summary[op_tag] += avg
  148. summary["core"] = (
  149. summary["preprocessing"] + summary["inference"] + summary["postprocessing"]
  150. )
  151. summary["other"] = summary["end_to_end"] - summary["core"]
  152. summary_list = [
  153. (
  154. iters,
  155. batch_size,
  156. instances,
  157. "Preprocessing",
  158. summary["preprocessing"],
  159. summary["preprocessing"] / batch_size,
  160. ),
  161. (
  162. iters,
  163. batch_size,
  164. instances,
  165. "Inference",
  166. summary["inference"],
  167. summary["inference"] / batch_size,
  168. ),
  169. (
  170. iters,
  171. batch_size,
  172. instances,
  173. "Postprocessing",
  174. summary["postprocessing"],
  175. summary["postprocessing"] / batch_size,
  176. ),
  177. (
  178. iters,
  179. batch_size,
  180. instances,
  181. "Core",
  182. summary["core"],
  183. summary["core"] / batch_size,
  184. ),
  185. (
  186. iters,
  187. batch_size,
  188. instances,
  189. "Other",
  190. summary["other"],
  191. summary["other"] / batch_size,
  192. ),
  193. (
  194. iters,
  195. batch_size,
  196. instances,
  197. "End-to-End",
  198. summary["end_to_end"],
  199. summary["end_to_end"] / batch_size,
  200. ),
  201. ]
  202. return detail_list, summary_list, operation_list
  203. def collect(self, batch_size):
  204. detail_list, summary_list, operation_list = self.gather(batch_size)
  205. if self._warmup:
  206. summary_head = [
  207. "Iters",
  208. "Batch Size",
  209. "Instances",
  210. "Type",
  211. "Avg Time Per Iter (ms)",
  212. "Avg Time Per Instance (ms)",
  213. ]
  214. table = PrettyTable(summary_head)
  215. summary_list = [
  216. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  217. ]
  218. table.add_rows(summary_list)
  219. table_name = "WarmUp Data".center(len(str(table).split("\n")[0]), " ")
  220. logging.info(table_name)
  221. logging.info(table)
  222. else:
  223. operation_head = [
  224. "Operation",
  225. "Source Code Location",
  226. ]
  227. table = PrettyTable(operation_head)
  228. table.add_rows(operation_list)
  229. table_name = "Operation Info".center(len(str(table).split("\n")[0]), " ")
  230. logging.info(table_name)
  231. logging.info(table)
  232. detail_head = [
  233. "Iters",
  234. "Batch Size",
  235. "Instances",
  236. "Operation",
  237. "Avg Time Per Iter (ms)",
  238. "Avg Time Per Instance (ms)",
  239. ]
  240. table = PrettyTable(detail_head)
  241. detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
  242. table.add_rows(detail_list)
  243. table_name = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  244. logging.info(table_name)
  245. logging.info(table)
  246. summary_head = [
  247. "Iters",
  248. "Batch Size",
  249. "Instances",
  250. "Type",
  251. "Avg Time Per Iter (ms)",
  252. "Avg Time Per Instance (ms)",
  253. ]
  254. table = PrettyTable(summary_head)
  255. summary_list = [
  256. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  257. ]
  258. table.add_rows(summary_list)
  259. table_name = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  260. logging.info(table_name)
  261. logging.info(table)
  262. if INFER_BENCHMARK_OUTPUT_DIR:
  263. save_dir = Path(INFER_BENCHMARK_OUTPUT_DIR)
  264. save_dir.mkdir(parents=True, exist_ok=True)
  265. csv_data = [detail_head, *detail_list]
  266. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  267. writer = csv.writer(file)
  268. writer.writerows(csv_data)
  269. csv_data = [summary_head, *summary_list]
  270. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  271. writer = csv.writer(file)
  272. writer.writerows(csv_data)
  273. def get_inference_operations():
  274. return _inference_operations
  275. def set_inference_operations(val):
  276. global _inference_operations
  277. _inference_operations = val
  278. if INFER_BENCHMARK:
  279. benchmark = Benchmark(enabled=True)
  280. else:
  281. benchmark = Benchmark(enabled=False)