Quellcode durchsuchen

add scope for model

jiangjiajun vor 5 Jahren
Ursprung
Commit
3cb494f3d2

+ 0 - 38
docs/apis/deploy.md

@@ -1,38 +0,0 @@
-# 预测部署-paddlex.deploy
-
-使用Paddle Inference进行高性能的Python预测部署。更多关于Paddle Inference信息请参考[Paddle Inference文档](https://paddle-inference.readthedocs.io/en/latest/#)
-
-## Predictor类
-
-```
-paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, use_trt=False, use_glog=False, memory_optimize=True)
-```
-
-> **参数**
-
-> > * **model_dir**: 训练过程中保存的模型路径, 注意需要使用导出的inference模型
-> > * **use_gpu**: 是否使用GPU进行预测
-> > * **gpu_id**: 使用的GPU序列号
-> > * **use_mkl**: 是否使用mkldnn加速库
-> > * **use_trt**: 是否使用TensorRT预测引擎
-> > * **use_glog**: 是否打印中间日志
-> > * **memory_optimize**: 是否优化内存使用
-
-> > ### 示例
-> >
-> > ```
-> > import paddlex
-> >
-> > model = paddlex.deploy.Predictor(model_dir, use_gpu=True)
-> > result = model.predict(image_file)
-> > ```
-
-### predict 接口
-> ```
-> predict(image, topk=1)
-> ```
-
-> **参数
-
-* **image(str|np.ndarray)**: 待预测的图片路径或np.ndarray,若为后者需注意为BGR格式
-* **topk(int)**: 图像分类时使用的参数,表示预测前topk个可能的分类

+ 0 - 1
docs/apis/index.rst

@@ -10,4 +10,3 @@ API接口说明
    slim.md
    load_model.md
    visualize.md
-   deploy.md

+ 14 - 14
paddlex/cv/models/base.py

@@ -73,6 +73,7 @@ class BaseAPI:
         self.status = 'Normal'
         # 已完成迭代轮数,为恢复训练时的起始轮数
         self.completed_epochs = 0
+        self.scope = fluid.global_scope()
 
     def _get_single_card_bs(self, batch_size):
         if batch_size % len(self.places) == 0:
@@ -84,6 +85,10 @@ class BaseAPI:
                                 'place']))
 
     def build_program(self):
+        if hasattr(paddlex, 'model_built') and paddlex.model_built:
+            logging.error(
+                "Function model.train() only can be called once in your code.")
+        paddlex.model_built = True
         # 构建训练网络
         self.train_inputs, self.train_outputs = self.build_net(mode='train')
         self.train_prog = fluid.default_main_program()
@@ -155,7 +160,7 @@ class BaseAPI:
             outputs=self.test_outputs,
             batch_size=batch_size,
             batch_nums=batch_num,
-            scope=None,
+            scope=self.scope,
             algo='KL',
             quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
             is_full_quantize=False,
@@ -244,8 +249,8 @@ class BaseAPI:
             logging.info(
                 "Load pretrain weights from {}.".format(pretrain_weights),
                 use_color=True)
-            paddlex.utils.utils.load_pretrain_weights(
-                self.exe, self.train_prog, pretrain_weights, fuse_bn)
+            paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
+                                                      pretrain_weights, fuse_bn)
         # 进行裁剪
         if sensitivities_file is not None:
             import paddleslim
@@ -349,9 +354,7 @@ class BaseAPI:
         logging.info("Model saved in {}.".format(save_dir))
 
     def export_inference_model(self, save_dir):
-        test_input_names = [
-            var.name for var in list(self.test_inputs.values())
-        ]
+        test_input_names = [var.name for var in list(self.test_inputs.values())]
         test_outputs = list(self.test_outputs.values())
         if self.__class__.__name__ == 'MaskRCNN':
             from paddlex.utils.save import save_mask_inference_model
@@ -388,8 +391,7 @@ class BaseAPI:
 
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
-        logging.info("Model for inference deploy saved in {}.".format(
-            save_dir))
+        logging.info("Model for inference deploy saved in {}.".format(save_dir))
 
     def train_loop(self,
                    num_epochs,
@@ -513,13 +515,11 @@ class BaseAPI:
                         eta = ((num_epochs - i) * total_num_steps - step - 1
                                ) * avg_step_time
                     if time_eval_one_epoch is not None:
-                        eval_eta = (
-                            total_eval_times - i // save_interval_epochs
-                        ) * time_eval_one_epoch
+                        eval_eta = (total_eval_times - i // save_interval_epochs
+                                    ) * time_eval_one_epoch
                     else:
-                        eval_eta = (
-                            total_eval_times - i // save_interval_epochs
-                        ) * total_num_steps_eval * avg_step_time
+                        eval_eta = (total_eval_times - i // save_interval_epochs
+                                    ) * total_num_steps_eval * avg_step_time
                     eta_str = seconds_to_hms(eta + eval_eta)
 
                     logging.info(

+ 17 - 13
paddlex/cv/models/classifier.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -227,9 +227,10 @@ class BaseClassifier(BaseAPI):
         true_labels = list()
         pred_scores = list()
         if not hasattr(self, 'parallel_test_prog'):
-            self.parallel_test_prog = fluid.CompiledProgram(
-                self.test_prog).with_data_parallel(
-                    share_vars_from=self.parallel_train_prog)
+            with fluid.scope_guard(self.scope):
+                self.parallel_test_prog = fluid.CompiledProgram(
+                    self.test_prog).with_data_parallel(
+                        share_vars_from=self.parallel_train_prog)
         batch_size_each_gpu = self._get_single_card_bs(batch_size)
         logging.info("Start to evaluating(total_samples={}, total_steps={})...".
                      format(eval_dataset.num_samples, total_steps))
@@ -242,9 +243,11 @@ class BaseClassifier(BaseAPI):
                 num_pad_samples = batch_size - num_samples
                 pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
                 images = np.concatenate([images, pad_images])
-            outputs = self.exe.run(self.parallel_test_prog,
-                                   feed={'image': images},
-                                   fetch_list=list(self.test_outputs.values()))
+            with fluid.scope_guard(self.scope):
+                outputs = self.exe.run(
+                    self.parallel_test_prog,
+                    feed={'image': images},
+                    fetch_list=list(self.test_outputs.values()))
             outputs = [outputs[0][:num_samples]]
             true_labels.extend(labels)
             pred_scores.extend(outputs[0].tolist())
@@ -286,10 +289,11 @@ class BaseClassifier(BaseAPI):
             self.arrange_transforms(
                 transforms=self.test_transforms, mode='test')
             im = self.test_transforms(img_file)
-        result = self.exe.run(self.test_prog,
-                              feed={'image': im},
-                              fetch_list=list(self.test_outputs.values()),
-                              use_program_cache=True)
+        with fluid.scope_guard(self.scope):
+            result = self.exe.run(self.test_prog,
+                                  feed={'image': im},
+                                  fetch_list=list(self.test_outputs.values()),
+                                  use_program_cache=True)
         pred_label = np.argsort(result[0][0])[::-1][:true_topk]
         res = [{
             'category_id': l,

+ 21 - 19
paddlex/cv/models/deeplabv3p.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -317,19 +317,18 @@ class DeepLabv3p(BaseAPI):
             tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
                 包含关键字:'confusion_matrix',表示评估的混淆矩阵。
         """
-        self.arrange_transforms(
-            transforms=eval_dataset.transforms, mode='eval')
+        self.arrange_transforms(transforms=eval_dataset.transforms, mode='eval')
         total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
         conf_mat = ConfusionMatrix(self.num_classes, streaming=True)
         data_generator = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)
         if not hasattr(self, 'parallel_test_prog'):
-            self.parallel_test_prog = fluid.CompiledProgram(
-                self.test_prog).with_data_parallel(
-                    share_vars_from=self.parallel_train_prog)
-        logging.info(
-            "Start to evaluating(total_samples={}, total_steps={})...".format(
-                eval_dataset.num_samples, total_steps))
+            with fluid.scope_guard(self.scope):
+                self.parallel_test_prog = fluid.CompiledProgram(
+                    self.test_prog).with_data_parallel(
+                        share_vars_from=self.parallel_train_prog)
+        logging.info("Start to evaluating(total_samples={}, total_steps={})...".
+                     format(eval_dataset.num_samples, total_steps))
         for step, data in tqdm.tqdm(
                 enumerate(data_generator()), total=total_steps):
             images = np.array([d[0] for d in data])
@@ -350,10 +349,12 @@ class DeepLabv3p(BaseAPI):
                 pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
                 images = np.concatenate([images, pad_images])
             feed_data = {'image': images}
-            outputs = self.exe.run(self.parallel_test_prog,
-                                   feed=feed_data,
-                                   fetch_list=list(self.test_outputs.values()),
-                                   return_numpy=True)
+            with fluid.scope_guard(self.scope):
+                outputs = self.exe.run(
+                    self.parallel_test_prog,
+                    feed=feed_data,
+                    fetch_list=list(self.test_outputs.values()),
+                    return_numpy=True)
             pred = outputs[0]
             if num_samples < batch_size:
                 pred = pred[0:num_samples]
@@ -399,10 +400,11 @@ class DeepLabv3p(BaseAPI):
                 transforms=self.test_transforms, mode='test')
             im, im_info = self.test_transforms(im_file)
         im = np.expand_dims(im, axis=0)
-        result = self.exe.run(self.test_prog,
-                              feed={'image': im},
-                              fetch_list=list(self.test_outputs.values()),
-                              use_program_cache=True)
+        with fluid.scope_guard(self.scope):
+            result = self.exe.run(self.test_prog,
+                                  feed={'image': im},
+                                  fetch_list=list(self.test_outputs.values()),
+                                  use_program_cache=True)
         pred = result[0]
         pred = np.squeeze(pred).astype('uint8')
         logit = result[1]

+ 19 - 16
paddlex/cv/models/faster_rcnn.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -325,10 +325,12 @@ class FasterRCNN(BaseAPI):
                 'im_info': im_infos,
                 'im_shape': im_shapes,
             }
-            outputs = self.exe.run(self.test_prog,
-                                   feed=[feed_data],
-                                   fetch_list=list(self.test_outputs.values()),
-                                   return_numpy=False)
+            with fluid.scope_guard(self.scope):
+                outputs = self.exe.run(
+                    self.test_prog,
+                    feed=[feed_data],
+                    fetch_list=list(self.test_outputs.values()),
+                    return_numpy=False)
             res = {
                 'bbox': (np.array(outputs[0]),
                          outputs[0].recursive_sequence_lengths())
@@ -388,15 +390,16 @@ class FasterRCNN(BaseAPI):
         im = np.expand_dims(im, axis=0)
         im_resize_info = np.expand_dims(im_resize_info, axis=0)
         im_shape = np.expand_dims(im_shape, axis=0)
-        outputs = self.exe.run(self.test_prog,
-                               feed={
-                                   'image': im,
-                                   'im_info': im_resize_info,
-                                   'im_shape': im_shape
-                               },
-                               fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False,
-                               use_program_cache=True)
+        with fluid.scope_guard(self.scope):
+            outputs = self.exe.run(self.test_prog,
+                                   feed={
+                                       'image': im,
+                                       'im_info': im_resize_info,
+                                       'im_shape': im_shape
+                                   },
+                                   fetch_list=list(self.test_outputs.values()),
+                                   return_numpy=False,
+                                   use_program_cache=True)
         res = {
             k: (np.array(v), v.recursive_sequence_lengths())
             for k, v in zip(list(self.test_outputs.keys()), outputs)

+ 36 - 32
paddlex/cv/models/load_model.py

@@ -24,6 +24,7 @@ import paddlex.utils.logging as logging
 
 
 def load_model(model_dir, fixed_input_shape=None):
+    model_scope = fluid.Scope()
     if not osp.exists(osp.join(model_dir, "model.yml")):
         raise Exception("There's not model.yml in {}".format(model_dir))
     with open(osp.join(model_dir, "model.yml")) as f:
@@ -51,38 +52,40 @@ def load_model(model_dir, fixed_input_shape=None):
                              format(fixed_input_shape))
                 model.fixed_input_shape = fixed_input_shape
 
-    if status == "Normal" or \
-            status == "Prune" or status == "fluid.save":
-        startup_prog = fluid.Program()
-        model.test_prog = fluid.Program()
-        with fluid.program_guard(model.test_prog, startup_prog):
-            with fluid.unique_name.guard():
-                model.test_inputs, model.test_outputs = model.build_net(
-                    mode='test')
-        model.test_prog = model.test_prog.clone(for_test=True)
-        model.exe.run(startup_prog)
-        if status == "Prune":
-            from .slim.prune import update_program
-            model.test_prog = update_program(model.test_prog, model_dir,
-                                             model.places[0])
-        import pickle
-        with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
-            load_dict = pickle.load(f)
-        fluid.io.set_program_state(model.test_prog, load_dict)
-
-    elif status == "Infer" or \
-            status == "Quant" or status == "fluid.save_inference_model":
-        [prog, input_names, outputs] = fluid.io.load_inference_model(
-            model_dir, model.exe, params_filename='__params__')
-        model.test_prog = prog
-        test_outputs_info = info['_ModelInputsOutputs']['test_outputs']
-        model.test_inputs = OrderedDict()
-        model.test_outputs = OrderedDict()
-        for name in input_names:
-            model.test_inputs[name] = model.test_prog.global_block().var(name)
-        for i, out in enumerate(outputs):
-            var_desc = test_outputs_info[i]
-            model.test_outputs[var_desc[0]] = out
+    with fluid.scope_guard(model_scope):
+        if status == "Normal" or \
+                status == "Prune" or status == "fluid.save":
+            startup_prog = fluid.Program()
+            model.test_prog = fluid.Program()
+            with fluid.program_guard(model.test_prog, startup_prog):
+                with fluid.unique_name.guard():
+                    model.test_inputs, model.test_outputs = model.build_net(
+                        mode='test')
+            model.test_prog = model.test_prog.clone(for_test=True)
+            model.exe.run(startup_prog)
+            if status == "Prune":
+                from .slim.prune import update_program
+                model.test_prog = update_program(model.test_prog, model_dir,
+                                                 model.places[0])
+            import pickle
+            with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
+                load_dict = pickle.load(f)
+            fluid.io.set_program_state(model.test_prog, load_dict)
+
+        elif status == "Infer" or \
+                status == "Quant" or status == "fluid.save_inference_model":
+            [prog, input_names, outputs] = fluid.io.load_inference_model(
+                model_dir, model.exe, params_filename='__params__')
+            model.test_prog = prog
+            test_outputs_info = info['_ModelInputsOutputs']['test_outputs']
+            model.test_inputs = OrderedDict()
+            model.test_outputs = OrderedDict()
+            for name in input_names:
+                model.test_inputs[name] = model.test_prog.global_block().var(
+                    name)
+            for i, out in enumerate(outputs):
+                var_desc = test_outputs_info[i]
+                model.test_outputs[var_desc[0]] = out
     if 'Transforms' in info:
         transforms_mode = info.get('TransformsMode', 'RGB')
         # 固定模型的输入shape
@@ -107,6 +110,7 @@ def load_model(model_dir, fixed_input_shape=None):
                 model.__dict__[k] = v
 
     logging.info("Model[{}] loaded.".format(info['Model']))
+    model.scope = model_scope
     model.trainable = False
     model.status = status
     return model

+ 19 - 16
paddlex/cv/models/mask_rcnn.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -286,10 +286,12 @@ class MaskRCNN(FasterRCNN):
                 'im_info': im_infos,
                 'im_shape': im_shapes,
             }
-            outputs = self.exe.run(self.test_prog,
-                                   feed=[feed_data],
-                                   fetch_list=list(self.test_outputs.values()),
-                                   return_numpy=False)
+            with fluid.scope_guard(self.scope):
+                outputs = self.exe.run(
+                    self.test_prog,
+                    feed=[feed_data],
+                    fetch_list=list(self.test_outputs.values()),
+                    return_numpy=False)
             res = {
                 'bbox': (np.array(outputs[0]),
                          outputs[0].recursive_sequence_lengths()),
@@ -356,15 +358,16 @@ class MaskRCNN(FasterRCNN):
         im = np.expand_dims(im, axis=0)
         im_resize_info = np.expand_dims(im_resize_info, axis=0)
         im_shape = np.expand_dims(im_shape, axis=0)
-        outputs = self.exe.run(self.test_prog,
-                               feed={
-                                   'image': im,
-                                   'im_info': im_resize_info,
-                                   'im_shape': im_shape
-                               },
-                               fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False,
-                               use_program_cache=True)
+        with fluid.scope_guard(self.scope):
+            outputs = self.exe.run(self.test_prog,
+                                   feed={
+                                       'image': im,
+                                       'im_info': im_resize_info,
+                                       'im_shape': im_shape
+                                   },
+                                   fetch_list=list(self.test_outputs.values()),
+                                   return_numpy=False,
+                                   use_program_cache=True)
         res = {
             k: (np.array(v), v.recursive_sequence_lengths())
             for k, v in zip(list(self.test_outputs.keys()), outputs)

+ 40 - 35
paddlex/cv/models/slim/post_quantization.py

@@ -85,13 +85,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         self._support_quantize_op_type = \
             list(set(QuantizationTransformPass._supported_quantizable_op_type +
                 AddQuantDequantPass._supported_quantizable_op_type))
-        
+
         # Check inputs
         assert executor is not None, "The executor cannot be None."
         assert batch_size > 0, "The batch_size should be greater than 0."
         assert algo in self._support_algo_type, \
             "The algo should be KL, abs_max or min_max."
-        
+
         self._executor = executor
         self._dataset = dataset
         self._batch_size = batch_size
@@ -154,20 +154,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         logging.info("Start to run batch!")
         for data in self._data_loader():
             start = time.time()
-            self._executor.run(
-                program=self._program,
-                feed=data,
-                fetch_list=self._fetch_list,
-                return_numpy=False)
+            with fluid.scope_guard(self._scope):
+                self._executor.run(program=self._program,
+                                   feed=data,
+                                   fetch_list=self._fetch_list,
+                                   return_numpy=False)
             if self._algo == "KL":
                 self._sample_data(batch_id)
             else:
                 self._sample_threshold()
             end = time.time()
-            logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
-                str(batch_id + 1),
-                str(batch_ct),
-                str(end-start)))
+            logging.debug(
+                '[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
+                    str(batch_id + 1), str(batch_ct), str(end - start)))
             batch_id += 1
             if self._batch_nums and batch_id >= self._batch_nums:
                 break
@@ -194,15 +193,16 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         Returns:
             None
         '''
-        feed_vars_names = [var.name for var in self._feed_list]
-        fluid.io.save_inference_model(
-            dirname=save_model_path,
-            feeded_var_names=feed_vars_names,
-            target_vars=self._fetch_list,
-            executor=self._executor,
-            params_filename='__params__',
-            main_program=self._program)
-        
+        with fluid.scope_guard(self._scope):
+            feed_vars_names = [var.name for var in self._feed_list]
+            fluid.io.save_inference_model(
+                dirname=save_model_path,
+                feeded_var_names=feed_vars_names,
+                target_vars=self._fetch_list,
+                executor=self._executor,
+                params_filename='__params__',
+                main_program=self._program)
+
     def _load_model_data(self):
         '''
         Set data loader.
@@ -212,7 +212,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
         self._data_loader = fluid.io.DataLoader.from_generator(
             feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
         self._data_loader.set_sample_list_generator(
-            self._dataset.generator(self._batch_size, drop_last=True),
+            self._dataset.generator(
+                self._batch_size, drop_last=True),
             places=self._place)
 
     def _calculate_kl_threshold(self):
@@ -235,10 +236,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                     weight_threshold.append(abs_max_value)
             self._quantized_var_kl_threshold[var_name] = weight_threshold
             end = time.time()
-            logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
-                str(ct),
-                str(len(self._quantized_weight_var_name)),
-                str(end-start)))
+            logging.debug(
+                '[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.
+                format(
+                    str(ct),
+                    str(len(self._quantized_weight_var_name)), str(end -
+                                                                   start)))
             ct += 1
 
         ct = 1
@@ -257,10 +260,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 self._quantized_var_kl_threshold[var_name] = \
                     self._get_kl_scaling_factor(np.abs(sampling_data))
                 end = time.time()
-                logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
-                    str(ct),
-                    str(len(self._quantized_act_var_name)),
-                    str(end-start)))
+                logging.debug(
+                    '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
+                    format(
+                        str(ct),
+                        str(len(self._quantized_act_var_name)),
+                        str(end - start)))
                 ct += 1
         else:
             for var_name in self._quantized_act_var_name:
@@ -270,10 +275,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
                 self._quantized_var_kl_threshold[var_name] = \
                     self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
                 end = time.time()
-                logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
-                    str(ct),
-                    str(len(self._quantized_act_var_name)),
-                    str(end-start)))
+                logging.debug(
+                    '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.
+                    format(
+                        str(ct),
+                        str(len(self._quantized_act_var_name)),
+                        str(end - start)))
                 ct += 1
-
-                

+ 16 - 13
paddlex/cv/models/yolo_v3.py

@@ -1,11 +1,11 @@
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-# 
+#
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
-# 
+#
 #     http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -313,10 +313,12 @@ class YOLOv3(BaseAPI):
             images = np.array([d[0] for d in data])
             im_sizes = np.array([d[1] for d in data])
             feed_data = {'image': images, 'im_size': im_sizes}
-            outputs = self.exe.run(self.test_prog,
-                                   feed=[feed_data],
-                                   fetch_list=list(self.test_outputs.values()),
-                                   return_numpy=False)
+            with fluid.scope_guard(self.scope):
+                outputs = self.exe.run(
+                    self.test_prog,
+                    feed=[feed_data],
+                    fetch_list=list(self.test_outputs.values()),
+                    return_numpy=False)
             res = {
                 'bbox': (np.array(outputs[0]),
                          outputs[0].recursive_sequence_lengths())
@@ -366,12 +368,13 @@ class YOLOv3(BaseAPI):
             im, im_size = self.test_transforms(img_file)
         im = np.expand_dims(im, axis=0)
         im_size = np.expand_dims(im_size, axis=0)
-        outputs = self.exe.run(self.test_prog,
-                               feed={'image': im,
-                                     'im_size': im_size},
-                               fetch_list=list(self.test_outputs.values()),
-                               return_numpy=False,
-                               use_program_cache=True)
+        with fluid.scope_guard(self.scope):
+            outputs = self.exe.run(self.test_prog,
+                                   feed={'image': im,
+                                         'im_size': im_size},
+                                   fetch_list=list(self.test_outputs.values()),
+                                   return_numpy=False,
+                                   use_program_cache=True)
         res = {
             k: (np.array(v), v.recursive_sequence_lengths())
             for k, v in zip(list(self.test_outputs.keys()), outputs)