__init__.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import logging
  16. from ... import RuntimeOption, UltraInferModel, ModelFormat
  17. from ... import c_lib_wrap as C
  18. class SchemaLanguage(object):
  19. ZH = 0
  20. EN = 1
  21. class SchemaNode(object):
  22. def __init__(self, name, children=[]):
  23. schema_node_children = []
  24. if isinstance(children, str):
  25. children = [children]
  26. for child in children:
  27. if isinstance(child, str):
  28. schema_node_children += [C.text.SchemaNode(child, [])]
  29. elif isinstance(child, dict):
  30. for key, val in child.items():
  31. schema_node_child = SchemaNode(key, val)
  32. schema_node_children += [schema_node_child._schema_node]
  33. else:
  34. assert "The type of child of SchemaNode should be str or dict."
  35. self._schema_node = C.text.SchemaNode(name, schema_node_children)
  36. self._schema_node_children = schema_node_children
  37. class UIEModel(UltraInferModel):
  38. def __init__(
  39. self,
  40. model_file,
  41. params_file,
  42. vocab_file,
  43. position_prob=0.5,
  44. max_length=128,
  45. schema=[],
  46. batch_size=64,
  47. runtime_option=RuntimeOption(),
  48. model_format=ModelFormat.PADDLE,
  49. schema_language=SchemaLanguage.ZH,
  50. ):
  51. if isinstance(schema, list):
  52. schema = SchemaNode("", schema)._schema_node_children
  53. elif isinstance(schema, dict):
  54. schema_tmp = []
  55. for key, val in schema.items():
  56. schema_tmp += [SchemaNode(key, val)._schema_node]
  57. schema = schema_tmp
  58. else:
  59. assert "The type of schema should be list or dict."
  60. schema_language = C.text.SchemaLanguage(schema_language)
  61. self._model = C.text.UIEModel(
  62. model_file,
  63. params_file,
  64. vocab_file,
  65. position_prob,
  66. max_length,
  67. schema,
  68. batch_size,
  69. runtime_option._option,
  70. model_format,
  71. schema_language,
  72. )
  73. assert self.initialized, "UIEModel initialize failed."
  74. def set_schema(self, schema):
  75. if isinstance(schema, list):
  76. schema = SchemaNode("", schema)._schema_node_children
  77. elif isinstance(schema, dict):
  78. schema_tmp = []
  79. for key, val in schema.items():
  80. schema_tmp += [SchemaNode(key, val)._schema_node]
  81. schema = schema_tmp
  82. self._model.set_schema(schema)
  83. def predict(self, texts, return_dict=False):
  84. results = self._model.predict(texts)
  85. if not return_dict:
  86. return results
  87. new_results = []
  88. for result in results:
  89. uie_result = dict()
  90. for key, uie_results in result.items():
  91. uie_result[key] = list()
  92. for uie_res in uie_results:
  93. uie_result[key].append(uie_res.get_dict())
  94. new_results += [uie_result]
  95. return new_results