Channingss 5 years ago
parent
commit
a19614d882
1 changed files with 24 additions and 5 deletions
  1. 24 5
      paddlex/cv/models/base.py

+ 24 - 5
paddlex/cv/models/base.py

@@ -329,11 +329,29 @@ class BaseAPI:
             "Model for inference deploy saved in {}.".format(save_dir))
 
     def export_onnx_model(self, save_dir, onnx_name=None):
-        from fluid.utils import op_io_info, init_name_prefix
-        from onnx import helper, checker
-        import fluid_onnx.ops as ops
-        from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
-        from debug.model_check import debug_model, Tracker
+        support_list = ['ResNet18','ResNet34','ResNet50','ResNet101','ResNet50_vd',
+                        'ResNet101_vd','ResNet50_vd_ssld','ResNet101_vd_ssld','DarkNet53',
+                        'MobileNetV1','MobileNetV2','MobileNetV3_large','MobileNetV3_small',
+                        'MobileNetV3_large_ssld','MobileNetV3_small_ssld','Xception41',
+                        'Xception65','DenseNet121','DenseNet161','DenseNet201','ShuffleNetV2'] 
+        unsupport_list = []
+        if self.model_type in unsupport_list:
+            raise Exception("Model: {} unsupport export to ONNX"
+                            .format(self.model_type)
+        try:
+            from fluid.utils import op_io_info, init_name_prefix
+            from onnx import helper, checker
+            import fluid_onnx.ops as ops
+            from fluid_onnx.variables import paddle_variable_to_onnx_tensor, paddle_onnx_weight
+            from debug.model_check import debug_model, Tracker
+        except Exception as e:
+            print(e)
+            print(
+                "Import Module Failed! Please install paddle2onnx. Related requirements
+                 see https://github.com/PaddlePaddle/paddle2onnx"
+            )
+            sys.exit(-1)
+
         place = fluid.CPUPlace()
         exe = fluid.Executor(place)
         inference_scope = fluid.global_scope()
@@ -392,6 +410,7 @@ class BaseAPI:
                             op_check_list.append(op.type)
             print('The operator sets to run test case.')
             print(set(op_check_list))
+
             # Create outputs
             # Get the new names for outputs if they've been renamed in nodes' making
             renamed_outputs = op_io_info.get_all_renamed_outputs()