object_detection.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ...utils.func_register import FuncRegister
  16. from ...modules.object_detection.model_list import MODELS
  17. from ..components import *
  18. from ..results import DetResult
  19. from .base import BasicPredictor
  20. class DetPredictor(BasicPredictor):
  21. entities = MODELS
  22. _FUNC_MAP = {}
  23. register = FuncRegister(_FUNC_MAP)
  24. def _build_components(self):
  25. self._add_component(ReadImage(format="RGB"))
  26. for cfg in self.config["Preprocess"]:
  27. tf_key = cfg["type"]
  28. func = self._FUNC_MAP[tf_key]
  29. cfg.pop("type")
  30. args = cfg
  31. op = func(self, **args) if args else func(self)
  32. self._add_component(op)
  33. predictor = ImageDetPredictor(
  34. model_dir=self.model_dir,
  35. model_prefix=self.MODEL_FILE_PREFIX,
  36. option=self.pp_option,
  37. )
  38. if "DETR" in self.model_name or "RCNN" in self.model_name:
  39. predictor.set_inputs(
  40. {
  41. "img": "img",
  42. "scale_factors": "scale_factors",
  43. "img_size": "img_size",
  44. }
  45. )
  46. self._add_component(
  47. [
  48. predictor,
  49. DetPostProcess(
  50. threshold=self.config["draw_threshold"],
  51. labels=self.config["label_list"],
  52. ),
  53. ]
  54. )
  55. @register("Resize")
  56. def build_resize(self, target_size, keep_ratio=False, interp=2):
  57. assert target_size
  58. if isinstance(interp, int):
  59. interp = {
  60. 0: "NEAREST",
  61. 1: "LINEAR",
  62. 2: "CUBIC",
  63. 3: "AREA",
  64. 4: "LANCZOS4",
  65. }[interp]
  66. op = Resize(target_size=target_size, keep_ratio=keep_ratio, interp=interp)
  67. return op
  68. @register("NormalizeImage")
  69. def build_normalize(
  70. self,
  71. norm_type=None,
  72. mean=[0.485, 0.456, 0.406],
  73. std=[0.229, 0.224, 0.225],
  74. is_scale=None,
  75. ):
  76. if is_scale:
  77. scale = 1.0 / 255.0
  78. else:
  79. scale = 1
  80. if not norm_type or norm_type == "none":
  81. norm_type = "mean_std"
  82. if norm_type != "mean_std":
  83. mean = 0
  84. std = 1
  85. return Normalize(mean=mean, std=std)
  86. @register("Permute")
  87. def build_to_chw(self):
  88. return ToCHWImage()
  89. @register("Pad")
  90. def build_pad(self, fill_value=None, size=None):
  91. if fill_value is None:
  92. fill_value = [127.5, 127.5, 127.5]
  93. if size is None:
  94. size = [3, 640, 640]
  95. return DetPad(size=size, fill_value=fill_value)
  96. @register("PadStride")
  97. def build_pad_stride(self, stride=32):
  98. return PadStride(stride=stride)
  99. def _pack_res(self, single):
  100. keys = ["img_path", "boxes"]
  101. return DetResult({key: single[key] for key in keys})