Sfoglia il codice sorgente

fix scope program for export inference

jiangjiajun 5 anni fa
parent
commit
a3f3e830ec
1 ha cambiato i file con 18 aggiunte e 17 eliminazioni
  1. 18 17
      paddlex/cv/models/base.py

+ 18 - 17
paddlex/cv/models/base.py

@@ -356,23 +356,24 @@ class BaseAPI:
     def export_inference_model(self, save_dir):
         test_input_names = [var.name for var in list(self.test_inputs.values())]
         test_outputs = list(self.test_outputs.values())
-        if self.__class__.__name__ == 'MaskRCNN':
-            from paddlex.utils.save import save_mask_inference_model
-            save_mask_inference_model(
-                dirname=save_dir,
-                executor=self.exe,
-                params_filename='__params__',
-                feeded_var_names=test_input_names,
-                target_vars=test_outputs,
-                main_program=self.test_prog)
-        else:
-            fluid.io.save_inference_model(
-                dirname=save_dir,
-                executor=self.exe,
-                params_filename='__params__',
-                feeded_var_names=test_input_names,
-                target_vars=test_outputs,
-                main_program=self.test_prog)
+        with fluid.scope_guard(self.scope):
+            if self.__class__.__name__ == 'MaskRCNN':
+                from paddlex.utils.save import save_mask_inference_model
+                save_mask_inference_model(
+                    dirname=save_dir,
+                    executor=self.exe,
+                    params_filename='__params__',
+                    feeded_var_names=test_input_names,
+                    target_vars=test_outputs,
+                    main_program=self.test_prog)
+            else:
+                fluid.io.save_inference_model(
+                    dirname=save_dir,
+                    executor=self.exe,
+                    params_filename='__params__',
+                    feeded_var_names=test_input_names,
+                    target_vars=test_outputs,
+                    main_program=self.test_prog)
         model_info = self.get_model_info()
         model_info['status'] = 'Infer'