iengine.h 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <string>
  16. // from pytorch
  17. #include "ATen/core/interned_strings.h" // NOLINT
  18. #include "torch/csrc/jit/ir/ir.h" // NOLINT
  19. #include "torch/script.h" // NOLINT
  20. #include "plugin_create.h" // NOLINT
  21. namespace baidu {
  22. namespace mirana {
  23. namespace poros {
  24. struct PorosGraph {
  25. torch::jit::Graph *graph = NULL;
  26. torch::jit::Node *node = NULL;
  27. };
  28. typedef uint64_t EngineID;
  29. class IEngine : public IPlugin, public torch::CustomClassHolder {
  30. public:
  31. virtual ~IEngine() {}
  32. /**
  33. * @brief init, initialization must be successful if the init is successful
  34. * @return int
  35. * @retval 0 => success, <0 => fail
  36. **/
  37. virtual int init() = 0;
  38. /**
  39. * @brief During compilation, the subgraph is converted into the graph
  40. *structure of the corresponding engine and stored inside the engine, so that
  41. *the execute_engine at runtime can be called
  42. * @param [in] sub_graph : subgraph
  43. * @return [res]int
  44. * @retval 0 => success, <0 => fail
  45. **/
  46. virtual int transform(const PorosGraph &sub_graph) = 0;
  47. /**
  48. * @brief Subgraph execution period logic
  49. * @param [in] inputs : input tensor
  50. * @return [res] output tensor
  51. **/
  52. virtual std::vector<at::Tensor>
  53. excute_engine(const std::vector<at::Tensor> &inputs) = 0;
  54. virtual void register_module_attribute(const std::string &name,
  55. torch::jit::Module &module) = 0;
  56. // Logo
  57. virtual const std::string who_am_i() = 0;
  58. // Whether the node is supported by the current engine
  59. bool is_node_supported(const torch::jit::Node *node);
  60. public:
  61. std::pair<uint64_t, uint64_t> _num_io; // Number of input/output parameters
  62. EngineID _id;
  63. };
  64. } // namespace poros
  65. } // namespace mirana
  66. } // namespace baidu