benchmark.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import csv
  16. import functools
  17. import inspect
  18. import time
  19. import uuid
  20. from pathlib import Path
  21. from types import GeneratorType
  22. import numpy as np
  23. from prettytable import PrettyTable
  24. from ...utils import logging
  25. from ...utils.flags import (
  26. INFER_BENCHMARK,
  27. INFER_BENCHMARK_OUTPUT_DIR,
  28. INFER_BENCHMARK_USE_CACHE_FOR_READ,
  29. PIPELINE_BENCHMARK,
  30. )
  31. ENTRY_POINT_NAME = "_entry_point_"
  32. # XXX: Global mutable state
  33. _inference_operations = []
  34. _is_measuring_time = False
  35. PIPELINE_FUNC_BLACK_LIST = ["inintial_predictor"]
  36. _step = 0
  37. _level = 0
  38. _top_func = None
  39. class Benchmark:
  40. def __init__(self, enabled):
  41. self._enabled = enabled
  42. self._elapses = {}
  43. self._warmup = False
  44. self._detail_list = []
  45. self._summary_list = []
  46. self._operation_list = []
  47. def timeit_with_options(self, name=None, is_read_operation=False):
  48. # TODO: Refactor
  49. def _deco(func_or_cls):
  50. if not self._enabled:
  51. return func_or_cls
  52. nonlocal name
  53. if name is None:
  54. name = func_or_cls.__qualname__
  55. if isinstance(func_or_cls, type):
  56. if not hasattr(func_or_cls, "__call__"):
  57. raise TypeError
  58. func = func_or_cls.__call__
  59. else:
  60. if not callable(func_or_cls):
  61. raise TypeError
  62. func = func_or_cls
  63. try:
  64. source_file = inspect.getsourcefile(func)
  65. source_line = inspect.getsourcelines(func)[1]
  66. location = f"{source_file}:{source_line}"
  67. except (TypeError, OSError) as e:
  68. location = uuid.uuid4().hex
  69. logging.debug(
  70. f"Benchmark: failed to get source file and line number: {e}"
  71. )
  72. use_cache = is_read_operation and INFER_BENCHMARK_USE_CACHE_FOR_READ
  73. if use_cache:
  74. if inspect.isgeneratorfunction(func):
  75. raise RuntimeError(
  76. "When `is_read_operation` is `True`, the wrapped function should not be a generator."
  77. )
  78. func = functools.lru_cache(maxsize=128)(func)
  79. @functools.wraps(func)
  80. def _wrapper(*args, **kwargs):
  81. args = tuple(
  82. tuple(arg) if isinstance(arg, list) else arg for arg in args
  83. )
  84. kwargs = {
  85. k: tuple(v) if isinstance(v, list) else v
  86. for k, v in kwargs.items()
  87. }
  88. output = func(*args, **kwargs)
  89. output = copy.deepcopy(output)
  90. return output
  91. else:
  92. if INFER_BENCHMARK:
  93. @functools.wraps(func)
  94. def _wrapper(*args, **kwargs):
  95. global _is_measuring_time
  96. operation_name = f"{name}@{location}"
  97. if _is_measuring_time:
  98. raise RuntimeError(
  99. "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
  100. )
  101. if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
  102. _is_measuring_time = True
  103. tic = time.perf_counter()
  104. try:
  105. output = func(*args, **kwargs)
  106. finally:
  107. if not operation_name.startswith(f"{ENTRY_POINT_NAME}@"):
  108. _is_measuring_time = False
  109. if isinstance(output, GeneratorType):
  110. return self.watch_generator(output, operation_name)
  111. else:
  112. self._update(time.perf_counter() - tic, operation_name)
  113. return output
  114. elif PIPELINE_BENCHMARK:
  115. @functools.wraps(func)
  116. def _wrapper(*args, **kwargs):
  117. global _step, _level, _top_func
  118. _step += 1
  119. _level += 1
  120. if _level == 1:
  121. if _top_func is None:
  122. _top_func = f"{name}@{location}"
  123. elif _top_func != f"{name}@{location}":
  124. raise RuntimeError(
  125. f"Multiple top-level function calls detected:\n"
  126. f" Function 1: {_top_func.split('@')[0]}\n"
  127. f" Location: {_top_func.split('@')[1]}\n"
  128. f" Function 2: {name}\n"
  129. f" Location: {location}\n"
  130. "Only one top-level function can be tracked at a time.\n"
  131. "Please call 'benchmark.reset()' between top-level function calls."
  132. )
  133. operation_name = f"{_step}@{_level}@{name}@{location}"
  134. tic = time.perf_counter()
  135. output = func(*args, **kwargs)
  136. if isinstance(output, GeneratorType):
  137. return self.watch_generator_simple(output, operation_name)
  138. else:
  139. self._update(time.perf_counter() - tic, operation_name)
  140. _level -= 1
  141. return output
  142. if isinstance(func_or_cls, type):
  143. func_or_cls.__call__ = _wrapper
  144. return func_or_cls
  145. else:
  146. return _wrapper
  147. return _deco
  148. def timeit(self, func_or_cls):
  149. return self.timeit_with_options()(func_or_cls)
  150. def _is_public_method(self, name):
  151. return not name.startswith("_")
  152. def time_methods(self, cls):
  153. for name, func in cls.__dict__.items():
  154. if (
  155. callable(func)
  156. and self._is_public_method(name)
  157. and not name.startswith("__")
  158. and name not in PIPELINE_FUNC_BLACK_LIST
  159. ):
  160. setattr(cls, name, self.timeit(func))
  161. return cls
  162. def watch_generator(self, generator, name):
  163. @functools.wraps(generator)
  164. def wrapper():
  165. global _is_measuring_time
  166. while True:
  167. try:
  168. if _is_measuring_time:
  169. raise RuntimeError(
  170. "Nested calls detected: Check the timed modules and exclude nested calls to prevent double-counting."
  171. )
  172. if not name.startswith(f"{ENTRY_POINT_NAME}@"):
  173. _is_measuring_time = True
  174. tic = time.perf_counter()
  175. try:
  176. item = next(generator)
  177. finally:
  178. if not name.startswith(f"{ENTRY_POINT_NAME}@"):
  179. _is_measuring_time = False
  180. self._update(time.perf_counter() - tic, name)
  181. yield item
  182. except StopIteration:
  183. break
  184. return wrapper()
  185. def watch_generator_simple(self, generator, name):
  186. @functools.wraps(generator)
  187. def wrapper():
  188. global _level
  189. while True:
  190. tic = time.perf_counter()
  191. try:
  192. item = next(generator)
  193. except StopIteration:
  194. break
  195. self._update(time.perf_counter() - tic, name)
  196. yield item
  197. _level -= 1
  198. return wrapper()
  199. def reset(self):
  200. global _step, _level, _top_func
  201. _step = 0
  202. _level = 0
  203. _top_func = None
  204. self._elapses = {}
  205. self._detail_list = []
  206. self._summary_list = []
  207. self._operation_list = []
  208. def _update(self, elapse, name):
  209. elapse = elapse * 1000
  210. if name in self._elapses:
  211. self._elapses[name].append(elapse)
  212. else:
  213. self._elapses[name] = [elapse]
  214. @property
  215. def logs(self):
  216. return self._elapses
  217. def start_timing(self):
  218. self._enabled = True
  219. def stop_timing(self):
  220. self._enabled = False
  221. def start_warmup(self):
  222. self._warmup = True
  223. def stop_warmup(self):
  224. self._warmup = False
  225. self.reset()
  226. def gather(self, batch_size):
  227. # NOTE: The gathering logic here is based on the following assumptions:
  228. # 1. The operations are performed sequentially.
  229. # 2. An operation is performed only once at each iteration.
  230. # 3. Operations do not nest, except that the entry point operation
  231. # contains all other operations.
  232. # 4. The input batch size for each operation is `batch_size`.
  233. # 5. Preprocessing operations are always performed before inference
  234. # operations, and inference operations are completed before
  235. # postprocessing operations. There is no interleaving among these
  236. # stages.
  237. logs = {k: v for k, v in self.logs.items()}
  238. summary = {"preprocessing": 0, "inference": 0, "postprocessing": 0}
  239. for key in logs:
  240. if key.startswith(f"{ENTRY_POINT_NAME}@"):
  241. base_predictor_time_list = logs.pop(key)
  242. break
  243. iters = len(base_predictor_time_list)
  244. instances = iters * batch_size
  245. summary["end_to_end"] = np.mean(base_predictor_time_list)
  246. detail_list = []
  247. operation_list = []
  248. op_tag = "preprocessing"
  249. for name, time_list in logs.items():
  250. assert len(time_list) == iters
  251. avg = np.mean(time_list)
  252. operation_name = name.split("@")[0]
  253. location = name.split("@")[1]
  254. if ":" not in location:
  255. location = "Unknown"
  256. detail_list.append(
  257. (iters, batch_size, instances, operation_name, avg, avg / batch_size)
  258. )
  259. operation_list.append((operation_name, location))
  260. if operation_name in _inference_operations:
  261. summary["inference"] += avg
  262. op_tag = "postprocessing"
  263. else:
  264. summary[op_tag] += avg
  265. summary["core"] = (
  266. summary["preprocessing"] + summary["inference"] + summary["postprocessing"]
  267. )
  268. summary["other"] = summary["end_to_end"] - summary["core"]
  269. summary_list = [
  270. (
  271. iters,
  272. batch_size,
  273. instances,
  274. "Preprocessing",
  275. summary["preprocessing"],
  276. summary["preprocessing"] / batch_size,
  277. ),
  278. (
  279. iters,
  280. batch_size,
  281. instances,
  282. "Inference",
  283. summary["inference"],
  284. summary["inference"] / batch_size,
  285. ),
  286. (
  287. iters,
  288. batch_size,
  289. instances,
  290. "Postprocessing",
  291. summary["postprocessing"],
  292. summary["postprocessing"] / batch_size,
  293. ),
  294. (
  295. iters,
  296. batch_size,
  297. instances,
  298. "Core",
  299. summary["core"],
  300. summary["core"] / batch_size,
  301. ),
  302. (
  303. iters,
  304. batch_size,
  305. instances,
  306. "Other",
  307. summary["other"],
  308. summary["other"] / batch_size,
  309. ),
  310. (
  311. iters,
  312. batch_size,
  313. instances,
  314. "End-to-End",
  315. summary["end_to_end"],
  316. summary["end_to_end"] / batch_size,
  317. ),
  318. ]
  319. return detail_list, summary_list, operation_list
  320. def collect(self, batch_size):
  321. detail_list, summary_list, operation_list = self.gather(batch_size)
  322. if self._warmup:
  323. summary_head = [
  324. "Iters",
  325. "Batch Size",
  326. "Instances",
  327. "Type",
  328. "Avg Time Per Iter (ms)",
  329. "Avg Time Per Instance (ms)",
  330. ]
  331. table = PrettyTable(summary_head)
  332. summary_list = [
  333. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  334. ]
  335. table.add_rows(summary_list)
  336. table_title = "Warmup Data".center(len(str(table).split("\n")[0]), " ")
  337. logging.info(table_title)
  338. logging.info(table)
  339. else:
  340. operation_head = [
  341. "Operation",
  342. "Source Code Location",
  343. ]
  344. table = PrettyTable(operation_head)
  345. table.add_rows(operation_list)
  346. table_title = "Operation Info".center(len(str(table).split("\n")[0]), " ")
  347. logging.info(table_title)
  348. logging.info(table)
  349. detail_head = [
  350. "Iters",
  351. "Batch Size",
  352. "Instances",
  353. "Operation",
  354. "Avg Time Per Iter (ms)",
  355. "Avg Time Per Instance (ms)",
  356. ]
  357. table = PrettyTable(detail_head)
  358. detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
  359. table.add_rows(detail_list)
  360. table_title = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  361. logging.info(table_title)
  362. logging.info(table)
  363. summary_head = [
  364. "Iters",
  365. "Batch Size",
  366. "Instances",
  367. "Type",
  368. "Avg Time Per Iter (ms)",
  369. "Avg Time Per Instance (ms)",
  370. ]
  371. table = PrettyTable(summary_head)
  372. summary_list = [
  373. i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
  374. ]
  375. table.add_rows(summary_list)
  376. table_title = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  377. logging.info(table_title)
  378. logging.info(table)
  379. if INFER_BENCHMARK_OUTPUT_DIR:
  380. save_dir = Path(INFER_BENCHMARK_OUTPUT_DIR)
  381. save_dir.mkdir(parents=True, exist_ok=True)
  382. csv_data = [detail_head, *detail_list]
  383. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  384. writer = csv.writer(file)
  385. writer.writerows(csv_data)
  386. csv_data = [summary_head, *summary_list]
  387. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  388. writer = csv.writer(file)
  389. writer.writerows(csv_data)
  390. def gather_pipeline(self):
  391. info_list = []
  392. detail_list = []
  393. operation_list = set()
  394. summary_list = []
  395. max_level = 0
  396. loop_num = 0
  397. for name, time_list in self.logs.items():
  398. op_time = np.sum(time_list)
  399. parts = name.split("@")
  400. step = int(parts[0])
  401. level = int(parts[1])
  402. operation_name = parts[2]
  403. location = parts[3]
  404. if ":" not in location:
  405. location = "Unknown"
  406. operation_list.add((operation_name, location))
  407. max_level = max(level, max_level)
  408. if level == 1:
  409. loop_num += 1
  410. format_operation_name = operation_name
  411. else:
  412. format_operation_name = " " * int(level - 1) + "-> " + operation_name
  413. info_list.append(
  414. (step, level, operation_name, format_operation_name, op_time)
  415. )
  416. operation_list = list(operation_list)
  417. info_list.sort(key=lambda x: x[0])
  418. step_num = int(len(info_list) / loop_num)
  419. for idx in range(step_num):
  420. step = info_list[idx][0]
  421. format_operation_name = info_list[idx][3]
  422. op_time = (
  423. np.sum(
  424. [info_list[pos][4] for pos in range(idx, len(info_list), step_num)]
  425. )
  426. / loop_num
  427. )
  428. detail_list.append([step, format_operation_name, op_time])
  429. level_time_list = [[0] for _ in range(max_level)]
  430. for idx, info in enumerate(info_list):
  431. step = info[0]
  432. level = info[1]
  433. operation_name = info[2]
  434. op_time = info[4]
  435. # The total time consumed by all operations on this layer
  436. if level > info_list[idx - 1][1]:
  437. level_time_list[level - 1].append(info_list[idx - 1][4])
  438. # The total time consumed by each operation on this layer
  439. while len(summary_list) < level:
  440. summary_list.append([len(summary_list) + 1, {}])
  441. if summary_list[level - 1][1].get(operation_name, None) is None:
  442. summary_list[level - 1][1][operation_name] = [op_time]
  443. else:
  444. summary_list[level - 1][1][operation_name].append(op_time)
  445. new_summary_list = []
  446. for i in range(len(summary_list)):
  447. level = summary_list[i][0]
  448. op_dict = summary_list[i][1]
  449. ops_all_time = 0.0
  450. op_info_list = []
  451. for idx, (name, time_list) in enumerate(op_dict.items()):
  452. op_all_time = np.sum(time_list) / loop_num
  453. op_info_list.append([level if i + idx == 0 else "", name, op_all_time])
  454. ops_all_time += op_all_time
  455. if i > 0:
  456. new_summary_list.append(["", "", ""])
  457. new_summary_list.append(
  458. [level, "Layer", np.sum(level_time_list[i]) / loop_num]
  459. )
  460. new_summary_list.append(["", "Core", ops_all_time])
  461. new_summary_list.append(
  462. ["", "Other", np.sum(level_time_list[i]) / loop_num - ops_all_time]
  463. )
  464. new_summary_list += op_info_list
  465. return detail_list, new_summary_list, operation_list
  466. def _initialize_pipeline_data(self):
  467. if not (self._operation_list and self._detail_list and self._summary_list):
  468. self._detail_list, self._summary_list, self._operation_list = (
  469. self.gather_pipeline()
  470. )
  471. def print_pipeline_data(self):
  472. self._initialize_pipeline_data()
  473. self.print_operation_info()
  474. self.print_detail_data()
  475. self.print_summary_data()
  476. def print_operation_info(self):
  477. self._initialize_pipeline_data()
  478. operation_head = [
  479. "Operation",
  480. "Source Code Location",
  481. ]
  482. table = PrettyTable(operation_head)
  483. table.add_rows(self._operation_list)
  484. table_title = "Operation Info".center(len(str(table).split("\n")[0]), " ")
  485. logging.info(table_title)
  486. logging.info(table)
  487. def print_detail_data(self):
  488. self._initialize_pipeline_data()
  489. detail_head = [
  490. "Step",
  491. "Operation",
  492. "Time",
  493. ]
  494. table = PrettyTable(detail_head)
  495. table.add_rows(self._detail_list)
  496. table_title = "Detail Data".center(len(str(table).split("\n")[0]), " ")
  497. table.align["Operation"] = "l"
  498. table.align["Time"] = "l"
  499. logging.info(table_title)
  500. logging.info(table)
  501. def print_summary_data(self):
  502. self._initialize_pipeline_data()
  503. summary_head = [
  504. "Level",
  505. "Operation",
  506. "Time",
  507. ]
  508. table = PrettyTable(summary_head)
  509. table.add_rows(self._summary_list)
  510. table_title = "Summary Data".center(len(str(table).split("\n")[0]), " ")
  511. table.align["Operation"] = "l"
  512. table.align["Time"] = "l"
  513. logging.info(table_title)
  514. logging.info(table)
  515. def save_pipeline_data(self, save_path):
  516. self._initialize_pipeline_data()
  517. save_dir = Path(save_path)
  518. save_dir.mkdir(parents=True, exist_ok=True)
  519. detail_head = [
  520. "Step",
  521. "Operation",
  522. "Time",
  523. ]
  524. csv_data = [detail_head, *self._detail_list]
  525. with open(Path(save_dir) / "detail.csv", "w", newline="") as file:
  526. writer = csv.writer(file)
  527. writer.writerows(csv_data)
  528. summary_head = [
  529. "Level",
  530. "Operation",
  531. "Time",
  532. ]
  533. csv_data = [summary_head, *self._summary_list]
  534. with open(Path(save_dir) / "summary.csv", "w", newline="") as file:
  535. writer = csv.writer(file)
  536. writer.writerows(csv_data)
  537. def get_inference_operations():
  538. return _inference_operations
  539. def set_inference_operations(val):
  540. global _inference_operations
  541. _inference_operations = val
  542. if INFER_BENCHMARK or PIPELINE_BENCHMARK:
  543. benchmark = Benchmark(enabled=True)
  544. else:
  545. benchmark = Benchmark(enabled=False)