object_detection.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ...utils.func_register import FuncRegister
  16. from ...modules.object_detection.model_list import MODELS
  17. from ..components import *
  18. from ..results import DetResult
  19. from .base import BasicPredictor
  20. class DetPredictor(BasicPredictor):
  21. entities = MODELS
  22. _FUNC_MAP = {}
  23. register = FuncRegister(_FUNC_MAP)
  24. def _build_components(self):
  25. self._add_component(ReadImage(format="RGB"))
  26. for cfg in self.config["Preprocess"]:
  27. tf_key = cfg["type"]
  28. func = self._FUNC_MAP[tf_key]
  29. cfg.pop("type")
  30. args = cfg
  31. op = func(self, **args) if args else func(self)
  32. self._add_component(op)
  33. predictor = ImageDetPredictor(
  34. model_dir=self.model_dir,
  35. model_prefix=self.MODEL_FILE_PREFIX,
  36. option=self.pp_option,
  37. )
  38. model_names = ["DETR", "RCNN", "YOLOv3", "CenterNet"]
  39. if any(name in self.model_name for name in model_names):
  40. predictor.set_inputs(
  41. {
  42. "img": "img",
  43. "scale_factors": "scale_factors",
  44. "img_size": "img_size",
  45. }
  46. )
  47. self._add_component(
  48. [
  49. predictor,
  50. DetPostProcess(
  51. threshold=self.config["draw_threshold"],
  52. labels=self.config["label_list"],
  53. ),
  54. ]
  55. )
  56. @register("Resize")
  57. def build_resize(self, target_size, keep_ratio=False, interp=2):
  58. assert target_size
  59. if isinstance(interp, int):
  60. interp = {
  61. 0: "NEAREST",
  62. 1: "LINEAR",
  63. 2: "CUBIC",
  64. 3: "AREA",
  65. 4: "LANCZOS4",
  66. }[interp]
  67. op = Resize(target_size=target_size, keep_ratio=keep_ratio, interp=interp)
  68. return op
  69. @register("NormalizeImage")
  70. def build_normalize(
  71. self,
  72. norm_type=None,
  73. mean=[0.485, 0.456, 0.406],
  74. std=[0.229, 0.224, 0.225],
  75. is_scale=None,
  76. ):
  77. if is_scale:
  78. scale = 1.0 / 255.0
  79. else:
  80. scale = 1
  81. if not norm_type or norm_type == "none":
  82. norm_type = "mean_std"
  83. if norm_type != "mean_std":
  84. mean = 0
  85. std = 1
  86. return Normalize(mean=mean, std=std)
  87. @register("Permute")
  88. def build_to_chw(self):
  89. return ToCHWImage()
  90. @register("Pad")
  91. def build_pad(self, fill_value=None, size=None):
  92. if fill_value is None:
  93. fill_value = [127.5, 127.5, 127.5]
  94. if size is None:
  95. size = [3, 640, 640]
  96. return DetPad(size=size, fill_value=fill_value)
  97. @register("PadStride")
  98. def build_pad_stride(self, stride=32):
  99. return PadStride(stride=stride)
  100. @register("WarpAffine")
  101. def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
  102. return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
  103. def _pack_res(self, single):
  104. keys = ["input_path", "boxes"]
  105. return DetResult({key: single[key] for key in keys})