yolox.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from .... import UltraInferModel, ModelFormat
  16. from .... import c_lib_wrap as C
  17. class YOLOX(UltraInferModel):
  18. def __init__(
  19. self,
  20. model_file,
  21. params_file="",
  22. runtime_option=None,
  23. model_format=ModelFormat.ONNX,
  24. ):
  25. """Load a YOLOX model exported by YOLOX.
  26. :param model_file: (str)Path of model file, e.g ./yolox.onnx
  27. :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
  28. :param runtime_option: (ultra_infer.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
  29. :param model_format: (ultra_infer.ModelForamt)Model format of the loaded model
  30. """
  31. # 调用基函数进行backend_option的初始化
  32. # 初始化后的option保存在self._runtime_option
  33. super(YOLOX, self).__init__(runtime_option)
  34. self._model = C.vision.detection.YOLOX(
  35. model_file, params_file, self._runtime_option, model_format
  36. )
  37. # 通过self.initialized判断整个模型的初始化是否成功
  38. assert self.initialized, "YOLOX initialize failed."
  39. def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
  40. """Detect an input image
  41. :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
  42. :param conf_threshold: confidence threshold for postprocessing, default is 0.25
  43. :param nms_iou_threshold: iou threshold for NMS, default is 0.5
  44. :return: DetectionResult
  45. """
  46. return self._model.predict(input_image, conf_threshold, nms_iou_threshold)
  47. # 一些跟YOLOX模型有关的属性封装
  48. # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
  49. @property
  50. def size(self):
  51. """
  52. Argument for image preprocessing step, the preprocess image size, tuple of (width, height), default size = [640, 640]
  53. """
  54. return self._model.size
  55. @property
  56. def padding_value(self):
  57. # padding value, size should be the same as channels
  58. return self._model.padding_value
  59. @property
  60. def is_decode_exported(self):
  61. """
  62. whether the model_file was exported with decode module.
  63. The official YOLOX/tools/export_onnx.py script will export ONNX file without decode module.
  64. Please set it 'true' manually if the model file was exported with decode module.
  65. Default False.
  66. """
  67. return self._model.is_decode_exported
  68. @property
  69. def downsample_strides(self):
  70. """
  71. downsample strides for YOLOX to generate anchors, will take (8,16,32) as default values, might have stride=64.
  72. """
  73. return self._model.downsample_strides
  74. @property
  75. def max_wh(self):
  76. # for offsetting the boxes by classes when using NMS
  77. return self._model.max_wh
  78. @size.setter
  79. def size(self, wh):
  80. assert isinstance(
  81. wh, (list, tuple)
  82. ), "The value to set `size` must be type of tuple or list."
  83. assert (
  84. len(wh) == 2
  85. ), "The value to set `size` must contains 2 elements means [width, height], but now it contains {} elements.".format(
  86. len(wh)
  87. )
  88. self._model.size = wh
  89. @padding_value.setter
  90. def padding_value(self, value):
  91. assert isinstance(
  92. value, list
  93. ), "The value to set `padding_value` must be type of list."
  94. self._model.padding_value = value
  95. @is_decode_exported.setter
  96. def is_decode_exported(self, value):
  97. assert isinstance(
  98. value, bool
  99. ), "The value to set `is_decode_exported` must be type of bool."
  100. self._model.is_decode_exported = value
  101. @downsample_strides.setter
  102. def downsample_strides(self, value):
  103. assert isinstance(
  104. value, list
  105. ), "The value to set `downsample_strides` must be type of list."
  106. self._model.downsample_strides = value
  107. @max_wh.setter
  108. def max_wh(self, value):
  109. assert isinstance(
  110. value, float
  111. ), "The value to set `max_wh` must be type of float."
  112. self._model.max_wh = value