benchmark.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import csv
  16. import functools
  17. import inspect
  18. import time
  19. import uuid
  20. from pathlib import Path
  21. from types import GeneratorType
  22. import numpy as np
  23. from prettytable import PrettyTable
  24. from ...utils import logging
  25. from ...utils.flags import (
  26. INFER_BENCHMARK,
  27. INFER_BENCHMARK_OUTPUT_DIR,
  28. INFER_BENCHMARK_USE_CACHE_FOR_READ,
  29. )
  30. ENTRY_POINT_NAME = "_entry_point_"
  31. # XXX: Global mutable state
  32. _inference_operations = []
  33. _is_measuring_time = False
  34. class Benchmark:
  35. def __init__(self, enabled):
  36. self._enabled = enabled
  37. self._elapses = {}
  38. self._warmup = False
  39. def timeit_with_options(self, name=None, is_read_operation=False):
  40. # TODO: Refactor
  41. def _deco(func_or_cls):
  42. if not self._enabled:
  43. return func_or_cls
  44. nonlocal name
  45. if name is None:
  46. name = func_or_cls.__qualname__
  47. if isinstance(func_or_cls, type):
  48. if not hasattr(func_or_cls, "__call__"):
  49. raise TypeError
  50. func = func_or_cls.__call__
  51. else:
  52. if not callable(func_or_cls):
  53. raise TypeError
  54. func = func_or_cls
  55. try:
  56. source_file = inspect.getsourcefile(func)
  57. source_line = inspect.getsourcelines(func)[1]
  58. location = f"{source_file}:{source_line}"
  59. except (TypeError, OSError) as e:
  60. location = uuid.uuid4().hex
  61. logging.debug(
  62. f"Benchmark: failed to get source file and line number: {e}"
  63. )
  64. use_cache = is_read_operation and INFER_BENCHMARK_USE_CACHE_FOR_READ
  65. if use_cache:
  66. if inspect.isgeneratorfunction(func):
  67. raise RuntimeError(
  68. "When `is_read_operation` is `True`, the wrapped function should not be a generator."
  69. )
  70. func = functools.lru_cache(maxsize=128)(func)
  71. @functools.wraps(func)
  72. def _wrapper(*args, **kwargs):
  73. args = tuple(
  74. tuple(arg) if isinstance(arg, list) else arg for arg in args
  75. )
  76. kwargs = {
  77. k: tuple(v) if isinstance(v, list) else v
  78. for k, v in kwargs.items()
  79. }
  80. output = func(*args, **kwargs)
  81. output = copy.deepcopy(output)
  82. return output
  83. else:
  84. @functools.wraps(func)
  85. def _wrapper(*args, **kwargs):
  86. global _is_measuring_time
  87. operation_name = f"{name}@{location}"
  88. if _is_measuring_time:
  89. raise RuntimeError(
  90. "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
  91. )
  92. if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
  93. _is_measuring_time = True
  94. tic = time.perf_counter()
  95. try:
  96. output = func(*args, **kwargs)
  97. finally:
  98. if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
  99. _is_measuring_time = False
  100. if isinstance(output, GeneratorType):
  101. return self.watch_generator(output, operation_name)
  102. else:
  103. self._update(time.perf_counter() - tic, operation_name)
  104. return output
  105. if isinstance(func_or_cls, type):
  106. func_or_cls.__call__ = _wrapper
  107. return func_or_cls
  108. else:
  109. return _wrapper
  110. return _deco
  111. def timeit(self, func_or_cls):
  112. return self.timeit_with_options()(func_or_cls)
  113. def watch_generator(self, generator, name):
  114. @functools.wraps(generator)
  115. def wrapper():
  116. global _is_measuring_time
  117. while True:
  118. try:
  119. if _is_measuring_time:
  120. raise RuntimeError(
  121. "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
  122. )
  123. if not name.startswith(f"{ENTRY_POINT_NAME}@"):
  124. _is_measuring_time = True
  125. tic = time.perf_counter()
  126. try:
  127. item = next(generator)
  128. finally:
  129. if not name.startswith(f"{ENTRY_POINT_NAME}@"):
  130. _is_measuring_time = False
  131. self._update(time.perf_counter() - tic, name)
  132. yield item
  133. except StopIteration:
  134. break
  135. return wrapper()
  136. def reset(self):
  137. self._elapses = {}
  138. def _update(self, elapse, name):
  139. elapse = elapse * 1000
  140. if name in self._elapses:
  141. self._elapses[name].append(elapse)
  142. else:
  143. self._elapses[name] = [elapse]
  144. @property
  145. def logs(self):
  146. return self._elapses
  147. def start_timing(self):
  148. self._enabled = True
  149. def stop_timing(self):
  150. self._enabled = False
  151. def start_warmup(self):
  152. self._warmup = True
  153. def stop_warmup(self):
  154. self._warmup = False
  155. self.reset()
  156. def gather(self, batch_size):
  157. # NOTE: The gathering logic here is based on the following assumptions:
  158. # 1. The operations are performed sequentially.
  159. # 2. An operation is performed only once at each iteration.
  160. # 3. Operations do not nest, except that the entry point operation
  161. # contains all other operations.
  162. # 4. The input batch size for each operation is `batch_size`.
  163. # 5. Preprocessing operations are always performed before inference
  164. # operations, and inference operations are completed before
  165. # postprocessing operations. There is no interleaving among these
  166. # stages.
  167. logs = {k: v for k, v in self.logs.items()}
  168. summary = {"preprocessing": 0, "inference": 0, "postprocessing": 0}
  169. for key in logs:
  170. if key.startswith(f"{ENTRY_POINT_NAME}@"):
  171. base_predictor_time_list = logs.pop(key)
  172. break
  173. iters = len(base_predictor_time_list)
  174. instances = iters * batch_size
  175. summary["end_to_end"] = np.mean(base_predictor_time_list)
  176. detail_list = []
  177. operation_list = []
  178. op_tag = "preprocessing"
  179. for name, time_list in logs.items():
  180. assert len(time_list) == iters
  181. avg = np.mean(time_list)
  182. operation_name = name.split("@")[0]
  183. location = name.split("@")[1]
  184. if ":" not in location:
  185. location = "Unknown"
  186. detail_list.append(
  187. (iters, batch_size, instances, operation_name, avg, avg / batch_size)
  188. )
  189. operation_list.append((operation_name, location))
  190. if operation_name in _inference_operations:
  191. summary["inference"] += avg
  192. op_tag = "postprocessing"
  193. else:
  194. summary[op_tag] += avg
  195. summary["core"] = (
  196. summary["preprocessing"] + summary["inference"] + summary["postprocessing"]
  197. )
  198. summary["other"] = summary["end_to_end"] - summary["core"]
  199. summary_list = [
  200. (
  201. iters,
  202. batch_size,
  203. instances,
  204. "Preprocessing",
  205. summary["preprocessing"],
  206. summary["preprocessing"] / batch_size,
  207. ),
  208. (
  209. iters,
  210. batch_size,
  211. instances,
  212. "Inference",
  213. summary["inference"],
  214. summary["inference"] / batch_size,
  215. ),
  216. (
  217. iters,
  218. batch_size,
  219. instances,
  220. "Postprocessing",
  221. summary["postprocessing"],
  222. summary["postprocessing"] / batch_size,
  223. ),
  224. (
  225. iters,
  226. batch_size,
  227. instances,
  228. "Core",
  229. summary["core"],
  230. summary["core"] / batch_size,
  231. ),
  232. (
  233. iters,
  234. batch_size,
  235. instances,
  236. "Other",
  237. summary["other"],
  238. summary["other"] / batch_size,
  239. ),
  240. (
  241. iters,
  242. batch_size,
  243. instances,
  244. "End-to-End",
  245. summary["end_to_end"],
  246. summary["end_to_end"] / batch_size,
  247. ),
  248. ]
  249. return detail_list, summary_list, operation_list
  250. def collect(self, batch_size):
  251. detail_list, summary_list, operation_list = self.gather(batch_size)
  252. if self._warmup:
  253. summary_head = [
  254. "Iters",
  255. "Batch Size",
  256. "Instances",
  257. "Type",
  258. "Avg Time Per Iter (ms)",
  259. "Avg Time Per Instance (ms)",
  260. ]
  261. table = PrettyTable(summary_head)
  262. summary_list = [
  263. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  264. ]
  265. table.add_rows(summary_list)
  266. table_title = "Warmup Data".center(len(str(table).split("\n")[0]), " ")
  267. logging.info(table_title)
  268. logging.info(table)
  269. else:
  270. operation_head = [
  271. "Operation",
  272. "Source Code Location",
  273. ]
  274. table = PrettyTable(operation_head)
  275. table.add_rows(operation_list)
  276. table_title = "Operation Info".center(len(str(table).split("\n")[0]), " ")
  277. logging.info(table_title)
  278. logging.info(table)
  279. detail_head = [
  280. "Iters",
  281. "Batch Size",
  282. "Instances",
  283. "Operation",
  284. "Avg Time Per Iter (ms)",
  285. "Avg Time Per Instance (ms)",
  286. ]
  287. table = PrettyTable(detail_head)
  288. detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
  289. table.add_rows(detail_list)
  290. table_title = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  291. logging.info(table_title)
  292. logging.info(table)
  293. summary_head = [
  294. "Iters",
  295. "Batch Size",
  296. "Instances",
  297. "Type",
  298. "Avg Time Per Iter (ms)",
  299. "Avg Time Per Instance (ms)",
  300. ]
  301. table = PrettyTable(summary_head)
  302. summary_list = [
  303. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  304. ]
  305. table.add_rows(summary_list)
  306. table_title = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  307. logging.info(table_title)
  308. logging.info(table)
  309. if INFER_BENCHMARK_OUTPUT_DIR:
  310. save_dir = Path(INFER_BENCHMARK_OUTPUT_DIR)
  311. save_dir.mkdir(parents=True, exist_ok=True)
  312. csv_data = [detail_head, *detail_list]
  313. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  314. writer = csv.writer(file)
  315. writer.writerows(csv_data)
  316. csv_data = [summary_head, *summary_list]
  317. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  318. writer = csv.writer(file)
  319. writer.writerows(csv_data)
  320. def get_inference_operations():
  321. return _inference_operations
  322. def set_inference_operations(val):
  323. global _inference_operations
  324. _inference_operations = val
  325. if INFER_BENCHMARK:
  326. benchmark = Benchmark(enabled=True)
  327. else:
  328. benchmark = Benchmark(enabled=False)