|
|
@@ -17,6 +17,7 @@ import functools
|
|
|
from types import GeneratorType
|
|
|
import time
|
|
|
from pathlib import Path
|
|
|
+import inspect
|
|
|
import numpy as np
|
|
|
from prettytable import PrettyTable
|
|
|
|
|
|
@@ -51,17 +52,32 @@ class Benchmark:
|
|
|
raise TypeError
|
|
|
func = func_or_cls
|
|
|
|
|
|
+ location = None
|
|
|
+
|
|
|
@functools.wraps(func)
|
|
|
def _wrapper(*args, **kwargs):
|
|
|
+ nonlocal location
|
|
|
+
|
|
|
if not self._enabled:
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
+ if location is None:
|
|
|
+ try:
|
|
|
+ source_file = inspect.getsourcefile(func)
|
|
|
+ source_line = inspect.getsourcelines(func)[1]
|
|
|
+ location = f"{source_file}:{source_line}"
|
|
|
+ except (TypeError, OSError) as e:
|
|
|
+ location = "Unknown"
|
|
|
+ logging.debug(
|
|
|
+ f"Benchmark: failed to get source file and line number: {e}"
|
|
|
+ )
|
|
|
+
|
|
|
tic = time.perf_counter()
|
|
|
output = func(*args, **kwargs)
|
|
|
if isinstance(output, GeneratorType):
|
|
|
- return self.watch_generator(output, name)
|
|
|
+ return self.watch_generator(output, f"{name}@{location}")
|
|
|
else:
|
|
|
- self._update(time.perf_counter() - tic, name)
|
|
|
+ self._update(time.perf_counter() - tic, f"{name}@{location}")
|
|
|
return output
|
|
|
|
|
|
if isinstance(func_or_cls, type):
|
|
|
@@ -133,22 +149,29 @@ class Benchmark:
|
|
|
logs = {k: v for k, v in self.logs.items()}
|
|
|
|
|
|
summary = {"preprocessing": 0, "inference": 0, "postprocessing": 0}
|
|
|
- base_predictor_time_list = logs.pop(ENTRY_POINT_NAME)
|
|
|
+ for key in logs:
|
|
|
+ if key.startswith(f"{ENTRY_POINT_NAME}@"):
|
|
|
+ base_predictor_time_list = logs.pop(key)
|
|
|
+ break
|
|
|
iters = len(base_predictor_time_list)
|
|
|
instances = iters * batch_size
|
|
|
summary["end_to_end"] = np.mean(base_predictor_time_list)
|
|
|
|
|
|
detail_list = []
|
|
|
+ operation_list = []
|
|
|
op_tag = "preprocessing"
|
|
|
|
|
|
for name, time_list in logs.items():
|
|
|
assert len(time_list) == iters
|
|
|
avg = np.mean(time_list)
|
|
|
+ operation_name = name.split("@")[0]
|
|
|
+ location = name.split("@")[1]
|
|
|
detail_list.append(
|
|
|
- (iters, batch_size, instances, name, avg, avg / batch_size)
|
|
|
+ (iters, batch_size, instances, operation_name, avg, avg / batch_size)
|
|
|
)
|
|
|
+ operation_list.append((operation_name, location))
|
|
|
|
|
|
- if name in _inference_operations:
|
|
|
+ if operation_name in _inference_operations:
|
|
|
summary["inference"] += avg
|
|
|
op_tag = "postprocessing"
|
|
|
else:
|
|
|
@@ -211,10 +234,10 @@ class Benchmark:
|
|
|
),
|
|
|
]
|
|
|
|
|
|
- return detail_list, summary_list
|
|
|
+ return detail_list, summary_list, operation_list
|
|
|
|
|
|
def collect(self, batch_size):
|
|
|
- detail_list, summary_list = self.gather(batch_size)
|
|
|
+ detail_list, summary_list, operation_list = self.gather(batch_size)
|
|
|
|
|
|
if self._warmup:
|
|
|
summary_head = [
|
|
|
@@ -230,11 +253,21 @@ class Benchmark:
|
|
|
i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
|
|
|
]
|
|
|
table.add_rows(summary_list)
|
|
|
- header = "WarmUp Data".center(len(str(table).split("\n")[0]), " ")
|
|
|
- logging.info(header)
|
|
|
+ table_name = "WarmUp Data".center(len(str(table).split("\n")[0]), " ")
|
|
|
+ logging.info(table_name)
|
|
|
logging.info(table)
|
|
|
|
|
|
else:
|
|
|
+ operation_head = [
|
|
|
+ "Operation",
|
|
|
+ "Source Code Location",
|
|
|
+ ]
|
|
|
+ table = PrettyTable(operation_head)
|
|
|
+ table.add_rows(operation_list)
|
|
|
+ table_name = "Operation Info".center(len(str(table).split("\n")[0]), " ")
|
|
|
+ logging.info(table_name)
|
|
|
+ logging.info(table)
|
|
|
+
|
|
|
detail_head = [
|
|
|
"Iters",
|
|
|
"Batch Size",
|
|
|
@@ -246,8 +279,8 @@ class Benchmark:
|
|
|
table = PrettyTable(detail_head)
|
|
|
detail_list = [i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in detail_list]
|
|
|
table.add_rows(detail_list)
|
|
|
- header = "Detail Data".center(len(str(table).split("\n")[0]), " ")
|
|
|
- logging.info(header)
|
|
|
+ table_name = "Detail Data".center(len(str(table).split("\n")[0]), " ")
|
|
|
+ logging.info(table_name)
|
|
|
logging.info(table)
|
|
|
|
|
|
summary_head = [
|
|
|
@@ -263,8 +296,8 @@ class Benchmark:
|
|
|
i[:4] + (f"{i[4]:.8f}", f"{i[5]:.8f}") for i in summary_list
|
|
|
]
|
|
|
table.add_rows(summary_list)
|
|
|
- header = "Summary Data".center(len(str(table).split("\n")[0]), " ")
|
|
|
- logging.info(header)
|
|
|
+ table_name = "Summary Data".center(len(str(table).split("\n")[0]), " ")
|
|
|
+ logging.info(table_name)
|
|
|
logging.info(table)
|
|
|
|
|
|
if INFER_BENCHMARK_OUTPUT_DIR:
|