object_detection.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ...utils.func_register import FuncRegister
  16. from ...modules.object_detection.model_list import MODELS
  17. from ..components import *
  18. from ..results import DetResult
  19. from .base import BasicPredictor
  20. class DetPredictor(BasicPredictor):
  21. entities = MODELS
  22. _FUNC_MAP = {}
  23. register = FuncRegister(_FUNC_MAP)
  24. def _build_components(self):
  25. self._add_component(ReadImage(format="RGB"))
  26. for cfg in self.config["Preprocess"]:
  27. tf_key = cfg["type"]
  28. func = self._FUNC_MAP[tf_key]
  29. cfg.pop("type")
  30. args = cfg
  31. op = func(self, **args) if args else func(self)
  32. self._add_component(op)
  33. predictor = ImageDetPredictor(
  34. model_dir=self.model_dir,
  35. model_prefix=self.MODEL_FILE_PREFIX,
  36. option=self.pp_option,
  37. )
  38. model_names = ["DETR", "RCNN", "YOLOv3", "CenterNet"]
  39. if any(name in self.model_name for name in model_names):
  40. predictor.set_inputs(
  41. {
  42. "img": "img",
  43. "scale_factors": "scale_factors",
  44. "img_size": "img_size",
  45. }
  46. )
  47. if self.model_name in ["BlazeFace", "BlazeFace-FPN-SSH"]:
  48. predictor.set_inputs(
  49. {
  50. "img": "img",
  51. "img_size": "img_size",
  52. }
  53. )
  54. self._add_component(
  55. [
  56. predictor,
  57. DetPostProcess(
  58. threshold=self.config["draw_threshold"],
  59. labels=self.config["label_list"],
  60. ),
  61. ]
  62. )
  63. @register("Resize")
  64. def build_resize(self, target_size, keep_ratio=False, interp=2):
  65. assert target_size
  66. if isinstance(interp, int):
  67. interp = {
  68. 0: "NEAREST",
  69. 1: "LINEAR",
  70. 2: "CUBIC",
  71. 3: "AREA",
  72. 4: "LANCZOS4",
  73. }[interp]
  74. op = Resize(target_size=target_size[::-1], keep_ratio=keep_ratio, interp=interp)
  75. return op
  76. @register("NormalizeImage")
  77. def build_normalize(
  78. self,
  79. norm_type=None,
  80. mean=[0.485, 0.456, 0.406],
  81. std=[0.229, 0.224, 0.225],
  82. is_scale=True,
  83. ):
  84. if is_scale:
  85. scale = 1.0 / 255.0
  86. else:
  87. scale = 1
  88. if not norm_type or norm_type == "none":
  89. norm_type = "mean_std"
  90. if norm_type != "mean_std":
  91. mean = 0
  92. std = 1
  93. return Normalize(scale=scale, mean=mean, std=std)
  94. @register("Permute")
  95. def build_to_chw(self):
  96. return ToCHWImage()
  97. @register("Pad")
  98. def build_pad(self, fill_value=None, size=None):
  99. if fill_value is None:
  100. fill_value = [127.5, 127.5, 127.5]
  101. if size is None:
  102. size = [3, 640, 640]
  103. return DetPad(size=size, fill_value=fill_value)
  104. @register("PadStride")
  105. def build_pad_stride(self, stride=32):
  106. return PadStride(stride=stride)
  107. @register("WarpAffine")
  108. def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
  109. return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
  110. def _pack_res(self, single):
  111. keys = ["input_path", "boxes"]
  112. return DetResult({key: single[key] for key in keys})