Browse Source

adapt encryption in linux and windows

jack 5 years ago
parent
commit
8aa1cc833c
1 changed files with 17 additions and 6 deletions
  1. 17 6
      deploy/cpp/src/paddlex.cpp

+ 17 - 6
deploy/cpp/src/paddlex.cpp

@@ -13,6 +13,7 @@
 // limitations under the License.
 #include <omp.h>
 #include <algorithm>
+#include <fstream>
 #include <cstring>
 #include "include/paddlex/paddlex.h"
 namespace PaddleX {
@@ -27,17 +28,28 @@ void Model::create_predictor(const std::string& model_dir,
   std::string model_file = model_dir + OS_PATH_SEP + "__model__";
   std::string params_file = model_dir + OS_PATH_SEP + "__params__";
   std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
+  std::string yaml_input = "";
 #ifdef WITH_ENCRYPTION
   if (key != "") {
     model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
     params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
-    std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
+    yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
     paddle_security_load_model(
         &config, key.c_str(), model_file.c_str(), params_file.c_str());
+    yaml_input = decrypt_file(yaml_file.c_str(), key.c_str());
   }
 #endif
-  // 读取配置文件
-  if (!load_config(yaml_file)) {
+  if (yaml_input == "") {
+    // 读取配置文件
+    std::ifstream yaml_fin(yaml_file);
+    yaml_fin.seekg(0, std::ios::end);
+    size_t yaml_file_size = yaml_fin.tellg();
+    yaml_input.assign(yaml_file_size, ' ');
+    yaml_fin.seekg(0);
+    yaml_fin.read(&yaml_input[0], yaml_file_size);
+  }
+  // 读取配置文件内容
+  if (!load_config(yaml_input)) {
     std::cerr << "Parse file 'model.yml' failed!" << std::endl;
     exit(-1);
   }
@@ -67,9 +79,8 @@ void Model::create_predictor(const std::string& model_dir,
   inputs_batch_.assign(batch_size, ImageBlob());
 }
 
-bool Model::load_config(const std::string& yaml_file) {
-  // std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
-  YAML::Node config = YAML::LoadFile(yaml_file);
+bool Model::load_config(const std::string& yaml_input) {
+  YAML::Node config = YAML::Load(yaml_input);
   type = config["_Attributes"]["model_type"].as<std::string>();
   name = config["Model"].as<std::string>();
   std::string version = config["version"].as<std::string>();