benchmark.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import functools
  16. from types import GeneratorType
  17. import time
  18. from pathlib import Path
  19. import inspect
  20. import numpy as np
  21. from prettytable import PrettyTable
  22. from ...utils.flags import (
  23. INFER_BENCHMARK,
  24. INFER_BENCHMARK_OUTPUT_DIR,
  25. INFER_BENCHMARK_USE_CACHE_FOR_READ,
  26. )
  27. from ...utils import logging
  28. ENTRY_POINT_NAME = "_entry_point_"
  29. # XXX: Global mutable state
  30. _inference_operations = []
  31. class Benchmark:
  32. def __init__(self, enabled):
  33. self._enabled = enabled
  34. self._elapses = {}
  35. self._warmup = False
  36. def timeit_with_options(self, name=None, is_read_operation=False):
  37. # TODO: Refactor
  38. def _deco(func_or_cls):
  39. if not self._enabled:
  40. return func_or_cls
  41. nonlocal name
  42. if name is None:
  43. name = func_or_cls.__qualname__
  44. if isinstance(func_or_cls, type):
  45. if not hasattr(func_or_cls, "__call__"):
  46. raise TypeError
  47. func = func_or_cls.__call__
  48. else:
  49. if not callable(func_or_cls):
  50. raise TypeError
  51. func = func_or_cls
  52. try:
  53. source_file = inspect.getsourcefile(func)
  54. source_line = inspect.getsourcelines(func)[1]
  55. location = f"{source_file}:{source_line}"
  56. except (TypeError, OSError) as e:
  57. location = "Unknown"
  58. logging.debug(
  59. f"Benchmark: failed to get source file and line number: {e}"
  60. )
  61. use_cache = is_read_operation and INFER_BENCHMARK_USE_CACHE_FOR_READ
  62. if use_cache:
  63. if inspect.isgeneratorfunction(func):
  64. raise RuntimeError(
  65. "When `is_read_operation` is `True`, the wrapped function should not be a generator."
  66. )
  67. func = functools.lru_cache(maxsize=128)(func)
  68. @functools.wraps(func)
  69. def _wrapper(*args, **kwargs):
  70. args = tuple(
  71. tuple(arg) if isinstance(arg, list) else arg for arg in args
  72. )
  73. kwargs = {
  74. k: tuple(v) if isinstance(v, list) else v
  75. for k, v in kwargs.items()
  76. }
  77. output = func(*args, **kwargs)
  78. return output
  79. else:
  80. @functools.wraps(func)
  81. def _wrapper(*args, **kwargs):
  82. operation_name = f"{name}@{location}"
  83. tic = time.perf_counter()
  84. output = func(*args, **kwargs)
  85. if isinstance(output, GeneratorType):
  86. return self.watch_generator(output, operation_name)
  87. else:
  88. self._update(time.perf_counter() - tic, operation_name)
  89. return output
  90. if isinstance(func_or_cls, type):
  91. func_or_cls.__call__ = _wrapper
  92. return func_or_cls
  93. else:
  94. return _wrapper
  95. return _deco
  96. def timeit(self, func_or_cls):
  97. return self.timeit_with_options()(func_or_cls)
  98. def watch_generator(self, generator, name):
  99. @functools.wraps(generator)
  100. def wrapper():
  101. while True:
  102. try:
  103. tic = time.perf_counter()
  104. item = next(generator)
  105. self._update(time.perf_counter() - tic, name)
  106. yield item
  107. except StopIteration:
  108. break
  109. return wrapper()
  110. def reset(self):
  111. self._elapses = {}
  112. def _update(self, elapse, name):
  113. elapse = elapse * 1000
  114. if name in self._elapses:
  115. self._elapses[name].append(elapse)
  116. else:
  117. self._elapses[name] = [elapse]
  118. @property
  119. def logs(self):
  120. return self._elapses
  121. def start_timing(self):
  122. self._enabled = True
  123. def stop_timing(self):
  124. self._enabled = False
  125. def start_warmup(self):
  126. self._warmup = True
  127. def stop_warmup(self):
  128. self._warmup = False
  129. self.reset()
  130. def gather(self, batch_size):
  131. # NOTE: The gathering logic here is based on the following assumptions:
  132. # 1. The operations are performed sequentially.
  133. # 2. An operation is performed only once at each iteration.
  134. # 3. Operations do not nest, except that the entry point operation
  135. # contains all other operations.
  136. # 4. The input batch size for each operation is `batch_size`.
  137. # 5. Inference operations are always performed, while preprocessing and
  138. # postprocessing operations are optional.
  139. # 6. If present, preprocessing operations are always performed before
  140. # inference operations, and inference operations are completed before
  141. # any postprocessing operations. There is no interleaving among these
  142. # stages.
  143. logs = {k: v for k, v in self.logs.items()}
  144. summary = {"preprocessing": 0, "inference": 0, "postprocessing": 0}
  145. for key in logs:
  146. if key.startswith(f"{ENTRY_POINT_NAME}@"):
  147. base_predictor_time_list = logs.pop(key)
  148. break
  149. iters = len(base_predictor_time_list)
  150. instances = iters * batch_size
  151. summary["end_to_end"] = np.mean(base_predictor_time_list)
  152. detail_list = []
  153. operation_list = []
  154. op_tag = "preprocessing"
  155. for name, time_list in logs.items():
  156. assert len(time_list) == iters
  157. avg = np.mean(time_list)
  158. operation_name = name.split("@")[0]
  159. location = name.split("@")[1]
  160. detail_list.append(
  161. (iters, batch_size, instances, operation_name, avg, avg / batch_size)
  162. )
  163. operation_list.append((operation_name, location))
  164. if operation_name in _inference_operations:
  165. summary["inference"] += avg
  166. op_tag = "postprocessing"
  167. else:
  168. summary[op_tag] += avg
  169. summary["core"] = (
  170. summary["preprocessing"] + summary["inference"] + summary["postprocessing"]
  171. )
  172. summary["other"] = summary["end_to_end"] - summary["core"]
  173. summary_list = [
  174. (
  175. iters,
  176. batch_size,
  177. instances,
  178. "Preprocessing",
  179. summary["preprocessing"],
  180. summary["preprocessing"] / batch_size,
  181. ),
  182. (
  183. iters,
  184. batch_size,
  185. instances,
  186. "Inference",
  187. summary["inference"],
  188. summary["inference"] / batch_size,
  189. ),
  190. (
  191. iters,
  192. batch_size,
  193. instances,
  194. "Postprocessing",
  195. summary["postprocessing"],
  196. summary["postprocessing"] / batch_size,
  197. ),
  198. (
  199. iters,
  200. batch_size,
  201. instances,
  202. "Core",
  203. summary["core"],
  204. summary["core"] / batch_size,
  205. ),
  206. (
  207. iters,
  208. batch_size,
  209. instances,
  210. "Other",
  211. summary["other"],
  212. summary["other"] / batch_size,
  213. ),
  214. (
  215. iters,
  216. batch_size,
  217. instances,
  218. "End-to-End",
  219. summary["end_to_end"],
  220. summary["end_to_end"] / batch_size,
  221. ),
  222. ]
  223. return detail_list, summary_list, operation_list
  224. def collect(self, batch_size):
  225. detail_list, summary_list, operation_list = self.gather(batch_size)
  226. if self._warmup:
  227. summary_head = [
  228. "Iters",
  229. "Batch Size",
  230. "Instances",
  231. "Type",
  232. "Avg Time Per Iter (ms)",
  233. "Avg Time Per Instance (ms)",
  234. ]
  235. table = PrettyTable(summary_head)
  236. summary_list = [
  237. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  238. ]
  239. table.add_rows(summary_list)
  240. table_name = "WarmUp Data".center(len(str(table).split("\n")[0]), " ")
  241. logging.info(table_name)
  242. logging.info(table)
  243. else:
  244. operation_head = [
  245. "Operation",
  246. "Source Code Location",
  247. ]
  248. table = PrettyTable(operation_head)
  249. table.add_rows(operation_list)
  250. table_name = "Operation Info".center(len(str(table).split("\n")[0]), " ")
  251. logging.info(table_name)
  252. logging.info(table)
  253. detail_head = [
  254. "Iters",
  255. "Batch Size",
  256. "Instances",
  257. "Operation",
  258. "Avg Time Per Iter (ms)",
  259. "Avg Time Per Instance (ms)",
  260. ]
  261. table = PrettyTable(detail_head)
  262. detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
  263. table.add_rows(detail_list)
  264. table_name = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  265. logging.info(table_name)
  266. logging.info(table)
  267. summary_head = [
  268. "Iters",
  269. "Batch Size",
  270. "Instances",
  271. "Type",
  272. "Avg Time Per Iter (ms)",
  273. "Avg Time Per Instance (ms)",
  274. ]
  275. table = PrettyTable(summary_head)
  276. summary_list = [
  277. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  278. ]
  279. table.add_rows(summary_list)
  280. table_name = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  281. logging.info(table_name)
  282. logging.info(table)
  283. if INFER_BENCHMARK_OUTPUT_DIR:
  284. save_dir = Path(INFER_BENCHMARK_OUTPUT_DIR)
  285. save_dir.mkdir(parents=True, exist_ok=True)
  286. csv_data = [detail_head, *detail_list]
  287. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  288. writer = csv.writer(file)
  289. writer.writerows(csv_data)
  290. csv_data = [summary_head, *summary_list]
  291. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  292. writer = csv.writer(file)
  293. writer.writerows(csv_data)
  294. def get_inference_operations():
  295. return _inference_operations
  296. def set_inference_operations(val):
  297. global _inference_operations
  298. _inference_operations = val
  299. if INFER_BENCHMARK:
  300. benchmark = Benchmark(enabled=True)
  301. else:
  302. benchmark = Benchmark(enabled=False)