|
|
@@ -207,24 +207,7 @@ class Predictor(object):
|
|
|
|
|
|
return net_outputs
|
|
|
|
|
|
- def predict(self, img_file, topk=1, transforms=None):
|
|
|
- """ 图片预测
|
|
|
-
|
|
|
- Args:
|
|
|
- img_file(List[np.ndarray or str], str or np.ndarray):
|
|
|
- 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
|
|
|
- topk(int): 分类预测时使用,表示预测前topk的结果。
|
|
|
- transforms (paddlex.transforms): 数据预处理操作。
|
|
|
- """
|
|
|
- if transforms is None and not hasattr(self._model, 'test_transforms'):
|
|
|
- raise Exception("Transforms need to be defined, now is None.")
|
|
|
- if transforms is None:
|
|
|
- transforms = self._model.test_transforms
|
|
|
- if isinstance(img_file, (str, np.ndarray)):
|
|
|
- images = [img_file]
|
|
|
- else:
|
|
|
- images = img_file
|
|
|
-
|
|
|
+ def _run(self, images, topk=1, transforms=None, repeats=1, verbose=False):
|
|
|
self.timer.preprocess_time_s.start()
|
|
|
preprocessed_input = self.preprocess(images, transforms)
|
|
|
self.timer.preprocess_time_s.end()
|
|
|
@@ -235,14 +218,17 @@ class Predictor(object):
|
|
|
logging.warning(
|
|
|
"{} only supports inference with batch size equal to 1."
|
|
|
.format(self._model.__class__.__name__))
|
|
|
- net_outputs = [
|
|
|
- self.raw_predict(sample) for sample in preprocessed_input
|
|
|
- ]
|
|
|
- self.timer.inference_time_s.end(repeats=len(preprocessed_input))
|
|
|
+ for step in range(repeats):
|
|
|
+ net_outputs = [
|
|
|
+ self.raw_predict(sample) for sample in preprocessed_input
|
|
|
+ ]
|
|
|
+ self.timer.inference_time_s.end(repeats=len(preprocessed_input) *
|
|
|
+ repeats)
|
|
|
ori_shape = None
|
|
|
else:
|
|
|
- net_outputs = self.raw_predict(preprocessed_input)
|
|
|
- self.timer.inference_time_s.end()
|
|
|
+ for step in range(repeats):
|
|
|
+ net_outputs = self.raw_predict(preprocessed_input)
|
|
|
+ self.timer.inference_time_s.end(repeats=repeats)
|
|
|
ori_shape = preprocessed_input.get('ori_shape', None)
|
|
|
|
|
|
self.timer.postprocess_time_s.start()
|
|
|
@@ -251,6 +237,35 @@ class Predictor(object):
|
|
|
self.timer.postprocess_time_s.end()
|
|
|
|
|
|
self.timer.img_num = len(images)
|
|
|
- self.timer.info(average=True)
|
|
|
+ if verbose:
|
|
|
+ self.timer.info(average=True)
|
|
|
|
|
|
return results
|
|
|
+
|
|
|
+ def predict(self,
|
|
|
+ img_file,
|
|
|
+ topk=1,
|
|
|
+ transforms=None,
|
|
|
+ warmup_iters=0,
|
|
|
+ repeats=1):
|
|
|
+ """ 图片预测
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img_file(List[np.ndarray or str], str or np.ndarray):
|
|
|
+ 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
|
|
|
+ topk(int): 分类预测时使用,表示预测前topk的结果。
|
|
|
+ transforms (paddlex.transforms): 数据预处理操作。
|
|
|
+ """
|
|
|
+ if transforms is None and not hasattr(self._model, 'test_transforms'):
|
|
|
+ raise Exception("Transforms need to be defined, now is None.")
|
|
|
+ if transforms is None:
|
|
|
+ transforms = self._model.test_transforms
|
|
|
+ if isinstance(img_file, (str, np.ndarray)):
|
|
|
+ images = [img_file]
|
|
|
+ else:
|
|
|
+ images = img_file
|
|
|
+
|
|
|
+ for step in range(warmup_iters):
|
|
|
+ self._run(
|
|
|
+ images=images, topk=topk, transforms=transforms, verbose=False)
|
|
|
+ self.timer.reset()
|