x_model.cpp 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "model_deploy/paddlex/include/x_model.h"
  15. #include "model_deploy/paddlex/include/x_standard_config.h"
  16. #include <fstream>
  17. namespace PaddleDeploy {
  18. bool PaddleXModel::GenerateTransformsConfig(const YAML::Node& src) {
  19. XEssential(src, &yaml_config_);
  20. for (const auto& op : src["Transforms"]) {
  21. std::string op_name = op.begin()->first.as<std::string>();
  22. if (op_name == "Normalize") {
  23. if (src["version"].as<std::string>() >= "2.0.0") {
  24. yaml_config_["transforms"]["Convert"]["dtype"] = "float";
  25. }
  26. XNormalize(op.begin()->second, &yaml_config_);
  27. } else if (op_name == "ResizeByShort") {
  28. XResizeByShort(op.begin()->second, &yaml_config_);
  29. } else if (op_name == "ResizeByLong") {
  30. XResizeByLong(op.begin()->second, &yaml_config_);
  31. } else if (op_name == "Padding") {
  32. if (src["version"].as<std::string>() >= "2.0.0") {
  33. XPaddingV2(op.begin()->second, &yaml_config_);
  34. } else {
  35. XPadding(op.begin()->second, &yaml_config_);
  36. }
  37. } else if (op_name == "CenterCrop") {
  38. XCenterCrop(op.begin()->second, &yaml_config_);
  39. } else if (op_name == "Resize") {
  40. XResize(op.begin()->second, &yaml_config_);
  41. } else {
  42. std::cerr << "Unexpected transforms op name: '"
  43. << op_name << "'" << std::endl;
  44. return false;
  45. }
  46. }
  47. yaml_config_["transforms"]["Permute"] = YAML::Null;
  48. return true;
  49. }
  50. bool PaddleXModel::YamlConfigInit(const std::string& cfg_file) {
  51. YAML::Node x_config = YAML::LoadFile(cfg_file);
  52. yaml_config_["model_format"] = "Paddle";
  53. yaml_config_["toolkit"] = "PaddleX";
  54. yaml_config_["version"] = x_config["version"].as<std::string>();
  55. yaml_config_["model_type"] =
  56. x_config["_Attributes"]["model_type"].as<std::string>();
  57. yaml_config_["model_name"] = x_config["Model"].as<std::string>();
  58. int i = 0;
  59. for (const auto& label : x_config["_Attributes"]["labels"]) {
  60. yaml_config_["labels"][i] = label.as<std::string>();
  61. i++;
  62. }
  63. // Generate Standard Transforms Configuration
  64. if (!GenerateTransformsConfig(x_config)) {
  65. std::cerr << "Fail to generate standard configuration "
  66. << "of tranforms" << std::endl;
  67. return false;
  68. }
  69. return true;
  70. }
  71. bool PaddleXModel::PreprocessInit() {
  72. preprocess_ = std::make_shared<XPreprocess>();
  73. if (!preprocess_->Init(yaml_config_)) {
  74. return false;
  75. }
  76. return true;
  77. }
  78. bool PaddleXModel::PostprocessInit() {
  79. postprocess_ = std::make_shared<XPostprocess>();
  80. if (!postprocess_->Init(yaml_config_)) {
  81. return false;
  82. }
  83. return true;
  84. }
  85. } // namespace PaddleDeploy