model.cc 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/text/uie/model.h"
  15. #include <algorithm>
  16. #include <codecvt>
  17. #include <locale>
  18. #include <queue>
  19. #include <sstream>
  20. #include "fast_tokenizer/pretokenizers/pretokenizer.h"
  21. #include "fast_tokenizer/utils/utf8.h"
  22. #include "ultra_infer/function/concat.h"
  23. #include "ultra_infer/function/split.h"
  24. namespace ultra_infer {
  25. namespace text {
  26. static std::string DBC2SBC(const std::string &content) {
  27. std::string result;
  28. size_t content_utf8_len = 0;
  29. while (content_utf8_len < content.length()) {
  30. uint32_t content_char;
  31. auto content_char_width = fast_tokenizer::utils::UTF8ToUInt32(
  32. content.data() + content_utf8_len, &content_char);
  33. content_char = fast_tokenizer::utils::UTF8ToUnicode(content_char);
  34. if (content_char == 0x3000) {
  35. content_char = 0x0020;
  36. } else {
  37. content_char -= 0xfee0;
  38. }
  39. if (!(content_char >= 0x0021 && content_char <= 0x7e)) {
  40. result.append(content.data() + content_utf8_len, content_char_width);
  41. } else {
  42. char dst_char[5] = {0};
  43. uint32_t utf8_uint32 = fast_tokenizer::utils::UnicodeToUTF8(content_char);
  44. uint32_t utf8_char_count =
  45. fast_tokenizer::utils::UnicodeToUTF8Char(utf8_uint32, dst_char);
  46. result.append(dst_char, utf8_char_count);
  47. }
  48. content_utf8_len += content_char_width;
  49. }
  50. return result;
  51. }
  52. static std::ostream &PrintResult(std::ostream &os, const UIEResult &result,
  53. int tab_size) {
  54. constexpr int TAB_OFFSET = 4;
  55. // Print text
  56. for (int i = 0; i < tab_size; ++i) {
  57. os << " ";
  58. }
  59. os << "text: " << result.text_ << "\n";
  60. // Print probability
  61. for (int i = 0; i < tab_size; ++i) {
  62. os << " ";
  63. }
  64. os << "probability: " << result.probability_ << "\n";
  65. if (result.start_ != 0 || result.end_ != 0) {
  66. // Print start
  67. for (int i = 0; i < tab_size; ++i) {
  68. os << " ";
  69. }
  70. os << "start: " << result.start_ << "\n";
  71. // Print end
  72. for (int i = 0; i < tab_size; ++i) {
  73. os << " ";
  74. }
  75. os << "end: " << result.end_ << "\n";
  76. }
  77. // Print relation
  78. if (result.relation_.size() > 0) {
  79. for (int i = 0; i < tab_size; ++i) {
  80. os << " ";
  81. }
  82. os << "relation:\n";
  83. for (auto &&curr_relation : result.relation_) {
  84. for (int i = 0; i < tab_size + TAB_OFFSET; ++i) {
  85. os << " ";
  86. }
  87. os << curr_relation.first << ":\n";
  88. for (int i = 0; i < curr_relation.second.size(); ++i) {
  89. PrintResult(os, curr_relation.second[i],
  90. tab_size + TAB_OFFSET + TAB_OFFSET);
  91. }
  92. }
  93. }
  94. os << "\n";
  95. return os;
  96. }
  97. std::ostream &operator<<(std::ostream &os, const UIEResult &result) {
  98. return PrintResult(os, result, 0);
  99. }
  100. std::ostream &operator<<(
  101. std::ostream &os,
  102. const std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>
  103. &results) {
  104. os << "The result:\n";
  105. for (int i = 0; i < results.size(); ++i) {
  106. for (auto &&curr_result : results[i]) {
  107. os << curr_result.first << ": \n";
  108. for (auto &&uie_result : curr_result.second) {
  109. PrintResult(os, uie_result, 4);
  110. }
  111. }
  112. os << std::endl;
  113. }
  114. return os;
  115. }
  116. std::string UIEResult::Str() const {
  117. std::ostringstream oss;
  118. oss << *this;
  119. return oss.str();
  120. }
  121. void Schema::CreateRoot(const std::string &name) {
  122. root_ = ultra_infer::utils::make_unique<SchemaNode>(name);
  123. }
  124. Schema::Schema(const std::string &schema, const std::string &name) {
  125. CreateRoot(name);
  126. root_->AddChild(schema);
  127. }
  128. Schema::Schema(const std::vector<std::string> &schema_list,
  129. const std::string &name) {
  130. CreateRoot(name);
  131. for (const auto &schema : schema_list) {
  132. root_->AddChild(schema);
  133. }
  134. }
  135. Schema::Schema(const std::vector<SchemaNode> &schema_list,
  136. const std::string &name) {
  137. CreateRoot(name);
  138. for (const auto &schema : schema_list) {
  139. root_->AddChild(schema);
  140. }
  141. }
  142. Schema::Schema(const SchemaNode &schema, const std::string &name) {
  143. CreateRoot(name);
  144. root_->AddChild(schema);
  145. }
  146. UIEModel::UIEModel(const std::string &model_file,
  147. const std::string &params_file,
  148. const std::string &vocab_file, float position_prob,
  149. size_t max_length, const std::vector<std::string> &schema,
  150. int batch_size,
  151. const ultra_infer::RuntimeOption &custom_option,
  152. const ultra_infer::ModelFormat &model_format,
  153. SchemaLanguage schema_language)
  154. : max_length_(max_length), position_prob_(position_prob),
  155. schema_language_(schema_language), batch_size_(batch_size),
  156. tokenizer_(vocab_file) {
  157. runtime_option = custom_option;
  158. runtime_option.SetModelPath(model_file, params_file, model_format);
  159. initialized = Initialize();
  160. SetSchema(schema);
  161. tokenizer_.EnableTruncMethod(
  162. max_length, 0, fast_tokenizer::core::Direction::RIGHT,
  163. fast_tokenizer::core::TruncStrategy::LONGEST_FIRST);
  164. }
  165. UIEModel::UIEModel(const std::string &model_file,
  166. const std::string &params_file,
  167. const std::string &vocab_file, float position_prob,
  168. size_t max_length, const std::vector<SchemaNode> &schema,
  169. int batch_size,
  170. const ultra_infer::RuntimeOption &custom_option,
  171. const ultra_infer::ModelFormat &model_format,
  172. SchemaLanguage schema_language)
  173. : max_length_(max_length), position_prob_(position_prob),
  174. schema_language_(schema_language), batch_size_(batch_size),
  175. tokenizer_(vocab_file) {
  176. runtime_option = custom_option;
  177. runtime_option.SetModelPath(model_file, params_file, model_format);
  178. initialized = Initialize();
  179. SetSchema(schema);
  180. tokenizer_.EnableTruncMethod(
  181. max_length, 0, fast_tokenizer::core::Direction::RIGHT,
  182. fast_tokenizer::core::TruncStrategy::LONGEST_FIRST);
  183. }
  184. UIEModel::UIEModel(const std::string &model_file,
  185. const std::string &params_file,
  186. const std::string &vocab_file, float position_prob,
  187. size_t max_length, const SchemaNode &schema, int batch_size,
  188. const ultra_infer::RuntimeOption &custom_option,
  189. const ultra_infer::ModelFormat &model_format,
  190. SchemaLanguage schema_language)
  191. : max_length_(max_length), position_prob_(position_prob),
  192. schema_language_(schema_language), batch_size_(batch_size),
  193. tokenizer_(vocab_file) {
  194. runtime_option = custom_option;
  195. runtime_option.SetModelPath(model_file, params_file, model_format);
  196. initialized = Initialize();
  197. SetSchema(schema);
  198. tokenizer_.EnableTruncMethod(
  199. max_length, 0, fast_tokenizer::core::Direction::RIGHT,
  200. fast_tokenizer::core::TruncStrategy::LONGEST_FIRST);
  201. }
  202. bool UIEModel::Initialize() {
  203. SetValidBackend();
  204. return InitRuntime();
  205. }
  206. void UIEModel::SetValidBackend() {
  207. // TODO(zhoushunjie): Add lite backend in future
  208. valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER,
  209. Backend::LITE};
  210. valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
  211. }
  212. void UIEModel::SetSchema(const std::vector<std::string> &schema) {
  213. schema_ = ultra_infer::utils::make_unique<Schema>(schema);
  214. }
  215. void UIEModel::SetSchema(const std::vector<SchemaNode> &schema) {
  216. schema_ = ultra_infer::utils::make_unique<Schema>(schema);
  217. }
  218. void UIEModel::SetSchema(const SchemaNode &schema) {
  219. schema_ = ultra_infer::utils::make_unique<Schema>(schema);
  220. }
  221. void UIEModel::AutoSplitter(const std::vector<std::string> &texts,
  222. size_t max_length,
  223. std::vector<std::string> *short_texts,
  224. std::vector<std::vector<size_t>> *input_mapping) {
  225. size_t cnt_org = 0;
  226. size_t cnt_short = 0;
  227. for (auto &text : texts) {
  228. auto text_len = fast_tokenizer::utils::GetUnicodeLenFromUTF8(text.c_str(),
  229. text.length());
  230. if (text_len <= max_length) {
  231. short_texts->push_back(text);
  232. if (input_mapping->size() <= cnt_org) {
  233. input_mapping->push_back({cnt_short});
  234. } else {
  235. (*input_mapping)[cnt_org].push_back(cnt_short);
  236. }
  237. cnt_short += 1;
  238. } else {
  239. fast_tokenizer::pretokenizers::CharToBytesOffsetConverter converter(text);
  240. for (size_t start = 0; start < text_len; start += max_length) {
  241. size_t end = start + max_length;
  242. if (end > text_len) {
  243. end = text_len;
  244. }
  245. fast_tokenizer::core::Offset byte_offset;
  246. converter.convert({start, end}, &byte_offset);
  247. short_texts->emplace_back(text.data() + byte_offset.first,
  248. byte_offset.second - byte_offset.first);
  249. }
  250. auto short_idx = cnt_short;
  251. cnt_short += text_len / max_length;
  252. if (text_len % max_length != 0) {
  253. ++cnt_short;
  254. }
  255. std::vector<size_t> temp_text_id(cnt_short - short_idx);
  256. std::iota(temp_text_id.begin(), temp_text_id.end(), short_idx);
  257. if (input_mapping->size() <= cnt_org) {
  258. input_mapping->push_back(std::move(temp_text_id));
  259. } else {
  260. (*input_mapping)[cnt_org].insert((*input_mapping)[cnt_org].end(),
  261. temp_text_id.begin(),
  262. temp_text_id.end());
  263. }
  264. }
  265. cnt_org += 1;
  266. }
  267. }
  268. void UIEModel::GetCandidateIdx(
  269. const float *probs, int64_t batch_size, int64_t seq_len,
  270. std::vector<std::vector<std::pair<int64_t, float>>> *candidate_idx_prob,
  271. float threshold) const {
  272. for (int i = 0; i < batch_size; ++i) {
  273. candidate_idx_prob->push_back({});
  274. for (int j = 0; j < seq_len; ++j) {
  275. if (probs[i * seq_len + j] > threshold) {
  276. candidate_idx_prob->back().push_back({j, probs[i * seq_len + j]});
  277. }
  278. }
  279. }
  280. }
  281. bool UIEModel::IdxProbCmp::operator()(
  282. const std::pair<IDX_PROB, IDX_PROB> &lhs,
  283. const std::pair<IDX_PROB, IDX_PROB> &rhs) const {
  284. if (lhs.first.first == rhs.first.first) {
  285. return lhs.second.first < rhs.second.first;
  286. }
  287. return lhs.first.first < rhs.first.first;
  288. }
  289. void UIEModel::GetSpan(const std::vector<IDX_PROB> &start_idx_prob,
  290. const std::vector<IDX_PROB> &end_idx_prob,
  291. SPAN_SET *span_set) const {
  292. size_t start_pointer = 0;
  293. size_t end_pointer = 0;
  294. size_t len_start = start_idx_prob.size();
  295. size_t len_end = end_idx_prob.size();
  296. while (start_pointer < len_start && end_pointer < len_end) {
  297. if (start_idx_prob[start_pointer].first ==
  298. end_idx_prob[end_pointer].first) {
  299. span_set->insert(std::make_pair(start_idx_prob[start_pointer],
  300. end_idx_prob[end_pointer]));
  301. ++start_pointer;
  302. ++end_pointer;
  303. } else if (start_idx_prob[start_pointer].first <
  304. end_idx_prob[end_pointer].first) {
  305. span_set->insert(std::make_pair(start_idx_prob[start_pointer],
  306. end_idx_prob[end_pointer]));
  307. ++start_pointer;
  308. } else {
  309. ++end_pointer;
  310. }
  311. }
  312. }
  313. void UIEModel::GetSpanIdxAndProbs(
  314. const SPAN_SET &span_set,
  315. const std::vector<fast_tokenizer::core::Offset> &offset_mapping,
  316. std::vector<SpanIdx> *span_idxs, std::vector<float> *probs) const {
  317. auto first_sep_idx =
  318. std::find_if(offset_mapping.begin() + 1, offset_mapping.end(),
  319. [](const fast_tokenizer::core::Offset &offset) {
  320. return offset == fast_tokenizer::core::Offset(0, 0);
  321. });
  322. auto prompt_end_token_id =
  323. std::distance(offset_mapping.begin(), first_sep_idx) - 1;
  324. for (auto &&span_item : span_set) {
  325. probs->push_back(span_item.first.second * span_item.second.second);
  326. auto start_id = offset_mapping[span_item.first.first].first;
  327. auto end_id = offset_mapping[span_item.second.first].second;
  328. bool is_prompt = span_item.second.first <= prompt_end_token_id &&
  329. span_item.second.first > 0;
  330. span_idxs->push_back({{start_id, end_id}, is_prompt});
  331. }
  332. }
  333. void UIEModel::ConvertSpanToUIEResult(
  334. const std::vector<std::string> &texts,
  335. const std::vector<std::string> &prompts,
  336. const std::vector<std::vector<SpanIdx>> &span_idxs,
  337. const std::vector<std::vector<float>> &probs,
  338. std::vector<std::vector<UIEResult>> *results) const {
  339. auto batch_size = texts.size();
  340. for (int i = 0; i < batch_size; ++i) {
  341. std::vector<UIEResult> result_list;
  342. if (span_idxs[i].size() == 0) {
  343. results->push_back({});
  344. continue;
  345. }
  346. auto &&text = texts[i];
  347. auto &&prompt = prompts[i];
  348. for (int j = 0; j < span_idxs[i].size(); ++j) {
  349. auto start = span_idxs[i][j].offset_.first;
  350. auto end = span_idxs[i][j].offset_.second;
  351. std::string span_text;
  352. std::vector<uint32_t> offset_mapping;
  353. if (span_idxs[i][j].is_prompt_) {
  354. fast_tokenizer::pretokenizers::CharToBytesOffsetConverter converter(
  355. prompt);
  356. fast_tokenizer::core::Offset byte_offset;
  357. converter.convert({start, end}, &byte_offset);
  358. span_text = prompt.substr(byte_offset.first,
  359. byte_offset.second - byte_offset.first);
  360. // Indicate cls task
  361. start = 0;
  362. end = 0;
  363. } else {
  364. fast_tokenizer::pretokenizers::CharToBytesOffsetConverter converter(
  365. text);
  366. fast_tokenizer::core::Offset byte_offset;
  367. converter.convert({start, end}, &byte_offset);
  368. span_text = text.substr(byte_offset.first,
  369. byte_offset.second - byte_offset.first);
  370. }
  371. result_list.emplace_back(start, end, probs[i][j], span_text);
  372. }
  373. results->push_back(result_list);
  374. }
  375. }
  376. void UIEModel::AutoJoiner(const std::vector<std::string> &short_texts,
  377. const std::vector<std::vector<size_t>> &input_mapping,
  378. std::vector<std::vector<UIEResult>> *results) {
  379. bool is_cls_task = false;
  380. // 1. Detect if it's a cls task
  381. for (auto &&short_result : *results) {
  382. if (short_result.size() == 0) {
  383. continue;
  384. } else if (short_result[0].start_ == 0 && short_result[0].end_ == 0) {
  385. is_cls_task = true;
  386. break;
  387. } else {
  388. break;
  389. }
  390. }
  391. // 2. Get the final result
  392. std::vector<std::vector<UIEResult>> final_result;
  393. if (is_cls_task) {
  394. for (auto &&input_mapping_item : input_mapping) {
  395. std::unordered_map<std::string, std::pair<int, float>> cls_options;
  396. for (auto &&result_idx : input_mapping_item) {
  397. if ((*results)[result_idx].size() == 0) {
  398. continue;
  399. }
  400. auto &&text = (*results)[result_idx].front().text_;
  401. auto &&probability = (*results)[result_idx].front().probability_;
  402. if (cls_options.count(text) == 0) {
  403. cls_options[text] = std::make_pair(1, probability);
  404. } else {
  405. cls_options[text].first += 1;
  406. cls_options[text].second += probability;
  407. }
  408. }
  409. std::vector<UIEResult> result_list;
  410. if (cls_options.size() > 0) {
  411. auto max_iter = std::max_element(
  412. cls_options.begin(), cls_options.end(),
  413. [](const std::pair<std::string, std::pair<int, float>> &lhs,
  414. const std::pair<std::string, std::pair<int, float>> &rhs) {
  415. return lhs.second.second < rhs.second.second;
  416. });
  417. result_list.emplace_back(
  418. 0, 0, max_iter->second.second / max_iter->second.first,
  419. max_iter->first);
  420. }
  421. final_result.push_back(result_list);
  422. }
  423. } else {
  424. for (auto &&input_mapping_item : input_mapping) {
  425. size_t offset = 0;
  426. std::vector<UIEResult> result_list;
  427. for (auto &&result_idx : input_mapping_item) {
  428. if (result_idx == 0) {
  429. result_list = std::move((*results)[result_idx]);
  430. offset += fast_tokenizer::utils::GetUnicodeLenFromUTF8(
  431. short_texts[result_idx].c_str(), short_texts[result_idx].size());
  432. } else {
  433. for (auto &&curr_result : (*results)[result_idx]) {
  434. curr_result.start_ += offset;
  435. curr_result.end_ += offset;
  436. }
  437. offset += fast_tokenizer::utils::GetUnicodeLenFromUTF8(
  438. short_texts[result_idx].c_str(), short_texts[result_idx].size());
  439. result_list.insert(result_list.end(), (*results)[result_idx].begin(),
  440. (*results)[result_idx].end());
  441. }
  442. }
  443. final_result.push_back(result_list);
  444. }
  445. }
  446. *results = std::move(final_result);
  447. }
  448. bool UIEModel::ConstructTextsAndPrompts(
  449. const std::vector<std::string> &raw_texts, const std::string &node_name,
  450. const std::vector<std::vector<std::string>> node_prefix,
  451. std::vector<std::string> *input_texts, std::vector<std::string> *prompts,
  452. std::vector<std::vector<size_t>> *input_mapping_with_raw_texts,
  453. std::vector<std::vector<size_t>> *input_mapping) {
  454. size_t idx = 0;
  455. if (node_prefix.empty()) {
  456. for (int i = 0; i < raw_texts.size(); ++i) {
  457. input_texts->push_back(raw_texts[i]);
  458. prompts->push_back(DBC2SBC(node_name));
  459. input_mapping_with_raw_texts->push_back({idx});
  460. idx += 1;
  461. }
  462. } else {
  463. for (int i = 0; i < raw_texts.size(); ++i) {
  464. if (node_prefix[i].size() == 0) {
  465. input_mapping_with_raw_texts->push_back({});
  466. } else {
  467. for (auto &&pre : node_prefix[i]) {
  468. input_texts->push_back(raw_texts[i]);
  469. prompts->push_back(DBC2SBC(pre + node_name));
  470. }
  471. auto prefix_len = node_prefix[i].size();
  472. input_mapping_with_raw_texts->push_back({});
  473. input_mapping_with_raw_texts->back().resize(prefix_len);
  474. std::iota(input_mapping_with_raw_texts->back().begin(),
  475. input_mapping_with_raw_texts->back().end(), idx);
  476. idx += prefix_len;
  477. }
  478. }
  479. }
  480. if (prompts->size() == 0) {
  481. return false;
  482. }
  483. // Shortten the input texts and prompts
  484. auto max_prompt_iter = std::max_element(
  485. prompts->begin(), prompts->end(),
  486. [](const std::string &lhs, const std::string &rhs) {
  487. auto lhs_ulen = fast_tokenizer::utils::GetUnicodeLenFromUTF8(
  488. lhs.c_str(), lhs.length());
  489. auto rhs_ulen = fast_tokenizer::utils::GetUnicodeLenFromUTF8(
  490. rhs.c_str(), rhs.length());
  491. return lhs_ulen < rhs_ulen;
  492. });
  493. auto max_prompt_len = fast_tokenizer::utils::GetUnicodeLenFromUTF8(
  494. max_prompt_iter->c_str(), max_prompt_iter->length());
  495. auto max_predict_len = max_length_ - 3 - max_prompt_len;
  496. std::vector<std::string> short_texts;
  497. AutoSplitter(*input_texts, max_predict_len, &short_texts, input_mapping);
  498. std::vector<std::string> short_texts_prompts;
  499. for (int i = 0; i < input_mapping->size(); ++i) {
  500. short_texts_prompts.insert(short_texts_prompts.end(),
  501. (*input_mapping)[i].size(), (*prompts)[i]);
  502. }
  503. (*input_texts) = std::move(short_texts);
  504. (*prompts) = std::move(short_texts_prompts);
  505. return true;
  506. }
  507. void UIEModel::Preprocess(
  508. const std::vector<std::string> &input_texts,
  509. const std::vector<std::string> &prompts,
  510. std::vector<fast_tokenizer::core::Encoding> *encodings,
  511. std::vector<ultra_infer::FDTensor> *inputs) {
  512. // 1. Tokenize the short texts and short prompts
  513. std::vector<fast_tokenizer::core::EncodeInput> text_pair_input;
  514. for (int i = 0; i < input_texts.size(); ++i) {
  515. text_pair_input.emplace_back(
  516. std::pair<std::string, std::string>(prompts[i], input_texts[i]));
  517. }
  518. tokenizer_.EncodeBatchStrings(text_pair_input, encodings);
  519. // 2. Construct the input vector tensor
  520. // 2.1 Allocate input tensor
  521. int64_t batch_size = input_texts.size();
  522. int64_t seq_len = 0;
  523. if (batch_size > 0) {
  524. seq_len = (*encodings)[0].GetIds().size();
  525. }
  526. inputs->resize(NumInputsOfRuntime());
  527. for (int i = 0; i < NumInputsOfRuntime(); ++i) {
  528. (*inputs)[i].Allocate({batch_size, seq_len}, ultra_infer::FDDataType::INT64,
  529. InputInfoOfRuntime(i).name);
  530. }
  531. // 2.2 Set the value of data
  532. size_t start = 0;
  533. int64_t *input_ids_ptr =
  534. reinterpret_cast<int64_t *>((*inputs)[0].MutableData());
  535. int64_t *type_ids_ptr =
  536. reinterpret_cast<int64_t *>((*inputs)[1].MutableData());
  537. int64_t *pos_ids_ptr =
  538. reinterpret_cast<int64_t *>((*inputs)[2].MutableData());
  539. int64_t *attn_mask_ptr =
  540. reinterpret_cast<int64_t *>((*inputs)[3].MutableData());
  541. for (int i = 0; i < encodings->size(); ++i) {
  542. auto &&curr_input_ids = (*encodings)[i].GetIds();
  543. auto &&curr_type_ids = (*encodings)[i].GetTypeIds();
  544. auto &&curr_attn_mask = (*encodings)[i].GetAttentionMask();
  545. std::copy(curr_input_ids.begin(), curr_input_ids.end(),
  546. input_ids_ptr + start);
  547. std::copy(curr_type_ids.begin(), curr_type_ids.end(), type_ids_ptr + start);
  548. std::iota(pos_ids_ptr + start, pos_ids_ptr + start + seq_len, 0);
  549. std::copy(curr_attn_mask.begin(), curr_attn_mask.end(),
  550. attn_mask_ptr + start);
  551. start += seq_len;
  552. }
  553. }
  554. void UIEModel::Postprocess(
  555. const std::vector<ultra_infer::FDTensor> &outputs,
  556. const std::vector<fast_tokenizer::core::Encoding> &encodings,
  557. const std::vector<std::string> &short_input_texts,
  558. const std::vector<std::string> &short_prompts,
  559. const std::vector<std::vector<size_t>> &input_mapping_with_short_text,
  560. std::vector<std::vector<UIEResult>> *results) {
  561. auto *start_prob = reinterpret_cast<const float *>(outputs[0].Data());
  562. auto *end_prob = reinterpret_cast<const float *>(outputs[1].Data());
  563. std::vector<std::vector<std::pair<int64_t, float>>> start_candidate_idx_prob,
  564. end_candidate_idx_prob;
  565. GetCandidateIdx(start_prob, outputs[0].shape[0], outputs[0].shape[1],
  566. &start_candidate_idx_prob, position_prob_);
  567. GetCandidateIdx(end_prob, outputs[1].shape[0], outputs[1].shape[1],
  568. &end_candidate_idx_prob, position_prob_);
  569. std::vector<std::vector<fast_tokenizer::core::Offset>> offset_mapping;
  570. for (int i = 0; i < encodings.size(); ++i) {
  571. auto &&curr_offsets = encodings[i].GetOffsets();
  572. offset_mapping.push_back(curr_offsets);
  573. }
  574. SPAN_SET span_set;
  575. auto batch_size = outputs[0].shape[0];
  576. std::vector<std::vector<float>> probs(batch_size);
  577. std::vector<std::vector<SpanIdx>> span_idxs(batch_size);
  578. for (int i = 0; i < batch_size; ++i) {
  579. GetSpan(start_candidate_idx_prob[i], end_candidate_idx_prob[i], &span_set);
  580. GetSpanIdxAndProbs(span_set, offset_mapping[i], &span_idxs[i], &probs[i]);
  581. span_set.clear();
  582. }
  583. ConvertSpanToUIEResult(short_input_texts, short_prompts, span_idxs, probs,
  584. results);
  585. AutoJoiner(short_input_texts, input_mapping_with_short_text, results);
  586. }
  587. void UIEModel::ConstructChildPromptPrefix(
  588. const std::vector<std::vector<size_t>> &input_mapping_with_raw_texts,
  589. const std::vector<std::vector<UIEResult>> &results_list,
  590. std::vector<std::vector<std::string>> *prefix) {
  591. prefix->resize(input_mapping_with_raw_texts.size());
  592. for (int i = 0; i < input_mapping_with_raw_texts.size(); ++i) {
  593. auto &&input_mapping_item = input_mapping_with_raw_texts[i];
  594. for (auto &&idx : input_mapping_item) {
  595. for (int j = 0; j < results_list[idx].size(); ++j) {
  596. std::string prefix_str;
  597. if (schema_language_ == SchemaLanguage::ZH) {
  598. // Note(zhoushunjie): It means "of" in Chinese.
  599. prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84";
  600. } else {
  601. prefix_str = " of " + results_list[idx][j].text_;
  602. }
  603. (*prefix)[i].push_back(prefix_str);
  604. }
  605. }
  606. }
  607. }
  608. void UIEModel::ConstructChildRelations(
  609. const std::vector<std::vector<UIEResult *>> &old_relations,
  610. const std::vector<std::vector<size_t>> &input_mapping_with_raw_texts,
  611. const std::vector<std::vector<UIEResult>> &results_list,
  612. const std::string &node_name,
  613. std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>
  614. *results,
  615. std::vector<std::vector<UIEResult *>> *new_relations) {
  616. new_relations->resize(input_mapping_with_raw_texts.size());
  617. if (old_relations.size() == 0) {
  618. for (int i = 0; i < input_mapping_with_raw_texts.size(); ++i) {
  619. auto &&input_mapping_item = input_mapping_with_raw_texts[i];
  620. auto &curr_result = (*results)[i];
  621. for (auto &&idx : input_mapping_item) {
  622. if (results_list[idx].size() == 0) {
  623. continue;
  624. }
  625. if (curr_result.count(node_name) == 0) {
  626. curr_result[node_name] = results_list[idx];
  627. } else {
  628. curr_result[node_name].insert(curr_result[node_name].end(),
  629. results_list[idx].begin(),
  630. results_list[idx].end());
  631. }
  632. }
  633. if (curr_result.count(node_name) > 0) {
  634. for (auto &&curr_result_ref : curr_result[node_name]) {
  635. (*new_relations)[i].push_back(&curr_result_ref);
  636. }
  637. }
  638. }
  639. } else {
  640. auto &curr_relations = old_relations;
  641. for (int i = 0; i < input_mapping_with_raw_texts.size(); ++i) {
  642. auto &&input_mapping_item = input_mapping_with_raw_texts[i];
  643. for (int j = 0; j < input_mapping_item.size(); ++j) {
  644. auto idx = input_mapping_item[j];
  645. if (results_list[idx].size() == 0) {
  646. continue;
  647. }
  648. if (curr_relations[i][j]->relation_.count(node_name) == 0) {
  649. curr_relations[i][j]->relation_[node_name] = results_list[idx];
  650. } else {
  651. auto &curr_result = curr_relations[i][j]->relation_[node_name];
  652. curr_result.insert(curr_result.end(), results_list[idx].begin(),
  653. results_list[idx].end());
  654. }
  655. }
  656. }
  657. for (int i = 0; i < curr_relations.size(); ++i) {
  658. for (int j = 0; j < curr_relations[i].size(); ++j) {
  659. if (curr_relations[i][j]->relation_.count(node_name)) {
  660. auto &curr_relation = curr_relations[i][j]->relation_[node_name];
  661. for (auto &&curr_result_ref : curr_relation) {
  662. (*new_relations)[i].push_back(&curr_result_ref);
  663. }
  664. }
  665. }
  666. }
  667. }
  668. }
  669. void UIEModel::Predict(
  670. const std::vector<std::string> &texts,
  671. std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>
  672. *results) {
  673. std::queue<SchemaNode> nodes;
  674. for (auto &node : schema_->root_->children_) {
  675. nodes.push(node);
  676. }
  677. results->resize(texts.size());
  678. while (!nodes.empty()) {
  679. auto node = nodes.front();
  680. nodes.pop();
  681. std::vector<std::vector<size_t>> input_mapping_with_raw_texts;
  682. std::vector<std::vector<size_t>> input_mapping_with_short_text;
  683. std::vector<std::string> short_input_texts;
  684. std::vector<std::string> short_prompts;
  685. // 1. Construct texts and prompts from raw text
  686. bool has_prompt = ConstructTextsAndPrompts(
  687. texts, node.name_, node.prefix_, &short_input_texts, &short_prompts,
  688. &input_mapping_with_raw_texts, &input_mapping_with_short_text);
  689. std::vector<std::vector<UIEResult>> results_list;
  690. if (has_prompt) {
  691. // 2. Convert texts and prompts to FDTensor
  692. std::vector<FDTensor> inputs;
  693. std::vector<fast_tokenizer::core::Encoding> encodings;
  694. Preprocess(short_input_texts, short_prompts, &encodings, &inputs);
  695. std::vector<std::vector<FDTensor>> inputs_vec(NumInputsOfRuntime());
  696. int encoding_size = encodings.size();
  697. std::vector<int> num_or_sections;
  698. for (int i = 0; i < encoding_size; i += batch_size_) {
  699. int actual_batch_size = (std::min)(batch_size_, encoding_size - i);
  700. num_or_sections.push_back(actual_batch_size);
  701. }
  702. for (int i = 0; i < NumInputsOfRuntime(); ++i) {
  703. function::Split(inputs[i], num_or_sections, &inputs_vec[i]);
  704. }
  705. // 3. Infer
  706. std::vector<ultra_infer::FDTensor> outputs(NumOutputsOfRuntime());
  707. std::vector<ultra_infer::FDTensor> outputs0, outputs1;
  708. for (int i = 0; i < inputs_vec[0].size(); ++i) {
  709. std::vector<ultra_infer::FDTensor> curr_inputs(NumInputsOfRuntime());
  710. std::vector<ultra_infer::FDTensor> curr_outputs(NumOutputsOfRuntime());
  711. for (int j = 0; j < NumInputsOfRuntime(); ++j) {
  712. curr_inputs[j] = std::move(inputs_vec[j][i]);
  713. curr_inputs[j].name = inputs[j].name;
  714. }
  715. if (!Infer(curr_inputs, &curr_outputs)) {
  716. FDERROR << "Failed to inference while using model:" << ModelName()
  717. << "." << std::endl;
  718. }
  719. outputs0.push_back(curr_outputs[0]);
  720. outputs1.push_back(curr_outputs[1]);
  721. }
  722. function::Concat(outputs0, &outputs[0]);
  723. function::Concat(outputs1, &outputs[1]);
  724. // 4. Convert FDTensor to UIEResult
  725. Postprocess(outputs, encodings, short_input_texts, short_prompts,
  726. input_mapping_with_short_text, &results_list);
  727. }
  728. // 5. Construct the new relation of the UIEResult
  729. std::vector<std::vector<UIEResult *>> relations;
  730. ConstructChildRelations(node.relations_, input_mapping_with_raw_texts,
  731. results_list, node.name_, results, &relations);
  732. // 6. Construct the next prompt prefix
  733. std::vector<std::vector<std::string>> prefix(texts.size());
  734. ConstructChildPromptPrefix(input_mapping_with_raw_texts, results_list,
  735. &prefix);
  736. for (auto &node_child : node.children_) {
  737. node_child.relations_ = relations;
  738. node_child.prefix_ = prefix;
  739. nodes.push(node_child);
  740. }
  741. }
  742. }
  743. } // namespace text
  744. } // namespace ultra_infer