فهرست منبع

use program cache in predict

jiangjiajun 5 سال پیش
والد
کامیت
29d7de56ff

+ 2 - 1
paddlex/cv/models/classifier.py

@@ -264,7 +264,8 @@ class BaseClassifier(BaseAPI):
             im = self.test_transforms(img_file)
         result = self.exe.run(self.test_prog,
                               feed={'image': im},
-                              fetch_list=list(self.test_outputs.values()))
+                              fetch_list=list(self.test_outputs.values()),
+                              use_program_cache=True)
         pred_label = np.argsort(result[0][0])[::-1][:true_topk]
         res = [{
             'category_id': l,

+ 2 - 1
paddlex/cv/models/deeplabv3p.py

@@ -398,7 +398,8 @@ class DeepLabv3p(BaseAPI):
         im = np.expand_dims(im, axis=0)
         result = self.exe.run(self.test_prog,
                               feed={'image': im},
-                              fetch_list=list(self.test_outputs.values()))
+                              fetch_list=list(self.test_outputs.values()),
+                              use_program_cache=True)
         pred = result[0]
         pred = np.squeeze(pred).astype('uint8')
         logit = result[1]

+ 2 - 1
paddlex/cv/models/faster_rcnn.py

@@ -389,7 +389,8 @@ class FasterRCNN(BaseAPI):
                                    'im_shape': im_shape
                                },
                                fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False)
+                               return_numpy=False,
+                               use_program_cache=True)
         res = {
             k: (np.array(v), v.recursive_sequence_lengths())
             for k, v in zip(list(self.test_outputs.keys()), outputs)

+ 2 - 1
paddlex/cv/models/mask_rcnn.py

@@ -357,7 +357,8 @@ class MaskRCNN(FasterRCNN):
                                    'im_shape': im_shape
                                },
                                fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False)
+                               return_numpy=False,
+                               use_program_cache=True)
         res = {
             k: (np.array(v), v.recursive_sequence_lengths())
             for k, v in zip(list(self.test_outputs.keys()), outputs)

+ 2 - 1
paddlex/cv/models/yolo_v3.py

@@ -363,7 +363,8 @@ class YOLOv3(BaseAPI):
                                feed={'image': im,
                                      'im_size': im_size},
                                fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False)
+                               return_numpy=False,
+                               use_program_cache=True)
         res = {
             k: (np.array(v), v.recursive_sequence_lengths())
             for k, v in zip(list(self.test_outputs.keys()), outputs)