浏览代码

fix export (#3393)

zhuyipin 9 月之前
父节点
当前提交
834ec58fc4
共有 1 个文件被更改,包括 11 次插入0 次删除
  1. 11 0
      paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py

+ 11 - 0
paddlex/repo_apis/Paddle3D_api/bev_fusion/model.py

@@ -188,6 +188,17 @@ class BEVFusionModel(BaseModel):
         if save_dir is not None:
         if save_dir is not None:
             cli_args.append(CLIArgument("--save_dir", save_dir))
             cli_args.append(CLIArgument("--save_dir", save_dir))
 
 
+        cli_args.append(CLIArgument("--save_name", "inference"))
+        cli_args.append(CLIArgument("--save_inference_yml"))
+
+        # PDX related settings
+        uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        export_with_pir = kwargs.pop("export_with_pir", False)
+        config.update({"uniform_output_enabled": uniform_output_enabled})
+        config.update({"pdx_model_name": self.name})
+        if export_with_pir:
+            config.update({"export_with_pir": export_with_pir})
+
         self._assert_empty_kwargs(kwargs)
         self._assert_empty_kwargs(kwargs)
         with self._create_new_config_file() as config_path:
         with self._create_new_config_file() as config_path:
             config.dump(config_path)
             config.dump(config_path)