paddle_model_encrypt.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #include <iostream>
  2. #include <string>
  3. #include <memory>
  4. #include <vector>
  5. #include <string.h>
  6. #include "paddle_model_encrypt.h"
  7. #include "model_code.h"
  8. #include "../util/system_utils.h"
  9. #include "../util/io_utils.h"
  10. #include "../constant/constant_model.h"
  11. #include "../util/crypto/aes_gcm.h"
  12. #include "../util/crypto/sha256_utils.h"
  13. #include "../util/crypto/base64.h"
  14. #include "../util/log.h"
  15. std::string paddle_generate_random_key() {
  16. std::string tmp = util::SystemUtils::random_key_iv(AES_GCM_KEY_LENGTH);
  17. // return util::crypto::Base64Utils::encode(tmp);
  18. return baidu::base::base64::base64_encode(tmp);
  19. }
  20. int paddle_encrypt_dir(const char* keydata, const char* src_dir, const char* dst_dir) {
  21. std::vector<std::string> files;
  22. int ret_files = ioutil::read_dir_files(src_dir, files);
  23. if (ret_files == -1) {
  24. return CODE_NOT_EXIST_DIR;
  25. }
  26. if (ret_files == 0) {
  27. return CODE_FILES_EMPTY_WITH_DIR;
  28. }
  29. // check model.yml, __model__, __params__ exist or not
  30. if (util::SystemUtils::check_pattern_exist(files, "model.yml")) {
  31. return CODE_MODEL_YML_FILE_NOT_EXIST;
  32. }
  33. if (util::SystemUtils::check_pattern_exist(files, "__model__")) {
  34. return CODE_MODEL_FILE_NOT_EXIST;
  35. }
  36. if (util::SystemUtils::check_pattern_exist(files, "__params__")) {
  37. return CODE_PARAMS_FILE_NOT_EXIST;
  38. }
  39. std::string src_str(src_dir);
  40. if (src_str[src_str.length() - 1] != '/') {
  41. src_str.append("/");
  42. }
  43. std::string dst_str(dst_dir);
  44. if (dst_str[dst_str.length() - 1] != '/') {
  45. dst_str.append("/");
  46. }
  47. int ret = CODE_OK;
  48. ret = ioutil::dir_exist_or_mkdir(dst_str.c_str());
  49. for (int i = 0; i < files.size(); ++i) {
  50. if (strcmp(files[i].c_str(), "__model__") == 0 || strcmp(files[i].c_str(), "__params__") == 0 || strcmp(files[i].c_str(), "model.yml") == 0) {
  51. std::string infile = src_str + files[i];
  52. std::string outfile = dst_str + files[i] + ".encrypted";
  53. ret = paddle_encrypt_model(keydata, infile.c_str(), outfile.c_str());
  54. } else {
  55. std::string infile = src_str + files[i];
  56. std::string outfile = dst_str + files[i];
  57. ret = ioutil::read_file_to_file(infile.c_str(), outfile.c_str());
  58. }
  59. if (ret != CODE_OK) {
  60. return ret;
  61. }
  62. }
  63. files.clear();
  64. return ret;
  65. }
  66. int paddle_encrypt_model(const char* keydata, const char* infile, const char* outfile) {
  67. // std::string key_str = util::crypto::Base64Utils::decode(std::string(keydata));
  68. std::string key_str = baidu::base::base64::base64_decode(std::string(keydata));
  69. if (key_str.length() != 32) {
  70. return CODE_KEY_LENGTH_ABNORMAL;
  71. }
  72. unsigned char* plain = NULL;
  73. size_t plain_len = 0;
  74. int ret_read = ioutil::read_file(infile, &plain, &plain_len);
  75. if (ret_read != CODE_OK) {
  76. return ret_read;
  77. }
  78. unsigned char* aes_key = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_KEY_LENGTH);
  79. unsigned char* aes_iv = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_IV_LENGTH);
  80. memcpy(aes_key, key_str.c_str(), AES_GCM_KEY_LENGTH);
  81. memcpy(aes_iv, key_str.c_str() + 16, AES_GCM_IV_LENGTH);
  82. unsigned char* cipher = (unsigned char*) malloc(sizeof(unsigned char) * (plain_len + AES_GCM_TAG_LENGTH));
  83. size_t cipher_len = 0;
  84. int ret_encrypt =
  85. util::crypto::AesGcm::encrypt_aes_gcm(plain,
  86. plain_len,
  87. aes_key,
  88. aes_iv,
  89. cipher,
  90. reinterpret_cast<int&>(cipher_len));
  91. free(aes_key);
  92. free(aes_iv);
  93. if (ret_encrypt != CODE_OK) {
  94. LOGD("[M]aes encrypt ret code: %d", ret_encrypt);
  95. free(plain);
  96. free(cipher);
  97. return CODE_AES_GCM_ENCRYPT_FIALED;
  98. }
  99. std::string randstr = util::SystemUtils::random_str(constant::TAG_LEN);
  100. std::string aes_key_iv(key_str);
  101. std::string sha256_key_iv = util::crypto::SHA256Utils::sha256_string(aes_key_iv);
  102. for (int i = 0; i < 64; ++i) {
  103. randstr[i] = sha256_key_iv[i];
  104. }
  105. size_t header_len = constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN + constant::TAG_LEN;
  106. unsigned char* header = (unsigned char*) malloc(sizeof(unsigned char) * header_len);
  107. memcpy(header, constant::MAGIC_NUMBER.c_str(), constant::MAGIC_NUMBER_LEN);
  108. memcpy(header + constant::MAGIC_NUMBER_LEN, constant::VERSION.c_str(), constant::VERSION_LEN);
  109. memcpy(header + constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN, randstr.c_str(), constant::TAG_LEN);
  110. int ret_write_file = ioutil::write_file(outfile, header, header_len);
  111. ret_write_file = ioutil::append_file(outfile, cipher, cipher_len);
  112. free(header);
  113. free(cipher);
  114. return ret_write_file;
  115. }
  116. int encrypt_stream(const std::string &keydata, std::istream &in_stream, std::ostream &out_stream) {
  117. std::string key_str = baidu::base::base64::base64_decode(keydata);
  118. if (key_str.length() != 32) {
  119. return CODE_KEY_LENGTH_ABNORMAL;
  120. }
  121. in_stream.seekg(0, std::ios::beg);
  122. in_stream.seekg(0, std::ios::end);
  123. size_t plain_len = in_stream.tellg();
  124. in_stream.seekg(0, std::ios::beg);
  125. std::unique_ptr<unsigned char[]> plain(new unsigned char[plain_len]);
  126. in_stream.read(reinterpret_cast<char *>(plain.get()), plain_len);
  127. std::string aes_key = key_str.substr(0, AES_GCM_KEY_LENGTH);
  128. std::string aes_iv = key_str.substr(16, AES_GCM_IV_LENGTH);
  129. std::unique_ptr<unsigned char[]> cipher(new unsigned char[plain_len + AES_GCM_TAG_LENGTH]);
  130. size_t cipher_len = 0;
  131. int ret_encrypt = util::crypto::AesGcm::encrypt_aes_gcm(plain.get(),
  132. plain_len,
  133. reinterpret_cast<const unsigned char*>(aes_key.c_str()),
  134. reinterpret_cast<const unsigned char*>(aes_iv.c_str()),
  135. cipher.get(),
  136. reinterpret_cast<int&>(cipher_len));
  137. if (ret_encrypt != CODE_OK) {
  138. LOGD("[M]aes encrypt ret code: %d", ret_encrypt);
  139. return CODE_AES_GCM_ENCRYPT_FIALED;
  140. }
  141. std::string randstr = util::SystemUtils::random_str(constant::TAG_LEN);
  142. std::string aes_key_iv(key_str);
  143. std::string sha256_key_iv = util::crypto::SHA256Utils::sha256_string(aes_key_iv);
  144. for (int i = 0; i < 64; ++i) {
  145. randstr[i] = sha256_key_iv[i];
  146. }
  147. size_t header_len = constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN + constant::TAG_LEN;
  148. out_stream.write(constant::MAGIC_NUMBER.c_str(), constant::MAGIC_NUMBER_LEN);
  149. out_stream.write(constant::VERSION.c_str(), constant::VERSION_LEN);
  150. out_stream.write(randstr.c_str(), constant::TAG_LEN);
  151. out_stream.write(reinterpret_cast<char *>(cipher.get()), cipher_len);
  152. return CODE_OK;
  153. }