compile.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <algorithm>
  16. #include <set>
  17. #include <string>
  18. #include <unordered_map>
  19. #include "iengine.h" // NOLINT
  20. #include "poros_module.h" // NOLINT
  21. #include "torch/script.h" // NOLINT
  22. namespace baidu {
  23. namespace mirana {
  24. namespace poros {
  25. /**
  26. * @brief compile graph
  27. *
  28. * @param [in] module : original module
  29. * @param [in] input_ivalues : prewarm datas
  30. * @param [in] options : Inference options
  31. * @return porosmodule
  32. * @retval !nullptr => succeed nullptr => failed
  33. **/
  34. std::unique_ptr<PorosModule>
  35. Compile(const torch::jit::Module &module,
  36. const std::vector<std::vector<c10::IValue>> &prewarm_datas,
  37. const PorosOptions &options);
  38. class Compiler {
  39. public:
  40. typedef std::unordered_map<const torch::jit::Node *, IEngine *> engine_map_t;
  41. typedef std::vector<std::vector<c10::IValue>> ivalue_vec_t;
  42. Compiler() : _origin_module(NULL) {}
  43. ~Compiler();
  44. /**
  45. * @brief initial Compiler
  46. *
  47. * @param [in] options : poros options
  48. * @return int
  49. * @retval 0 => succeed <0 => failed
  50. **/
  51. int init(const PorosOptions &options);
  52. /**
  53. * @brief compile whole graph
  54. *
  55. * @param [in] origin_module
  56. * @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
  57. * @param [out] optimized_module : optimized graph
  58. * @return int
  59. * @retval 0 => succeed <0 => failed
  60. **/
  61. int compile(const torch::jit::Module &origin_module,
  62. const ivalue_vec_t &prewarm_datas,
  63. torch::jit::Module *optimized_module);
  64. private:
  65. /**
  66. * @brief preprocess this calculation graph
  67. *
  68. * @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
  69. * @param [out] graph : preprcessed graph
  70. * @return int
  71. * @retval 0 => succeed <0 => failed
  72. **/
  73. int preprocess_graph(const ivalue_vec_t &prewarm_datas,
  74. std::shared_ptr<torch::jit::Graph> &graph);
  75. /**
  76. * @brief segment this calculation graph
  77. *
  78. * @param [in/out] graph
  79. * @return int
  80. * @retval 0 => succeed <0 => failed
  81. **/
  82. int segment_graph(std::shared_ptr<torch::jit::Graph> &graph);
  83. // Split subgraph(block)
  84. // The divided subgraph, as a subgraph, is associated with the block
  85. int segment_block(torch::jit::Block &block, IEngine *engine,
  86. int current_depth);
  87. // Subgraph optimization
  88. /**
  89. * @brief Subgraph optimization
  90. *
  91. * @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
  92. * @param [in] opt_graph : ivalue_vec_t, vector of IValue
  93. * @param [out] optimized_module : optimized graph
  94. * @return int
  95. * @retval 0 => succeed <0 => failed
  96. **/
  97. int optimize_subgraph(const ivalue_vec_t &prewarm_datas,
  98. const std::shared_ptr<torch::jit::Graph> &opt_graph,
  99. torch::jit::Module *optimized_module);
  100. // Subgraph optimization(block)
  101. int optimize_subblock(torch::jit::Block *block,
  102. torch::jit::Module *optimized_module);
  103. /**
  104. * @brief Compile the subgraph into a new graph based on the engine
  105. *
  106. * @param [in] engine : The engine used by the subgraph
  107. * @param [in] subgraph_node : Subgraph node
  108. * @return [out] module : Transformed model
  109. * @retval 0 => succeed <0 => failed
  110. **/
  111. int transform(IEngine *engine, torch::jit::Node &subgraph_node,
  112. torch::jit::Module &module);
  113. /**
  114. * @brief Select engine based on subgraph and options
  115. *
  116. * @param [in] node : Jit Node
  117. * @return int
  118. * @retval 0 => succeed <0 => failed
  119. **/
  120. IEngine *select_engine(const torch::jit::Node *n);
  121. /**
  122. * @brief destroy
  123. *
  124. * @return void
  125. **/
  126. void close();
  127. private:
  128. int _max_segment_depth{5}; // Maximum subgraph segmentation depth
  129. ivalue_vec_t _prewarm_datas; // Prewarm datas
  130. PorosOptions _options;
  131. engine_map_t _engine_map; // The engine used to record the subgraph
  132. const torch::jit::Module *_origin_module; // Origin_module
  133. std::atomic<int> _engine_index = {0}; // Record engine index
  134. };
  135. /**
  136. * @brief compile graph, internal use
  137. *
  138. * @param [in] module : Origin module
  139. * @param [in] input_ivalues : Prewarm datas
  140. * @param [in] options : Inference options
  141. * @return optimized_module
  142. * @retval !nullptr => succeed nullptr => failed
  143. **/
  144. std::unique_ptr<torch::jit::Module>
  145. CompileGraph(const torch::jit::Module &module,
  146. const std::vector<std::vector<c10::IValue>> &prewarm_datas,
  147. const PorosOptions &options);
  148. } // namespace poros
  149. } // namespace mirana
  150. } // namespace baidu