|
|
@@ -356,23 +356,24 @@ class BaseAPI:
|
|
|
def export_inference_model(self, save_dir):
|
|
|
test_input_names = [var.name for var in list(self.test_inputs.values())]
|
|
|
test_outputs = list(self.test_outputs.values())
|
|
|
- if self.__class__.__name__ == 'MaskRCNN':
|
|
|
- from paddlex.utils.save import save_mask_inference_model
|
|
|
- save_mask_inference_model(
|
|
|
- dirname=save_dir,
|
|
|
- executor=self.exe,
|
|
|
- params_filename='__params__',
|
|
|
- feeded_var_names=test_input_names,
|
|
|
- target_vars=test_outputs,
|
|
|
- main_program=self.test_prog)
|
|
|
- else:
|
|
|
- fluid.io.save_inference_model(
|
|
|
- dirname=save_dir,
|
|
|
- executor=self.exe,
|
|
|
- params_filename='__params__',
|
|
|
- feeded_var_names=test_input_names,
|
|
|
- target_vars=test_outputs,
|
|
|
- main_program=self.test_prog)
|
|
|
+ with fluid.scope_guard(self.scope):
|
|
|
+ if self.__class__.__name__ == 'MaskRCNN':
|
|
|
+ from paddlex.utils.save import save_mask_inference_model
|
|
|
+ save_mask_inference_model(
|
|
|
+ dirname=save_dir,
|
|
|
+ executor=self.exe,
|
|
|
+ params_filename='__params__',
|
|
|
+ feeded_var_names=test_input_names,
|
|
|
+ target_vars=test_outputs,
|
|
|
+ main_program=self.test_prog)
|
|
|
+ else:
|
|
|
+ fluid.io.save_inference_model(
|
|
|
+ dirname=save_dir,
|
|
|
+ executor=self.exe,
|
|
|
+ params_filename='__params__',
|
|
|
+ feeded_var_names=test_input_names,
|
|
|
+ target_vars=test_outputs,
|
|
|
+ main_program=self.test_prog)
|
|
|
model_info = self.get_model_info()
|
|
|
model_info['status'] = 'Infer'
|
|
|
|