paddle_model_encrypt.cpp 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <string.h>
  15. #include <iostream>
  16. #include <string>
  17. #include <memory>
  18. #include <vector>
  19. #include "encryption/include/model_code.h"
  20. #include "encryption/include/paddle_model_encrypt.h"
  21. #include "encryption/util/include/constant/constant_model.h"
  22. #include "encryption/util/include/crypto/aes_gcm.h"
  23. #include "encryption/util/include/crypto/sha256_utils.h"
  24. #include "encryption/util/include/crypto/base64.h"
  25. #include "encryption/util/include/system_utils.h"
  26. #include "encryption/util/include/io_utils.h"
  27. #include "encryption/util/include/log.h"
  28. std::string paddle_generate_random_key() {
  29. std::string tmp = util::SystemUtils::random_key_iv(AES_GCM_KEY_LENGTH);
  30. // return util::crypto::Base64Utils::encode(tmp);
  31. return baidu::base::base64::base64_encode(tmp);
  32. }
  33. int paddle_encrypt_dir(const char* keydata,
  34. const char* src_dir,
  35. const char* dst_dir) {
  36. std::vector<std::string> files;
  37. int ret_files = ioutil::read_dir_files(src_dir, files);
  38. if (ret_files == -1) {
  39. return CODE_NOT_EXIST_DIR;
  40. }
  41. if (ret_files == 0) {
  42. return CODE_FILES_EMPTY_WITH_DIR;
  43. }
  44. // check model.yml, __model__, __params__ exist or not
  45. if (util::SystemUtils::check_pattern_exist(files, "model.yml")) {
  46. return CODE_MODEL_YML_FILE_NOT_EXIST;
  47. }
  48. if (util::SystemUtils::check_pattern_exist(files, "__model__")) {
  49. return CODE_MODEL_FILE_NOT_EXIST;
  50. }
  51. if (util::SystemUtils::check_pattern_exist(files, "__params__")) {
  52. return CODE_PARAMS_FILE_NOT_EXIST;
  53. }
  54. std::string src_str(src_dir);
  55. if (src_str[src_str.length() - 1] != '/') {
  56. src_str.append("/");
  57. }
  58. std::string dst_str(dst_dir);
  59. if (dst_str[dst_str.length() - 1] != '/') {
  60. dst_str.append("/");
  61. }
  62. int ret = CODE_OK;
  63. ret = ioutil::dir_exist_or_mkdir(dst_str.c_str());
  64. for (int i = 0; i < files.size(); ++i) {
  65. if (strcmp(files[i].c_str(), "__model__") == 0
  66. || strcmp(files[i].c_str(), "__params__") == 0
  67. || strcmp(files[i].c_str(), "model.yml") == 0) {
  68. std::string infile = src_str + files[i];
  69. std::string outfile = dst_str + files[i] + ".encrypted";
  70. ret = paddle_encrypt_model(keydata, infile.c_str(),
  71. outfile.c_str());
  72. } else {
  73. std::string infile = src_str + files[i];
  74. std::string outfile = dst_str + files[i];
  75. ret = ioutil::read_file_to_file(infile.c_str(), outfile.c_str());
  76. }
  77. if (ret != CODE_OK) {
  78. return ret;
  79. }
  80. }
  81. files.clear();
  82. return ret;
  83. }
  84. int paddle_encrypt_model(const char* keydata,
  85. const char* infile,
  86. const char* outfile) {
  87. std::string key_str =
  88. baidu::base::base64::base64_decode(std::string(keydata));
  89. if (key_str.length() != 32) {
  90. return CODE_KEY_LENGTH_ABNORMAL;
  91. }
  92. unsigned char* plain = NULL;
  93. size_t plain_len = 0;
  94. int ret_read = ioutil::read_file(infile, &plain, &plain_len);
  95. if (ret_read != CODE_OK) {
  96. return ret_read;
  97. }
  98. unsigned char* aes_key =
  99. (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_KEY_LENGTH);
  100. unsigned char* aes_iv =
  101. (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_IV_LENGTH);
  102. memcpy(aes_key, key_str.c_str(), AES_GCM_KEY_LENGTH);
  103. memcpy(aes_iv, key_str.c_str() + 16, AES_GCM_IV_LENGTH);
  104. unsigned char* cipher = (unsigned char*) malloc(sizeof(unsigned char) *
  105. (plain_len + AES_GCM_TAG_LENGTH));
  106. size_t cipher_len = 0;
  107. int ret_encrypt =
  108. util::crypto::AesGcm::encrypt_aes_gcm(plain,
  109. plain_len,
  110. aes_key,
  111. aes_iv,
  112. cipher,
  113. reinterpret_cast<int&>(cipher_len));
  114. free(aes_key);
  115. free(aes_iv);
  116. if (ret_encrypt != CODE_OK) {
  117. LOGD("[M]aes encrypt ret code: %d", ret_encrypt);
  118. free(plain);
  119. free(cipher);
  120. return CODE_AES_GCM_ENCRYPT_FIALED;
  121. }
  122. std::string randstr = util::SystemUtils::random_str(constant::TAG_LEN);
  123. std::string aes_key_iv(key_str);
  124. std::string sha256_key_iv =
  125. util::crypto::SHA256Utils::sha256_string(aes_key_iv);
  126. for (int i = 0; i < 64; ++i) {
  127. randstr[i] = sha256_key_iv[i];
  128. }
  129. size_t header_len = constant::MAGIC_NUMBER_LEN +
  130. constant::VERSION_LEN + constant::TAG_LEN;
  131. unsigned char* header =
  132. (unsigned char*) malloc(sizeof(unsigned char) * header_len);
  133. memcpy(header, constant::MAGIC_NUMBER.c_str(), constant::MAGIC_NUMBER_LEN);
  134. memcpy(header + constant::MAGIC_NUMBER_LEN,
  135. constant::VERSION.c_str(),
  136. constant::VERSION_LEN);
  137. memcpy(header + constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN,
  138. randstr.c_str(), constant::TAG_LEN);
  139. int ret_write_file = ioutil::write_file(outfile, header, header_len);
  140. ret_write_file = ioutil::append_file(outfile, cipher, cipher_len);
  141. free(header);
  142. free(cipher);
  143. return ret_write_file;
  144. }
  145. int encrypt_stream(const std::string &keydata,
  146. std::istream &in_stream, std::ostream &out_stream) {
  147. std::string key_str = baidu::base::base64::base64_decode(keydata);
  148. if (key_str.length() != 32) {
  149. return CODE_KEY_LENGTH_ABNORMAL;
  150. }
  151. in_stream.seekg(0, std::ios::beg);
  152. in_stream.seekg(0, std::ios::end);
  153. size_t plain_len = in_stream.tellg();
  154. in_stream.seekg(0, std::ios::beg);
  155. std::unique_ptr<unsigned char[]> plain(new unsigned char[plain_len]);
  156. in_stream.read(reinterpret_cast<char *>(plain.get()), plain_len);
  157. std::string aes_key = key_str.substr(0, AES_GCM_KEY_LENGTH);
  158. std::string aes_iv = key_str.substr(16, AES_GCM_IV_LENGTH);
  159. std::unique_ptr<unsigned char[]> cipher(
  160. new unsigned char[plain_len + AES_GCM_TAG_LENGTH]);
  161. size_t cipher_len = 0;
  162. int ret_encrypt = util::crypto::AesGcm::encrypt_aes_gcm(
  163. plain.get(),
  164. plain_len,
  165. reinterpret_cast<const unsigned char*>(aes_key.c_str()),
  166. reinterpret_cast<const unsigned char*>(aes_iv.c_str()),
  167. cipher.get(),
  168. reinterpret_cast<int&>(cipher_len));
  169. if (ret_encrypt != CODE_OK) {
  170. LOGD("[M]aes encrypt ret code: %d", ret_encrypt);
  171. return CODE_AES_GCM_ENCRYPT_FIALED;
  172. }
  173. std::string randstr = util::SystemUtils::random_str(constant::TAG_LEN);
  174. std::string aes_key_iv(key_str);
  175. std::string sha256_key_iv =
  176. util::crypto::SHA256Utils::sha256_string(aes_key_iv);
  177. for (int i = 0; i < 64; ++i) {
  178. randstr[i] = sha256_key_iv[i];
  179. }
  180. size_t header_len = constant::MAGIC_NUMBER_LEN +
  181. constant::VERSION_LEN + constant::TAG_LEN;
  182. out_stream.write(constant::MAGIC_NUMBER.c_str(),
  183. constant::MAGIC_NUMBER_LEN);
  184. out_stream.write(constant::VERSION.c_str(), constant::VERSION_LEN);
  185. out_stream.write(randstr.c_str(), constant::TAG_LEN);
  186. out_stream.write(reinterpret_cast<char *>(cipher.get()), cipher_len);
  187. return CODE_OK;
  188. }