object_detection.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ...utils.func_register import FuncRegister
  16. from ...modules.object_detection.model_list import MODELS
  17. from ..components import *
  18. from ..results import DetResult
  19. from .base import BasicPredictor
  20. class DetPredictor(BasicPredictor):
  21. entities = MODELS
  22. _FUNC_MAP = {}
  23. register = FuncRegister(_FUNC_MAP)
  24. def _build_components(self):
  25. self._add_component(ReadImage(format="RGB"))
  26. for cfg in self.config["Preprocess"]:
  27. tf_key = cfg["type"]
  28. func = self._FUNC_MAP[tf_key]
  29. cfg.pop("type")
  30. args = cfg
  31. op = func(self, **args) if args else func(self)
  32. self._add_component(op)
  33. predictor = ImageDetPredictor(
  34. model_dir=self.model_dir,
  35. model_prefix=self.MODEL_FILE_PREFIX,
  36. option=self.pp_option,
  37. )
  38. model_names = ["DETR", "RCNN", "YOLOv3", "CenterNet", "PP-DocLayout-L"]
  39. if any(name in self.model_name for name in model_names):
  40. predictor.set_inputs(
  41. {
  42. "img": "img",
  43. "scale_factors": "scale_factors",
  44. "img_size": "img_size",
  45. }
  46. )
  47. if self.model_name in ["BlazeFace", "BlazeFace-FPN-SSH"]:
  48. predictor.set_inputs(
  49. {
  50. "img": "img",
  51. "img_size": "img_size",
  52. }
  53. )
  54. self._add_component(
  55. [
  56. predictor,
  57. DetPostProcess(
  58. threshold=self.config["draw_threshold"],
  59. labels=self.config["label_list"],
  60. layout_postprocess=self.config.get("layout_postprocess", False),
  61. ),
  62. ]
  63. )
  64. @register("Resize")
  65. def build_resize(self, target_size, keep_ratio=False, interp=2):
  66. assert target_size
  67. if isinstance(interp, int):
  68. interp = {
  69. 0: "NEAREST",
  70. 1: "LINEAR",
  71. 2: "CUBIC",
  72. 3: "AREA",
  73. 4: "LANCZOS4",
  74. }[interp]
  75. op = Resize(target_size=target_size[::-1], keep_ratio=keep_ratio, interp=interp)
  76. return op
  77. @register("NormalizeImage")
  78. def build_normalize(
  79. self,
  80. norm_type=None,
  81. mean=[0.485, 0.456, 0.406],
  82. std=[0.229, 0.224, 0.225],
  83. is_scale=True,
  84. ):
  85. if is_scale:
  86. scale = 1.0 / 255.0
  87. else:
  88. scale = 1
  89. if not norm_type or norm_type == "none":
  90. norm_type = "mean_std"
  91. if norm_type != "mean_std":
  92. mean = 0
  93. std = 1
  94. return Normalize(scale=scale, mean=mean, std=std)
  95. @register("Permute")
  96. def build_to_chw(self):
  97. return ToCHWImage()
  98. @register("Pad")
  99. def build_pad(self, fill_value=None, size=None):
  100. if fill_value is None:
  101. fill_value = [127.5, 127.5, 127.5]
  102. if size is None:
  103. size = [3, 640, 640]
  104. return DetPad(size=size, fill_value=fill_value)
  105. @register("PadStride")
  106. def build_pad_stride(self, stride=32):
  107. return PadStride(stride=stride)
  108. @register("WarpAffine")
  109. def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
  110. return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
  111. def _pack_res(self, single):
  112. keys = ["input_path", "boxes"]
  113. return DetResult({key: single[key] for key in keys})