benchmark.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import csv
  16. import functools
  17. import inspect
  18. import time
  19. import uuid
  20. from pathlib import Path
  21. from types import GeneratorType
  22. import numpy as np
  23. from prettytable import PrettyTable
  24. from ...utils import logging
  25. from ...utils.flags import (
  26. INFER_BENCHMARK,
  27. INFER_BENCHMARK_OUTPUT_DIR,
  28. INFER_BENCHMARK_USE_CACHE_FOR_READ,
  29. PIPELINE_BENCHMARK,
  30. )
  31. ENTRY_POINT_NAME = "_entry_point_"
  32. # XXX: Global mutable state
  33. _inference_operations = []
  34. _is_measuring_time = False
  35. PIPELINE_FUNC_BLACK_LIST = ["inintial_predictor"]
  36. _step = 0
  37. _level = 0
  38. _top_func = None
  39. class Benchmark:
  40. def __init__(self, enabled):
  41. self._enabled = enabled
  42. self._elapses = {}
  43. self._warmup = False
  44. self._detail_list = []
  45. self._summary_list = []
  46. self._operation_list = []
  47. def timeit_with_options(self, name=None, is_read_operation=False):
  48. # TODO: Refactor
  49. def _deco(func_or_cls):
  50. if not self._enabled:
  51. return func_or_cls
  52. nonlocal name
  53. if name is None:
  54. name = func_or_cls.__qualname__
  55. if isinstance(func_or_cls, type):
  56. if not hasattr(func_or_cls, "__call__"):
  57. raise TypeError
  58. func = func_or_cls.__call__
  59. else:
  60. if not callable(func_or_cls):
  61. raise TypeError
  62. func = func_or_cls
  63. try:
  64. source_file = inspect.getsourcefile(func)
  65. source_line = inspect.getsourcelines(func)[1]
  66. location = f"{source_file}:{source_line}"
  67. except (TypeError, OSError) as e:
  68. location = uuid.uuid4().hex
  69. logging.debug(
  70. f"Benchmark: failed to get source file and line number: {e}"
  71. )
  72. use_cache = is_read_operation and INFER_BENCHMARK_USE_CACHE_FOR_READ
  73. if use_cache:
  74. if inspect.isgeneratorfunction(func):
  75. raise RuntimeError(
  76. "When `is_read_operation` is `True`, the wrapped function should not be a generator."
  77. )
  78. func = functools.lru_cache(maxsize=128)(func)
  79. @functools.wraps(func)
  80. def _wrapper(*args, **kwargs):
  81. args = tuple(
  82. tuple(arg) if isinstance(arg, list) else arg for arg in args
  83. )
  84. kwargs = {
  85. k: tuple(v) if isinstance(v, list) else v
  86. for k, v in kwargs.items()
  87. }
  88. output = func(*args, **kwargs)
  89. output = copy.deepcopy(output)
  90. return output
  91. else:
  92. if INFER_BENCHMARK:
  93. @functools.wraps(func)
  94. def _wrapper(*args, **kwargs):
  95. global _is_measuring_time
  96. operation_name = f"{name}@{location}"
  97. if _is_measuring_time:
  98. raise RuntimeError(
  99. "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
  100. )
  101. if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
  102. _is_measuring_time = True
  103. tic = time.perf_counter()
  104. try:
  105. output = func(*args, **kwargs)
  106. finally:
  107. if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
  108. _is_measuring_time = False
  109. if isinstance(output, GeneratorType):
  110. return self.watch_generator(output, operation_name)
  111. else:
  112. self._update(time.perf_counter() - tic, operation_name)
  113. return output
  114. elif PIPELINE_BENCHMARK:
  115. @functools.wraps(func)
  116. def _wrapper(*args, **kwargs):
  117. global _step, _level, _top_func
  118. _step += 1
  119. _level += 1
  120. if _level == 1:
  121. if _top_func is None:
  122. _top_func = f"{name}@{location}"
  123. elif _top_func != f"{name}@{location}":
  124. raise RuntimeError(
  125. f"Multiple top-level function calls detected:\n"
  126. f" Function 1: {_top_func.split('@')[0]}\n"
  127. f" Location: {_top_func.split('@')[1]}\n"
  128. f" Function 2: {name}\n"
  129. f" Location: {location}\n"
  130. "Only one top-level function can be tracked at a time.\n"
  131. "Please call 'benchmark.reset()' between top-level function calls."
  132. )
  133. operation_name = f"{_step}@{_level}@{name}@{location}"
  134. tic = time.perf_counter()
  135. output = func(*args, **kwargs)
  136. if isinstance(output, GeneratorType):
  137. return self.watch_generator_simple(output, operation_name)
  138. else:
  139. self._update(time.perf_counter() - tic, operation_name)
  140. _level -= 1
  141. return output
  142. if isinstance(func_or_cls, type):
  143. func_or_cls.__call__ = _wrapper
  144. return func_or_cls
  145. else:
  146. return _wrapper
  147. return _deco
  148. def timeit(self, func_or_cls):
  149. return self.timeit_with_options()(func_or_cls)
  150. def _is_public_method(self, name):
  151. return not name.startswith("_")
  152. def time_methods(self, cls):
  153. for name, func in cls.__dict__.items():
  154. if (
  155. callable(func)
  156. and self._is_public_method(name)
  157. and not name.startswith("__")
  158. and name not in PIPELINE_FUNC_BLACK_LIST
  159. ):
  160. setattr(cls, name, self.timeit(func))
  161. return cls
  162. def watch_generator(self, generator, name):
  163. @functools.wraps(generator)
  164. def wrapper():
  165. global _is_measuring_time
  166. while True:
  167. try:
  168. if _is_measuring_time:
  169. raise RuntimeError(
  170. "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
  171. )
  172. if not name.startswith(f"{ENTRY_POINT_NAME}@"):
  173. _is_measuring_time = True
  174. tic = time.perf_counter()
  175. try:
  176. item = next(generator)
  177. finally:
  178. if not name.startswith(f"{ENTRY_POINT_NAME}@"):
  179. _is_measuring_time = False
  180. self._update(time.perf_counter() - tic, name)
  181. yield item
  182. except StopIteration:
  183. break
  184. return wrapper()
  185. def watch_generator_simple(self, generator, name):
  186. @functools.wraps(generator)
  187. def wrapper():
  188. global _level
  189. try:
  190. while True:
  191. tic = time.perf_counter()
  192. try:
  193. item = next(generator)
  194. except StopIteration:
  195. break
  196. self._update(time.perf_counter() - tic, name)
  197. yield item
  198. finally:
  199. _level -= 1
  200. return wrapper()
  201. def reset(self):
  202. global _step, _level, _top_func
  203. _step = 0
  204. _level = 0
  205. _top_func = None
  206. self._elapses = {}
  207. self._detail_list = []
  208. self._summary_list = []
  209. self._operation_list = []
  210. def _update(self, elapse, name):
  211. elapse = elapse * 1000
  212. if name in self._elapses:
  213. self._elapses[name].append(elapse)
  214. else:
  215. self._elapses[name] = [elapse]
  216. @property
  217. def logs(self):
  218. return self._elapses
  219. def start_timing(self):
  220. self._enabled = True
  221. def stop_timing(self):
  222. self._enabled = False
  223. def start_warmup(self):
  224. self._warmup = True
  225. def stop_warmup(self):
  226. self._warmup = False
  227. self.reset()
  228. def gather(self, batch_size):
  229. # NOTE: The gathering logic here is based on the following assumptions:
  230. # 1. The operations are performed sequentially.
  231. # 2. An operation is performed only once at each iteration.
  232. # 3. Operations do not nest, except that the entry point operation
  233. # contains all other operations.
  234. # 4. The input batch size for each operation is `batch_size`.
  235. # 5. Preprocessing operations are always performed before inference
  236. # operations, and inference operations are completed before
  237. # postprocessing operations. There is no interleaving among these
  238. # stages.
  239. logs = {k: v for k, v in self.logs.items()}
  240. summary = {"preprocessing": 0, "inference": 0, "postprocessing": 0}
  241. for key in logs:
  242. if key.startswith(f"{ENTRY_POINT_NAME}@"):
  243. base_predictor_time_list = logs.pop(key)
  244. break
  245. iters = len(base_predictor_time_list)
  246. instances = iters * batch_size
  247. summary["end_to_end"] = np.mean(base_predictor_time_list)
  248. detail_list = []
  249. operation_list = []
  250. op_tag = "preprocessing"
  251. for name, time_list in logs.items():
  252. assert len(time_list) == iters
  253. avg = np.mean(time_list)
  254. operation_name = name.split("@")[0]
  255. location = name.split("@")[1]
  256. if ":" not in location:
  257. location = "Unknown"
  258. detail_list.append(
  259. (iters, batch_size, instances, operation_name, avg, avg / batch_size)
  260. )
  261. operation_list.append((operation_name, location))
  262. if operation_name in _inference_operations:
  263. summary["inference"] += avg
  264. op_tag = "postprocessing"
  265. else:
  266. summary[op_tag] += avg
  267. summary["core"] = (
  268. summary["preprocessing"] + summary["inference"] + summary["postprocessing"]
  269. )
  270. summary["other"] = summary["end_to_end"] - summary["core"]
  271. summary_list = [
  272. (
  273. iters,
  274. batch_size,
  275. instances,
  276. "Preprocessing",
  277. summary["preprocessing"],
  278. summary["preprocessing"] / batch_size,
  279. ),
  280. (
  281. iters,
  282. batch_size,
  283. instances,
  284. "Inference",
  285. summary["inference"],
  286. summary["inference"] / batch_size,
  287. ),
  288. (
  289. iters,
  290. batch_size,
  291. instances,
  292. "Postprocessing",
  293. summary["postprocessing"],
  294. summary["postprocessing"] / batch_size,
  295. ),
  296. (
  297. iters,
  298. batch_size,
  299. instances,
  300. "Core",
  301. summary["core"],
  302. summary["core"] / batch_size,
  303. ),
  304. (
  305. iters,
  306. batch_size,
  307. instances,
  308. "Other",
  309. summary["other"],
  310. summary["other"] / batch_size,
  311. ),
  312. (
  313. iters,
  314. batch_size,
  315. instances,
  316. "End-to-End",
  317. summary["end_to_end"],
  318. summary["end_to_end"] / batch_size,
  319. ),
  320. ]
  321. return detail_list, summary_list, operation_list
  322. def collect(self, batch_size):
  323. detail_list, summary_list, operation_list = self.gather(batch_size)
  324. if self._warmup:
  325. summary_head = [
  326. "Iters",
  327. "Batch Size",
  328. "Instances",
  329. "Type",
  330. "Avg Time Per Iter (ms)",
  331. "Avg Time Per Instance (ms)",
  332. ]
  333. table = PrettyTable(summary_head)
  334. summary_list = [
  335. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  336. ]
  337. table.add_rows(summary_list)
  338. table_title = "Warmup Data".center(len(str(table).split("\n")[0]), " ")
  339. logging.info(table_title)
  340. logging.info(table)
  341. else:
  342. operation_head = [
  343. "Operation",
  344. "Source Code Location",
  345. ]
  346. table = PrettyTable(operation_head)
  347. table.add_rows(operation_list)
  348. table_title = "Operation Info".center(len(str(table).split("\n")[0]), " ")
  349. logging.info(table_title)
  350. logging.info(table)
  351. detail_head = [
  352. "Iters",
  353. "Batch Size",
  354. "Instances",
  355. "Operation",
  356. "Avg Time Per Iter (ms)",
  357. "Avg Time Per Instance (ms)",
  358. ]
  359. table = PrettyTable(detail_head)
  360. detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
  361. table.add_rows(detail_list)
  362. table_title = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  363. logging.info(table_title)
  364. logging.info(table)
  365. summary_head = [
  366. "Iters",
  367. "Batch Size",
  368. "Instances",
  369. "Type",
  370. "Avg Time Per Iter (ms)",
  371. "Avg Time Per Instance (ms)",
  372. ]
  373. table = PrettyTable(summary_head)
  374. summary_list = [
  375. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  376. ]
  377. table.add_rows(summary_list)
  378. table_title = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  379. logging.info(table_title)
  380. logging.info(table)
  381. if INFER_BENCHMARK_OUTPUT_DIR:
  382. save_dir = Path(INFER_BENCHMARK_OUTPUT_DIR)
  383. save_dir.mkdir(parents=True, exist_ok=True)
  384. csv_data = [detail_head, *detail_list]
  385. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  386. writer = csv.writer(file)
  387. writer.writerows(csv_data)
  388. csv_data = [summary_head, *summary_list]
  389. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  390. writer = csv.writer(file)
  391. writer.writerows(csv_data)
  392. def gather_pipeline(self):
  393. info_list = []
  394. detail_list = []
  395. operation_list = set()
  396. summary_list = []
  397. max_level = 0
  398. loop_num = 0
  399. for name, time_list in self.logs.items():
  400. op_time = np.sum(time_list)
  401. parts = name.split("@")
  402. step = int(parts[0])
  403. level = int(parts[1])
  404. operation_name = parts[2]
  405. location = parts[3]
  406. if ":" not in location:
  407. location = "Unknown"
  408. operation_list.add((operation_name, location))
  409. max_level = max(level, max_level)
  410. if level == 1:
  411. loop_num += 1
  412. format_operation_name = operation_name
  413. else:
  414. format_operation_name = " " * int(level - 1) + "-> " + operation_name
  415. info_list.append(
  416. (step, level, operation_name, format_operation_name, op_time)
  417. )
  418. operation_list = list(operation_list)
  419. info_list.sort(key=lambda x: x[0])
  420. step_num = int(len(info_list) / loop_num)
  421. for idx in range(step_num):
  422. step = info_list[idx][0]
  423. format_operation_name = info_list[idx][3]
  424. op_time = (
  425. np.sum(
  426. [info_list[pos][4] for pos in range(idx, len(info_list), step_num)]
  427. )
  428. / loop_num
  429. )
  430. detail_list.append([step, format_operation_name, op_time])
  431. level_time_list = [[0] for _ in range(max_level)]
  432. for idx, info in enumerate(info_list):
  433. step = info[0]
  434. level = info[1]
  435. operation_name = info[2]
  436. op_time = info[4]
  437. # The total time consumed by all operations on this layer
  438. if level > info_list[idx - 1][1]:
  439. level_time_list[level - 1].append(info_list[idx - 1][4])
  440. # The total time consumed by each operation on this layer
  441. while len(summary_list) < level:
  442. summary_list.append([len(summary_list) + 1, {}])
  443. if summary_list[level - 1][1].get(operation_name, None) is None:
  444. summary_list[level - 1][1][operation_name] = [op_time]
  445. else:
  446. summary_list[level - 1][1][operation_name].append(op_time)
  447. new_summary_list = []
  448. for i in range(len(summary_list)):
  449. level = summary_list[i][0]
  450. op_dict = summary_list[i][1]
  451. ops_all_time = 0.0
  452. op_info_list = []
  453. for idx, (name, time_list) in enumerate(op_dict.items()):
  454. op_all_time = np.sum(time_list) / loop_num
  455. op_info_list.append([level if i + idx == 0 else "", name, op_all_time])
  456. ops_all_time += op_all_time
  457. if i > 0:
  458. new_summary_list.append(["", "", ""])
  459. new_summary_list.append(
  460. [level, "Layer", np.sum(level_time_list[i]) / loop_num]
  461. )
  462. new_summary_list.append(["", "Core", ops_all_time])
  463. new_summary_list.append(
  464. ["", "Other", np.sum(level_time_list[i]) / loop_num - ops_all_time]
  465. )
  466. new_summary_list += op_info_list
  467. return detail_list, new_summary_list, operation_list
  468. def _initialize_pipeline_data(self):
  469. if not (self._operation_list and self._detail_list and self._summary_list):
  470. self._detail_list, self._summary_list, self._operation_list = (
  471. self.gather_pipeline()
  472. )
  473. def print_pipeline_data(self):
  474. self._initialize_pipeline_data()
  475. self.print_operation_info()
  476. self.print_detail_data()
  477. self.print_summary_data()
  478. def print_operation_info(self):
  479. self._initialize_pipeline_data()
  480. operation_head = [
  481. "Operation",
  482. "Source Code Location",
  483. ]
  484. table = PrettyTable(operation_head)
  485. table.add_rows(self._operation_list)
  486. table_title = "Operation Info".center(len(str(table).split("\n")[0]), " ")
  487. logging.info(table_title)
  488. logging.info(table)
  489. def print_detail_data(self):
  490. self._initialize_pipeline_data()
  491. detail_head = [
  492. "Step",
  493. "Operation",
  494. "Time (ms)",
  495. ]
  496. table = PrettyTable(detail_head)
  497. table.add_rows(self._detail_list)
  498. table_title = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  499. table.align["Operation"] = "l"
  500. table.align["Time (ms)"] = "l"
  501. logging.info(table_title)
  502. logging.info(table)
  503. def print_summary_data(self):
  504. self._initialize_pipeline_data()
  505. summary_head = [
  506. "Level",
  507. "Operation",
  508. "Time (ms)",
  509. ]
  510. table = PrettyTable(summary_head)
  511. table.add_rows(self._summary_list)
  512. table_title = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  513. table.align["Operation"] = "l"
  514. table.align["Time (ms)"] = "l"
  515. logging.info(table_title)
  516. logging.info(table)
  517. def save_pipeline_data(self, save_path):
  518. self._initialize_pipeline_data()
  519. save_dir = Path(save_path)
  520. save_dir.mkdir(parents=True, exist_ok=True)
  521. detail_head = [
  522. "Step",
  523. "Operation",
  524. "Time (ms)",
  525. ]
  526. csv_data = [detail_head, *self._detail_list]
  527. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  528. writer = csv.writer(file)
  529. writer.writerows(csv_data)
  530. summary_head = [
  531. "Level",
  532. "Operation",
  533. "Time (ms)",
  534. ]
  535. csv_data = [summary_head, *self._summary_list]
  536. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  537. writer = csv.writer(file)
  538. writer.writerows(csv_data)
  539. def get_inference_operations():
  540. return _inference_operations
  541. def set_inference_operations(val):
  542. global _inference_operations
  543. _inference_operations = val
  544. if INFER_BENCHMARK or PIPELINE_BENCHMARK:
  545. benchmark = Benchmark(enabled=True)
  546. else:
  547. benchmark = Benchmark(enabled=False)