predictor.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple, Optional, Callable
  15. import numpy as np
  16. import inspect
  17. from ....utils.func_register import FuncRegister
  18. from ....modules.open_vocabulary_detection.model_list import MODELS
  19. from ...common.batch_sampler import ImageBatchSampler
  20. from ...common.reader import ReadImage
  21. from .processors import GroundingDINOProcessor, GroundingDINOPostProcessor
  22. from ..common import StaticInfer
  23. from ..base import BasicPredictor
  24. from ..object_detection.result import DetResult
  25. class OVDetPredictor(BasicPredictor):
  26. entities = MODELS
  27. _FUNC_MAP = {}
  28. register = FuncRegister(_FUNC_MAP)
  29. def __init__(
  30. self, *args, thresholds: Optional[Union[Dict, float]] = None, **kwargs
  31. ):
  32. """Initializes DetPredictor.
  33. Args:
  34. *args: Arbitrary positional arguments passed to the superclass.
  35. thresholds (Optional[Union[Dict, float]], optional): The thresholds for filtering out low-confidence predictions, using a dict to record multiple thresholds
  36. Defaults to None.
  37. **kwargs: Arbitrary keyword arguments passed to the superclass.
  38. """
  39. super().__init__(*args, **kwargs)
  40. if isinstance(thresholds, float):
  41. thresholds = {"threshold": thresholds}
  42. self.thresholds = thresholds
  43. self.pre_ops, self.infer, self.post_op = self._build()
  44. def _build_batch_sampler(self):
  45. return ImageBatchSampler()
  46. def _get_result_class(self):
  47. return DetResult
  48. def _build(self):
  49. # build model preprocess ops
  50. pre_ops = [ReadImage(format="RGB")]
  51. for cfg in self.config["Preprocess"]:
  52. tf_key = cfg["type"]
  53. func = self._FUNC_MAP[tf_key]
  54. cfg.pop("type")
  55. args = cfg
  56. op = func(self, **args) if args else func(self)
  57. if op:
  58. pre_ops.append(op)
  59. # build infer
  60. infer = StaticInfer(
  61. model_dir=self.model_dir,
  62. model_prefix=self.MODEL_FILE_PREFIX,
  63. option=self.pp_option,
  64. )
  65. # build postprocess op
  66. post_op = self.build_postprocess(pre_ops=pre_ops)
  67. return pre_ops, infer, post_op
  68. def process(
  69. self, batch_data: List[Any], prompt: str, thresholds: Optional[dict] = None
  70. ):
  71. """
  72. Process a batch of data through the preprocessing, inference, and postprocessing.
  73. Args:
  74. batch_data (List[str]): A batch of input data (e.g., image file paths).
  75. prompt (str): Text prompt for open vocabulary detection.
  76. thresholds (Optional[dict]): thresholds used for postprocess.
  77. Returns:
  78. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
  79. for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  80. """
  81. image_paths = batch_data
  82. src_images = self.pre_ops[0](batch_data)
  83. datas = src_images
  84. # preprocess
  85. for pre_op in self.pre_ops[1:-1]:
  86. datas = pre_op(datas)
  87. # use Model-specific preprocessor to format batch inputs
  88. batch_inputs = self.pre_ops[-1](datas, prompt)
  89. # do infer
  90. batch_preds = self.infer(batch_inputs)
  91. # postprocess
  92. current_thresholds = self._parse_current_thresholds(
  93. self.post_op, self.thresholds, thresholds
  94. )
  95. boxes = self.post_op(
  96. *batch_preds, prompt=prompt, src_images=src_images, **current_thresholds
  97. )
  98. return {
  99. "input_path": image_paths,
  100. "input_img": src_images,
  101. "boxes": boxes,
  102. }
  103. def _parse_current_thresholds(self, func, init_thresholds, process_thresholds):
  104. assert isinstance(func, Callable)
  105. thr2val = {}
  106. for name, param in inspect.signature(func).parameters.items():
  107. if "threshold" in name:
  108. thr2val[name] = None
  109. if init_thresholds is not None:
  110. thr2val.update(init_thresholds)
  111. if process_thresholds is not None:
  112. thr2val.update(process_thresholds)
  113. return thr2val
  114. def build_postprocess(self, **kwargs):
  115. if "GroundingDINO" in self.model_name:
  116. pre_ops = kwargs.get("pre_ops")
  117. return GroundingDINOPostProcessor(
  118. tokenizer=pre_ops[-1].tokenizer,
  119. box_threshold=self.config["box_threshold"],
  120. text_threshold=self.config["text_threshold"],
  121. )
  122. else:
  123. raise NotImplementedError
  124. @register("GroundingDINOProcessor")
  125. def build_grounding_dino_preprocessor(
  126. self, text_max_words=256, target_size=(800, 1333)
  127. ):
  128. return GroundingDINOProcessor(
  129. model_dir=self.model_dir,
  130. text_max_words=text_max_words,
  131. target_size=target_size,
  132. )