浏览代码

add warmup for python deploy

will-jl944 4 年之前
父节点
当前提交
077bc33bbf
共有 2 个文件被更改,包括 43 次插入33 次删除
  1. 23 17
      paddlex/deploy.py
  2. 20 16
      paddlex/utils/utils.py

+ 23 - 17
paddlex/deploy.py

@@ -207,38 +207,31 @@ class Predictor(object):
 
         return net_outputs
 
-    def _run(self, images, topk=1, transforms=None, repeats=1, verbose=False):
+    def _run(self, images, topk=1, transforms=None):
         self.timer.preprocess_time_s.start()
         preprocessed_input = self.preprocess(images, transforms)
-        self.timer.preprocess_time_s.end()
+        self.timer.preprocess_time_s.end(iter_num=len(images))
 
+        ori_shape = None
         self.timer.inference_time_s.start()
         if 'RCNN' in self._model.__class__.__name__:
             if len(preprocessed_input) > 1:
                 logging.warning(
                     "{} only supports inference with batch size equal to 1."
                     .format(self._model.__class__.__name__))
-            for step in range(repeats):
-                net_outputs = [
-                    self.raw_predict(sample) for sample in preprocessed_input
-                ]
-            self.timer.inference_time_s.end(repeats=len(preprocessed_input) *
-                                            repeats)
-            ori_shape = None
+            net_outputs = [
+                self.raw_predict(sample) for sample in preprocessed_input
+            ]
+            self.timer.inference_time_s.end(iter_num=len(images))
         else:
-            for step in range(repeats):
-                net_outputs = self.raw_predict(preprocessed_input)
-            self.timer.inference_time_s.end(repeats=repeats)
+            net_outputs = self.raw_predict(preprocessed_input)
+            self.timer.inference_time_s.end(iter_num=1)
             ori_shape = preprocessed_input.get('ori_shape', None)
 
         self.timer.postprocess_time_s.start()
         results = self.postprocess(
             net_outputs, topk, ori_shape=ori_shape, transforms=transforms)
-        self.timer.postprocess_time_s.end()
-
-        self.timer.img_num = len(images)
-        if verbose:
-            self.timer.info(average=True)
+        self.timer.postprocess_time_s.end(iter_num=len(images))
 
         return results
 
@@ -255,7 +248,11 @@ class Predictor(object):
                     图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
                 topk(int): 分类预测时使用,表示预测前topk的结果。
                 transforms (paddlex.transforms): 数据预处理操作。
+                warmup_iters (int): 预热轮数,默认为0。
+                repeats (int): 重复次数,用于评估模型推理以及前后处理速度。若大于1,会预测repeats次取时间平均值。
         """
+        if repeats < 1:
+            logging.error("`repeats` must be greater than 1.", exit=True)
         if transforms is None and not hasattr(self._model, 'test_transforms'):
             raise Exception("Transforms need to be defined, now is None.")
         if transforms is None:
@@ -269,3 +266,12 @@ class Predictor(object):
             self._run(
                 images=images, topk=topk, transforms=transforms, verbose=False)
         self.timer.reset()
+
+        for step in range(repeats):
+            results = self._run(
+                images=images, topk=topk, transforms=transforms)
+
+        self.timer.repeats = repeats
+        self.timer.info(average=True)
+
+        return results

+ 20 - 16
paddlex/utils/utils.py

@@ -152,12 +152,12 @@ class Times(object):
     def start(self):
         self.st = time.time()
 
-    def end(self, repeats=1, accumulative=True):
+    def end(self, iter_num=1, accumulative=True):
         self.et = time.time()
         if accumulative:
-            self.time += (self.et - self.st) / repeats
+            self.time += (self.et - self.st) / iter_num
         else:
-            self.time = (self.et - self.st) / repeats
+            self.time = (self.et - self.st) / iter_num
 
     def reset(self):
         self.time = 0.
@@ -175,46 +175,49 @@ class Timer(Times):
         self.inference_time_s = Times()
         self.postprocess_time_s = Times()
         self.img_num = 0
+        self.repeats = 0
 
     def info(self, average=False):
         total_time = self.preprocess_time_s.value(
         ) + self.inference_time_s.value() + self.postprocess_time_s.value()
         total_time = round(total_time, 4)
         print("------------------ Inference Time Info ----------------------")
-        print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
-                                                       self.img_num))
+        print("total_time(ms): {}, img_num: {}, batch_size: {}".format(
+            total_time * 1000, self.img_num, self.img_num / self.repeats))
         preprocess_time = round(
-            self.preprocess_time_s.value() / self.img_num,
+            self.preprocess_time_s.value() / self.repeats,
             4) if average else self.preprocess_time_s.value()
         postprocess_time = round(
-            self.postprocess_time_s.value() / self.img_num,
+            self.postprocess_time_s.value() / self.repeats,
             4) if average else self.postprocess_time_s.value()
-        inference_time = round(self.inference_time_s.value() / self.img_num,
+        inference_time = round(self.inference_time_s.value() / self.repeats,
                                4) if average else self.inference_time_s.value()
 
-        average_latency = total_time / self.img_num
+        average_latency = total_time / self.repeats
         print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
             average_latency * 1000, 1 / average_latency))
-        print(
-            "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
-            format(preprocess_time * 1000, inference_time * 1000,
-                   postprocess_time * 1000))
+        print("preprocess_time_per_im(ms): {:.2f}, "
+              "inference_time_per_batch(ms): {:.2f}, "
+              "postprocess_time_per_im(ms): {:.2f}".format(
+                  preprocess_time * 1000, inference_time * 1000,
+                  postprocess_time * 1000))
 
     def report(self, average=False):
         dic = {}
         dic['preprocess_time_s'] = round(
-            self.preprocess_time_s.value() / self.img_num,
+            self.preprocess_time_s.value() / self.repeats,
             4) if average else self.preprocess_time_s.value()
         dic['postprocess_time_s'] = round(
-            self.postprocess_time_s.value() / self.img_num,
+            self.postprocess_time_s.value() / self.repeats,
             4) if average else self.postprocess_time_s.value()
         dic['inference_time_s'] = round(
-            self.inference_time_s.value() / self.img_num,
+            self.inference_time_s.value() / self.repeats,
             4) if average else self.inference_time_s.value()
         dic['img_num'] = self.img_num
         total_time = self.preprocess_time_s.value(
         ) + self.inference_time_s.value() + self.postprocess_time_s.value()
         dic['total_time_s'] = round(total_time, 4)
+        dic['batch_size'] = self.img_num / self.repeats
         return dic
 
     def reset(self):
@@ -222,3 +225,4 @@ class Timer(Times):
         self.inference_time_s.reset()
         self.postprocess_time_s.reset()
         self.img_num = 0
+        self.repeats = 0