|
|
@@ -247,13 +247,16 @@ class Predictor:
|
|
|
[output_tensor.copy_to_cpu(), output_tensor_lod])
|
|
|
return output_results
|
|
|
|
|
|
- def predict(self, image, topk=1):
|
|
|
+ def predict(self, image, topk=1, transforms=None):
|
|
|
""" 图片预测
|
|
|
|
|
|
Args:
|
|
|
image(str|np.ndarray): 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
|
|
|
- topk(int): 分类预测时使用,表示预测前topk的结果
|
|
|
+ topk(int): 分类预测时使用,表示预测前topk的结果。
|
|
|
+ transforms (paddlex.cls.transforms): 数据预处理操作。
|
|
|
"""
|
|
|
+ if transforms is not None:
|
|
|
+ self.transforms = transforms
|
|
|
preprocessed_input = self.preprocess([image])
|
|
|
model_pred = self.raw_predict(preprocessed_input)
|
|
|
im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
|
|
|
@@ -269,15 +272,18 @@ class Predictor:
|
|
|
|
|
|
return results[0]
|
|
|
|
|
|
- def batch_predict(self, image_list, topk=1):
|
|
|
+ def batch_predict(self, image_list, topk=1, transforms=None):
|
|
|
""" 图片预测
|
|
|
|
|
|
Args:
|
|
|
image_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
|
|
|
也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
|
|
|
|
|
|
- topk(int): 分类预测时使用,表示预测前topk的结果
|
|
|
+ topk(int): 分类预测时使用,表示预测前topk的结果。
|
|
|
+ transforms (paddlex.cls.transforms): 数据预处理操作。
|
|
|
"""
|
|
|
+ if transforms is not None:
|
|
|
+ self.transforms = transforms
|
|
|
preprocessed_input = self.preprocess(image_list, self.thread_pool)
|
|
|
model_pred = self.raw_predict(preprocessed_input)
|
|
|
im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
|