model.h 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include "fast_tokenizer/tokenizers/ernie_fast_tokenizer.h"
  16. #include "ultra_infer/ultra_infer_model.h"
  17. #include "ultra_infer/utils/unique_ptr.h"
  18. #include <ostream>
  19. #include <set>
  20. #include <string>
  21. #include <unordered_map>
  22. #include <vector>
  23. using namespace paddlenlp;
  24. namespace ultra_infer {
  25. namespace text {
  26. struct ULTRAINFER_DECL UIEResult {
  27. size_t start_;
  28. size_t end_;
  29. double probability_;
  30. std::string text_;
  31. std::unordered_map<std::string, std::vector<UIEResult>> relation_;
  32. UIEResult() = default;
  33. UIEResult(size_t start, size_t end, double probability, std::string text)
  34. : start_(start), end_(end), probability_(probability), text_(text) {}
  35. std::string Str() const;
  36. };
  37. ULTRAINFER_DECL std::ostream &operator<<(std::ostream &os,
  38. const UIEResult &result);
  39. ULTRAINFER_DECL std::ostream &operator<<(
  40. std::ostream &os,
  41. const std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>
  42. &results);
  43. struct ULTRAINFER_DECL SchemaNode {
  44. std::string name_;
  45. std::vector<std::vector<std::string>> prefix_;
  46. std::vector<std::vector<UIEResult *>> relations_;
  47. std::vector<SchemaNode> children_;
  48. SchemaNode() = default;
  49. SchemaNode(const SchemaNode &) = default;
  50. explicit SchemaNode(const std::string &name,
  51. const std::vector<SchemaNode> &children = {})
  52. : name_(name), children_(children) {}
  53. void AddChild(const std::string &schema) { children_.emplace_back(schema); }
  54. void AddChild(const SchemaNode &schema) { children_.push_back(schema); }
  55. void AddChild(const std::string &schema,
  56. const std::vector<std::string> &children) {
  57. SchemaNode schema_node(schema);
  58. for (auto &child : children) {
  59. schema_node.children_.emplace_back(child);
  60. }
  61. children_.emplace_back(schema_node);
  62. }
  63. void AddChild(const std::string &schema,
  64. const std::vector<SchemaNode> &children) {
  65. SchemaNode schema_node(schema);
  66. schema_node.children_ = children;
  67. children_.emplace_back(schema_node);
  68. }
  69. };
  70. enum SchemaLanguage {
  71. ZH, // Chinese
  72. EN // English
  73. };
  74. struct Schema {
  75. explicit Schema(const std::string &schema, const std::string &name = "root");
  76. explicit Schema(const std::vector<std::string> &schema_list,
  77. const std::string &name = "root");
  78. explicit Schema(const std::vector<SchemaNode> &schema_list,
  79. const std::string &name = "root");
  80. explicit Schema(const SchemaNode &schema, const std::string &name = "root");
  81. private:
  82. void CreateRoot(const std::string &name);
  83. std::unique_ptr<SchemaNode> root_;
  84. friend class UIEModel;
  85. };
  86. struct ULTRAINFER_DECL UIEModel : public UltraInferModel {
  87. public:
  88. UIEModel(const std::string &model_file, const std::string &params_file,
  89. const std::string &vocab_file, float position_prob,
  90. size_t max_length, const std::vector<std::string> &schema,
  91. int batch_size,
  92. const ultra_infer::RuntimeOption &custom_option =
  93. ultra_infer::RuntimeOption(),
  94. const ultra_infer::ModelFormat &model_format =
  95. ultra_infer::ModelFormat::PADDLE,
  96. SchemaLanguage schema_language = SchemaLanguage::ZH);
  97. UIEModel(const std::string &model_file, const std::string &params_file,
  98. const std::string &vocab_file, float position_prob,
  99. size_t max_length, const SchemaNode &schema, int batch_size,
  100. const ultra_infer::RuntimeOption &custom_option =
  101. ultra_infer::RuntimeOption(),
  102. const ultra_infer::ModelFormat &model_format =
  103. ultra_infer::ModelFormat::PADDLE,
  104. SchemaLanguage schema_language = SchemaLanguage::ZH);
  105. UIEModel(const std::string &model_file, const std::string &params_file,
  106. const std::string &vocab_file, float position_prob,
  107. size_t max_length, const std::vector<SchemaNode> &schema,
  108. int batch_size,
  109. const ultra_infer::RuntimeOption &custom_option =
  110. ultra_infer::RuntimeOption(),
  111. const ultra_infer::ModelFormat &model_format =
  112. ultra_infer::ModelFormat::PADDLE,
  113. SchemaLanguage schema_language = SchemaLanguage::ZH);
  114. virtual std::string ModelName() const { return "UIEModel"; }
  115. void SetSchema(const std::vector<std::string> &schema);
  116. void SetSchema(const std::vector<SchemaNode> &schema);
  117. void SetSchema(const SchemaNode &schema);
  118. bool ConstructTextsAndPrompts(
  119. const std::vector<std::string> &raw_texts, const std::string &node_name,
  120. const std::vector<std::vector<std::string>> node_prefix,
  121. std::vector<std::string> *input_texts, std::vector<std::string> *prompts,
  122. std::vector<std::vector<size_t>> *input_mapping_with_raw_texts,
  123. std::vector<std::vector<size_t>> *input_mapping_with_short_text);
  124. void Preprocess(const std::vector<std::string> &input_texts,
  125. const std::vector<std::string> &prompts,
  126. std::vector<fast_tokenizer::core::Encoding> *encodings,
  127. std::vector<ultra_infer::FDTensor> *inputs);
  128. void Postprocess(
  129. const std::vector<ultra_infer::FDTensor> &outputs,
  130. const std::vector<fast_tokenizer::core::Encoding> &encodings,
  131. const std::vector<std::string> &short_input_texts,
  132. const std::vector<std::string> &short_prompts,
  133. const std::vector<std::vector<size_t>> &input_mapping_with_short_text,
  134. std::vector<std::vector<UIEResult>> *results);
  135. void ConstructChildPromptPrefix(
  136. const std::vector<std::vector<size_t>> &input_mapping_with_raw_texts,
  137. const std::vector<std::vector<UIEResult>> &results_list,
  138. std::vector<std::vector<std::string>> *prefix);
  139. void ConstructChildRelations(
  140. const std::vector<std::vector<UIEResult *>> &old_relations,
  141. const std::vector<std::vector<size_t>> &input_mapping_with_raw_texts,
  142. const std::vector<std::vector<UIEResult>> &results_list,
  143. const std::string &node_name,
  144. std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>
  145. *results,
  146. std::vector<std::vector<UIEResult *>> *new_relations);
  147. void
  148. Predict(const std::vector<std::string> &texts,
  149. std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>
  150. *results);
  151. protected:
  152. using IDX_PROB = std::pair<int64_t, float>;
  153. struct IdxProbCmp {
  154. bool operator()(const std::pair<IDX_PROB, IDX_PROB> &lhs,
  155. const std::pair<IDX_PROB, IDX_PROB> &rhs) const;
  156. };
  157. using SPAN_SET = std::set<std::pair<IDX_PROB, IDX_PROB>, IdxProbCmp>;
  158. struct SpanIdx {
  159. fast_tokenizer::core::Offset offset_;
  160. bool is_prompt_;
  161. };
  162. void SetValidBackend();
  163. bool Initialize();
  164. void AutoSplitter(const std::vector<std::string> &texts, size_t max_length,
  165. std::vector<std::string> *short_texts,
  166. std::vector<std::vector<size_t>> *input_mapping);
  167. void AutoJoiner(const std::vector<std::string> &short_texts,
  168. const std::vector<std::vector<size_t>> &input_mapping,
  169. std::vector<std::vector<UIEResult>> *results);
  170. // Get idx of the last dimension in probability arrays, which is greater than
  171. // a limitation.
  172. void GetCandidateIdx(const float *probs, int64_t batch_size, int64_t seq_len,
  173. std::vector<std::vector<IDX_PROB>> *candidate_idx_prob,
  174. float threshold = 0.5) const;
  175. void GetSpan(const std::vector<IDX_PROB> &start_idx_prob,
  176. const std::vector<IDX_PROB> &end_idx_prob,
  177. SPAN_SET *span_set) const;
  178. void GetSpanIdxAndProbs(
  179. const SPAN_SET &span_set,
  180. const std::vector<fast_tokenizer::core::Offset> &offset_mapping,
  181. std::vector<SpanIdx> *span_idxs, std::vector<float> *probs) const;
  182. void
  183. ConvertSpanToUIEResult(const std::vector<std::string> &texts,
  184. const std::vector<std::string> &prompts,
  185. const std::vector<std::vector<SpanIdx>> &span_idxs,
  186. const std::vector<std::vector<float>> &probs,
  187. std::vector<std::vector<UIEResult>> *results) const;
  188. std::unique_ptr<Schema> schema_;
  189. size_t max_length_;
  190. float position_prob_;
  191. int batch_size_;
  192. SchemaLanguage schema_language_;
  193. fast_tokenizer::tokenizers_impl::ErnieFastTokenizer tokenizer_;
  194. };
  195. } // namespace text
  196. } // namespace ultra_infer