Răsfoiți Sursa

add warmup for python deploy

will-jl944 4 ani în urmă
părinte
comite
0bd41d9bb5
2 a modificat fișierele cu 46 adăugiri și 25 ștergeri
  1. 40 25
      paddlex/deploy.py
  2. 6 0
      paddlex/utils/utils.py

+ 40 - 25
paddlex/deploy.py

@@ -207,24 +207,7 @@ class Predictor(object):
 
         return net_outputs
 
-    def predict(self, img_file, topk=1, transforms=None):
-        """ 图片预测
-
-            Args:
-                img_file(List[np.ndarray or str], str or np.ndarray):
-                    图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
-                topk(int): 分类预测时使用,表示预测前topk的结果。
-                transforms (paddlex.transforms): 数据预处理操作。
-        """
-        if transforms is None and not hasattr(self._model, 'test_transforms'):
-            raise Exception("Transforms need to be defined, now is None.")
-        if transforms is None:
-            transforms = self._model.test_transforms
-        if isinstance(img_file, (str, np.ndarray)):
-            images = [img_file]
-        else:
-            images = img_file
-
+    def _run(self, images, topk=1, transforms=None, repeats=1, verbose=False):
         self.timer.preprocess_time_s.start()
         preprocessed_input = self.preprocess(images, transforms)
         self.timer.preprocess_time_s.end()
@@ -235,14 +218,17 @@ class Predictor(object):
                 logging.warning(
                     "{} only supports inference with batch size equal to 1."
                     .format(self._model.__class__.__name__))
-            net_outputs = [
-                self.raw_predict(sample) for sample in preprocessed_input
-            ]
-            self.timer.inference_time_s.end(repeats=len(preprocessed_input))
+            for step in range(repeats):
+                net_outputs = [
+                    self.raw_predict(sample) for sample in preprocessed_input
+                ]
+            self.timer.inference_time_s.end(repeats=len(preprocessed_input) *
+                                            repeats)
             ori_shape = None
         else:
-            net_outputs = self.raw_predict(preprocessed_input)
-            self.timer.inference_time_s.end()
+            for step in range(repeats):
+                net_outputs = self.raw_predict(preprocessed_input)
+            self.timer.inference_time_s.end(repeats=repeats)
             ori_shape = preprocessed_input.get('ori_shape', None)
 
         self.timer.postprocess_time_s.start()
@@ -251,6 +237,35 @@ class Predictor(object):
         self.timer.postprocess_time_s.end()
 
         self.timer.img_num = len(images)
-        self.timer.info(average=True)
+        if verbose:
+            self.timer.info(average=True)
 
         return results
+
+    def predict(self,
+                img_file,
+                topk=1,
+                transforms=None,
+                warmup_iters=0,
+                repeats=1):
+        """ 图片预测
+
+            Args:
+                img_file(List[np.ndarray or str], str or np.ndarray):
+                    图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
+                topk(int): 分类预测时使用,表示预测前topk的结果。
+                transforms (paddlex.transforms): 数据预处理操作。
+        """
+        if transforms is None and not hasattr(self._model, 'test_transforms'):
+            raise Exception("Transforms need to be defined, now is None.")
+        if transforms is None:
+            transforms = self._model.test_transforms
+        if isinstance(img_file, (str, np.ndarray)):
+            images = [img_file]
+        else:
+            images = img_file
+
+        for step in range(warmup_iters):
+            self._run(
+                images=images, topk=topk, transforms=transforms, verbose=False)
+        self.timer.reset()

+ 6 - 0
paddlex/utils/utils.py

@@ -216,3 +216,9 @@ class Timer(Times):
         ) + self.inference_time_s.value() + self.postprocess_time_s.value()
         dic['total_time_s'] = round(total_time, 4)
         return dic
+
+    def reset(self):
+        self.preprocess_time_s.reset()
+        self.inference_time_s.reset()
+        self.postprocess_time_s.reset()
+        self.img_num = 0