|
|
@@ -13,6 +13,7 @@
|
|
|
// limitations under the License.
|
|
|
#include <omp.h>
|
|
|
#include <algorithm>
|
|
|
+#include <fstream>
|
|
|
#include <cstring>
|
|
|
#include "include/paddlex/paddlex.h"
|
|
|
namespace PaddleX {
|
|
|
@@ -27,17 +28,28 @@ void Model::create_predictor(const std::string& model_dir,
|
|
|
std::string model_file = model_dir + OS_PATH_SEP + "__model__";
|
|
|
std::string params_file = model_dir + OS_PATH_SEP + "__params__";
|
|
|
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
|
|
|
+ std::string yaml_input = "";
|
|
|
#ifdef WITH_ENCRYPTION
|
|
|
if (key != "") {
|
|
|
model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
|
|
|
params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
|
|
|
- std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
|
|
|
+ yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
|
|
|
paddle_security_load_model(
|
|
|
&config, key.c_str(), model_file.c_str(), params_file.c_str());
|
|
|
+ yaml_input = decrypt_file(yaml_file.c_str(), key.c_str());
|
|
|
}
|
|
|
#endif
|
|
|
- // 读取配置文件
|
|
|
- if (!load_config(yaml_file)) {
|
|
|
+ if (yaml_input == "") {
|
|
|
+ // 读取配置文件
|
|
|
+ std::ifstream yaml_fin(yaml_file);
|
|
|
+ yaml_fin.seekg(0, std::ios::end);
|
|
|
+ size_t yaml_file_size = yaml_fin.tellg();
|
|
|
+ yaml_input.assign(yaml_file_size, ' ');
|
|
|
+ yaml_fin.seekg(0);
|
|
|
+ yaml_fin.read(&yaml_input[0], yaml_file_size);
|
|
|
+ }
|
|
|
+ // 读取配置文件内容
|
|
|
+ if (!load_config(yaml_input)) {
|
|
|
std::cerr << "Parse file 'model.yml' failed!" << std::endl;
|
|
|
exit(-1);
|
|
|
}
|
|
|
@@ -67,9 +79,8 @@ void Model::create_predictor(const std::string& model_dir,
|
|
|
inputs_batch_.assign(batch_size, ImageBlob());
|
|
|
}
|
|
|
|
|
|
-bool Model::load_config(const std::string& yaml_file) {
|
|
|
- // std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
|
|
|
- YAML::Node config = YAML::LoadFile(yaml_file);
|
|
|
+bool Model::load_config(const std::string& yaml_input) {
|
|
|
+ YAML::Node config = YAML::Load(yaml_input);
|
|
|
type = config["_Attributes"]["model_type"].as<std::string>();
|
|
|
name = config["Model"].as<std::string>();
|
|
|
std::string version = config["version"].as<std::string>();
|