Bläddra i källkod

support warmup elapse

gaotingquan 1 år sedan
förälder
incheckning
18ba0d79f4
2 ändrade filer med 70 tillägg och 23 borttagningar
  1. 2 1
      paddlex/inference/models/base/basic_predictor.py
  2. 68 22
      paddlex/inference/utils/benchmark.py

+ 2 - 1
paddlex/inference/models/base/basic_predictor.py

@@ -54,11 +54,12 @@ class BasicPredictor(
     def __call__(self, input, **kwargs):
         self.set_predictor(**kwargs)
         if self.benchmark:
+            self.benchmark.start()
             if INFER_BENCHMARK_WARMUP > 0:
                 output = super().__call__(input)
                 for _ in range(INFER_BENCHMARK_WARMUP):
                     next(output)
-            self.benchmark.reset()
+                self.benchmark.warmup_stop(INFER_BENCHMARK_WARMUP)
             output = list(super().__call__(input))
             self.benchmark.collect(len(output))
         else:

+ 68 - 22
paddlex/inference/utils/benchmark.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import csv
 import functools
 from types import GeneratorType
 import time
@@ -25,10 +26,22 @@ from ...utils import logging
 class Benchmark:
     def __init__(self, components):
         self._components = components
+        self._warmup_start = None
+        self._warmup_elapse = None
+        self._warmup_num = None
         self._e2e_tic = None
         self._e2e_elapse = None
 
-    def reset(self):
+    def start(self):
+        self._warmup_start = time.time()
+        self._reset()
+
+    def warmup_stop(self, warmup_num):
+        self._warmup_elapse = time.time() - self._warmup_start
+        self._warmup_num = warmup_num
+        self._reset()
+
+    def _reset(self):
         for name in self._components:
             cmp = self._components[name]
             cmp.timer.reset()
@@ -47,7 +60,7 @@ class Benchmark:
             counts = len(times)
             avg = np.mean(times)
             total = np.sum(times)
-            detail.append((name, counts, avg))
+            detail.append((name, total, counts, avg))
             if isinstance(cmp, BasePaddlePredictor):
                 summary["inference"] += total
                 op_tag = "postprocess"
@@ -55,42 +68,75 @@ class Benchmark:
                 summary[op_tag] += total
 
         summary = [
-            ("PreProcess", e2e_num, summary["preprocess"] / e2e_num),
-            ("Inference", e2e_num, summary["inference"] / e2e_num),
-            ("PostProcess", e2e_num, summary["postprocess"] / e2e_num),
-            ("End2End", e2e_num, self._e2e_elapse / e2e_num),
+            (
+                "PreProcess",
+                summary["preprocess"],
+                e2e_num,
+                summary["preprocess"] / e2e_num,
+            ),
+            (
+                "Inference",
+                summary["inference"],
+                e2e_num,
+                summary["inference"] / e2e_num,
+            ),
+            (
+                "PostProcess",
+                summary["postprocess"],
+                e2e_num,
+                summary["postprocess"] / e2e_num,
+            ),
+            ("End2End", self._e2e_elapse, e2e_num, self._e2e_elapse / e2e_num),
         ]
+        if self._warmup_elapse:
+            summary.append(
+                (
+                    "WarmUp",
+                    self._warmup_elapse,
+                    self._warmup_num,
+                    self._warmup_elapse / self._warmup_num,
+                )
+            )
         return detail, summary
 
     def collect(self, e2e_num):
         self._e2e_elapse = time.time() - self._e2e_tic
         detail, summary = self.gather(e2e_num)
 
-        table = PrettyTable(["Component", "Call Counts", "Avg Time Per Call (ms)"])
+        table = PrettyTable(
+            ["Component", "Total Time (ms)", "Call Counts", "Avg Time Per Call (ms)"]
+        )
         table.add_rows(
-            [(name, cnts, f"{avg * 1000:.8f}") for name, cnts, avg in detail]
+            [
+                (name, f"{total * 1000:.8f}", cnts, f"{avg * 1000:.8f}")
+                for name, total, cnts, avg in detail
+            ]
         )
         logging.info(table)
 
-        table = PrettyTable(["Stage", "Num of Instances", "Avg Time Per Instance (ms)"])
+        table = PrettyTable(
+            [
+                "Stage",
+                "Total Time (ms)",
+                "Num of Instances",
+                "Avg Time Per Instance (ms)",
+            ]
+        )
         table.add_rows(
-            [(name, cnts, f"{avg * 1000:.8f}") for name, cnts, avg in summary]
+            [
+                (name, f"{total * 1000:.8f}", cnts, f"{avg * 1000:.8f}")
+                for name, total, cnts, avg in summary
+            ]
         )
         logging.info(table)
 
         if INFER_BENCHMARK_OUTPUT:
-            str_ = "Component, Call Counts, Avg Time Per Call (ms)\n"
-            str_ += "\n".join(
-                [f"{name}, {cnts}, {avg * 1000:.18f}" for name, cnts, avg in detail]
-            )
-            str_ += "\n" + "*" * 100 + "\n"
-            str_ += "Stage, Num of Instances, Avg Time Per Instance (ms)\n"
-            str_ += "\n".join(
-                [f"{name}, {cnts}, {avg * 1000:.18f}" for name, cnts, avg in summary]
-            )
-
-            with open(INFER_BENCHMARK_OUTPUT, "w") as f:
-                f.write(str_)
+            csv_data = [["Stage", "Total Time", "Num", "Avg Time"]]
+            csv_data.extend(detail)
+            csv_data.extend(summary)
+            with open("benchmark.csv", "w", newline="") as file:
+                writer = csv.writer(file)
+                writer.writerows(csv_data)
 
 
 class Timer: