model.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import abc
  16. import logging
  17. from . import c_lib_wrap as C
  18. class BaseUltraInferModel(metaclass=abc.ABCMeta):
  19. @abc.abstractmethod
  20. def model_name(self):
  21. raise NotImplementedError
  22. @abc.abstractmethod
  23. def num_inputs_of_runtime(self):
  24. raise NotImplementedError
  25. @abc.abstractmethod
  26. def num_outputs_of_runtime(self):
  27. raise NotImplementedError
  28. class UltraInferModel(BaseUltraInferModel):
  29. def __init__(self, option):
  30. self._model = None
  31. if option is None:
  32. self._runtime_option = C.RuntimeOption()
  33. else:
  34. self._runtime_option = option._option
  35. def model_name(self):
  36. return self._model.model_name()
  37. def num_inputs_of_runtime(self):
  38. return self._model.num_inputs_of_runtime()
  39. def num_outputs_of_runtime(self):
  40. return self._model.num_outputs_of_runtime()
  41. def input_info_of_runtime(self, index):
  42. assert (
  43. index < self.num_inputs_of_runtime()
  44. ), "The index:{} must be less than number of inputs:{}.".format(
  45. index, self.num_inputs_of_runtime()
  46. )
  47. return self._model.input_info_of_runtime(index)
  48. def output_info_of_runtime(self, index):
  49. assert (
  50. index < self.num_outputs_of_runtime()
  51. ), "The index:{} must be less than number of outputs:{}.".format(
  52. index, self.num_outputs_of_runtime()
  53. )
  54. return self._model.output_info_of_runtime(index)
  55. def enable_record_time_of_runtime(self):
  56. self._model.enable_record_time_of_runtime()
  57. def disable_record_time_of_runtime(self):
  58. self._model.disable_record_time_of_runtime()
  59. def print_statis_info_of_runtime(self):
  60. return self._model.print_statis_info_of_runtime()
  61. def get_profile_time(self):
  62. """Get profile time of Runtime after the profile process is done."""
  63. return self._model.get_profile_time()
  64. @property
  65. def runtime_option(self):
  66. return self._model.runtime_option if self._model is not None else None
  67. @property
  68. def initialized(self):
  69. if self._model is None:
  70. return False
  71. return self._model.initialized()