retinaface.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from .... import UltraInferModel, ModelFormat
  16. from .... import c_lib_wrap as C
  17. class RetinaFace(UltraInferModel):
  18. def __init__(
  19. self,
  20. model_file,
  21. params_file="",
  22. runtime_option=None,
  23. model_format=ModelFormat.ONNX,
  24. ):
  25. """Load a RetinaFace model exported by RetinaFace.
  26. :param model_file: (str)Path of model file, e.g ./retinaface.onnx
  27. :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
  28. :param runtime_option: (ultra_infer.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
  29. :param model_format: (ultra_infer.ModelForamt)Model format of the loaded model
  30. """
  31. # 调用基函数进行backend_option的初始化
  32. # 初始化后的option保存在self._runtime_option
  33. super(RetinaFace, self).__init__(runtime_option)
  34. self._model = C.vision.facedet.RetinaFace(
  35. model_file, params_file, self._runtime_option, model_format
  36. )
  37. # 通过self.initialized判断整个模型的初始化是否成功
  38. assert self.initialized, "RetinaFace initialize failed."
  39. def predict(self, input_image, conf_threshold=0.7, nms_iou_threshold=0.3):
  40. """Detect the location and key points of human faces from an input image
  41. :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
  42. :param conf_threshold: confidence threashold for postprocessing, default is 0.7
  43. :param nms_iou_threshold: iou threashold for NMS, default is 0.3
  44. :return: FaceDetectionResult
  45. """
  46. return self._model.predict(input_image, conf_threshold, nms_iou_threshold)
  47. # 一些跟模型有关的属性封装
  48. # 多数是预处理相关,可通过修改如model.size = [640, 480]改变预处理时resize的大小(前提是模型支持)
  49. @property
  50. def size(self):
  51. """
  52. Argument for image preprocessing step, the preprocess image size, tuple of (width, height), default (640, 640)
  53. """
  54. return self._model.size
  55. @property
  56. def variance(self):
  57. """
  58. Argument for image postprocessing step, variance in RetinaFace's prior-box(anchor) generate process, default (0.1, 0.2)
  59. """
  60. return self._model.variance
  61. @property
  62. def downsample_strides(self):
  63. """
  64. Argument for image postprocessing step, downsample strides (namely, steps) for RetinaFace to generate anchors, will take (8,16,32) as default values
  65. """
  66. return self._model.downsample_strides
  67. @property
  68. def min_sizes(self):
  69. """
  70. Argument for image postprocessing step, min sizes, width and height for each anchor, default min_sizes = [[16, 32], [64, 128], [256, 512]]
  71. """
  72. return self._model.min_sizes
  73. @property
  74. def landmarks_per_face(self):
  75. """
  76. Argument for image postprocessing step, landmarks_per_face, default 5 in RetinaFace
  77. """
  78. return self._model.landmarks_per_face
  79. @size.setter
  80. def size(self, wh):
  81. assert isinstance(
  82. wh, (list, tuple)
  83. ), "The value to set `size` must be type of tuple or list."
  84. assert (
  85. len(wh) == 2
  86. ), "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
  87. len(wh)
  88. )
  89. self._model.size = wh
  90. @variance.setter
  91. def variance(self, value):
  92. assert isinstance(
  93. value, (list, tuple)
  94. ), "The value to set `variance` must be type of tuple or list."
  95. assert (
  96. len(value) == 2
  97. ), "The value to set `variance` must contatins 2 elements".format(len(value))
  98. self._model.variance = value
  99. @downsample_strides.setter
  100. def downsample_strides(self, value):
  101. assert isinstance(
  102. value, list
  103. ), "The value to set `downsample_strides` must be type of list."
  104. self._model.downsample_strides = value
  105. @min_sizes.setter
  106. def min_sizes(self, value):
  107. assert isinstance(
  108. value, list
  109. ), "The value to set `min_sizes` must be type of list."
  110. self._model.min_sizes = value
  111. @landmarks_per_face.setter
  112. def landmarks_per_face(self, value):
  113. assert isinstance(
  114. value, int
  115. ), "The value to set `landmarks_per_face` must be type of int."
  116. self._model.landmarks_per_face = value