yolox.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import logging
  16. from .... import UltraInferModel, ModelFormat
  17. from .... import c_lib_wrap as C
  18. class YOLOX(UltraInferModel):
  19. def __init__(
  20. self,
  21. model_file,
  22. params_file="",
  23. runtime_option=None,
  24. model_format=ModelFormat.ONNX,
  25. ):
  26. """Load a YOLOX model exported by YOLOX.
  27. :param model_file: (str)Path of model file, e.g ./yolox.onnx
  28. :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
  29. :param runtime_option: (ultra_infer.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
  30. :param model_format: (ultra_infer.ModelForamt)Model format of the loaded model
  31. """
  32. # 调用基函数进行backend_option的初始化
  33. # 初始化后的option保存在self._runtime_option
  34. super(YOLOX, self).__init__(runtime_option)
  35. self._model = C.vision.detection.YOLOX(
  36. model_file, params_file, self._runtime_option, model_format
  37. )
  38. # 通过self.initialized判断整个模型的初始化是否成功
  39. assert self.initialized, "YOLOX initialize failed."
  40. def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5):
  41. """Detect an input image
  42. :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
  43. :param conf_threshold: confidence threashold for postprocessing, default is 0.25
  44. :param nms_iou_threshold: iou threashold for NMS, default is 0.5
  45. :return: DetectionResult
  46. """
  47. return self._model.predict(input_image, conf_threshold, nms_iou_threshold)
  48. # 一些跟YOLOX模型有关的属性封装
  49. # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持)
  50. @property
  51. def size(self):
  52. """
  53. Argument for image preprocessing step, the preprocess image size, tuple of (width, height), default size = [640, 640]
  54. """
  55. return self._model.size
  56. @property
  57. def padding_value(self):
  58. # padding value, size should be the same as channels
  59. return self._model.padding_value
  60. @property
  61. def is_decode_exported(self):
  62. """
  63. whether the model_file was exported with decode module.
  64. The official YOLOX/tools/export_onnx.py script will export ONNX file without decode module.
  65. Please set it 'true' manually if the model file was exported with decode module.
  66. Defalut False.
  67. """
  68. return self._model.is_decode_exported
  69. @property
  70. def downsample_strides(self):
  71. """
  72. downsample strides for YOLOX to generate anchors, will take (8,16,32) as default values, might have stride=64.
  73. """
  74. return self._model.downsample_strides
  75. @property
  76. def max_wh(self):
  77. # for offseting the boxes by classes when using NMS
  78. return self._model.max_wh
  79. @size.setter
  80. def size(self, wh):
  81. assert isinstance(
  82. wh, (list, tuple)
  83. ), "The value to set `size` must be type of tuple or list."
  84. assert (
  85. len(wh) == 2
  86. ), "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
  87. len(wh)
  88. )
  89. self._model.size = wh
  90. @padding_value.setter
  91. def padding_value(self, value):
  92. assert isinstance(
  93. value, list
  94. ), "The value to set `padding_value` must be type of list."
  95. self._model.padding_value = value
  96. @is_decode_exported.setter
  97. def is_decode_exported(self, value):
  98. assert isinstance(
  99. value, bool
  100. ), "The value to set `is_decode_exported` must be type of bool."
  101. self._model.is_decode_exported = value
  102. @downsample_strides.setter
  103. def downsample_strides(self, value):
  104. assert isinstance(
  105. value, list
  106. ), "The value to set `downsample_strides` must be type of list."
  107. self._model.downsample_strides = value
  108. @max_wh.setter
  109. def max_wh(self, value):
  110. assert isinstance(
  111. value, float
  112. ), "The value to set `max_wh` must be type of float."
  113. self._model.max_wh = value