浏览代码

add transforms in predict/bath_predict in deploy.py

FlyingQianMM 5 年之前
父节点
当前提交
8a23df1637
共有 1 个文件被更改,包括 10 次插入4 次删除
  1. 10 4
      paddlex/deploy.py

+ 10 - 4
paddlex/deploy.py

@@ -247,13 +247,16 @@ class Predictor:
                 [output_tensor.copy_to_cpu(), output_tensor_lod])
                 [output_tensor.copy_to_cpu(), output_tensor_lod])
         return output_results
         return output_results
 
 
-    def predict(self, image, topk=1):
+    def predict(self, image, topk=1, transforms=None):
         """ 图片预测
         """ 图片预测
 
 
             Args:
             Args:
                 image(str|np.ndarray): 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
                 image(str|np.ndarray): 图像路径;或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
-                topk(int): 分类预测时使用,表示预测前topk的结果
+                topk(int): 分类预测时使用,表示预测前topk的结果。
+                transforms (paddlex.cls.transforms): 数据预处理操作。
         """
         """
+        if transforms is not None:
+            self.transforms = transforms
         preprocessed_input = self.preprocess([image])
         preprocessed_input = self.preprocess([image])
         model_pred = self.raw_predict(preprocessed_input)
         model_pred = self.raw_predict(preprocessed_input)
         im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
         im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
@@ -269,15 +272,18 @@ class Predictor:
 
 
         return results[0]
         return results[0]
 
 
-    def batch_predict(self, image_list, topk=1):
+    def batch_predict(self, image_list, topk=1, transforms=None):
         """ 图片预测
         """ 图片预测
 
 
             Args:
             Args:
                 image_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
                 image_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
                     也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
                     也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
 
 
-                topk(int): 分类预测时使用,表示预测前topk的结果
+                topk(int): 分类预测时使用,表示预测前topk的结果。
+                transforms (paddlex.cls.transforms): 数据预处理操作。
         """
         """
+        if transforms is not None:
+            self.transforms = transforms
         preprocessed_input = self.preprocess(image_list, self.thread_pool)
         preprocessed_input = self.preprocess(image_list, self.thread_pool)
         model_pred = self.raw_predict(preprocessed_input)
         model_pred = self.raw_predict(preprocessed_input)
         im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
         im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[