predictor.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple, Optional, Callable
  15. import numpy as np
  16. import inspect
  17. from ....utils.func_register import FuncRegister
  18. from ....modules.open_vocabulary_detection.model_list import MODELS
  19. from ...common.batch_sampler import ImageBatchSampler
  20. from ...common.reader import ReadImage
  21. from .processors import (
  22. GroundingDINOProcessor,
  23. GroundingDINOPostProcessor
  24. )
  25. from ..common import StaticInfer
  26. from ..base import BasicPredictor
  27. from ..object_detection.result import DetResult
  28. class OVDetPredictor(BasicPredictor):
  29. entities = MODELS
  30. _FUNC_MAP = {}
  31. register = FuncRegister(_FUNC_MAP)
  32. def __init__(self, *args, thresholds: Optional[Union[Dict, float]] = None, **kwargs):
  33. """Initializes DetPredictor.
  34. Args:
  35. *args: Arbitrary positional arguments passed to the superclass.
  36. thresholds (Optional[Union[Dict, float]], optional): The thresholds for filtering out low-confidence predictions, using a dict to record multiple thresholds
  37. Defaults to None.
  38. **kwargs: Arbitrary keyword arguments passed to the superclass.
  39. """
  40. super().__init__(*args, **kwargs)
  41. if isinstance(thresholds, float):
  42. thresholds = {"threshold": thresholds}
  43. self.thresholds = thresholds
  44. self.pre_ops, self.infer, self.post_op = self._build()
  45. def _build_batch_sampler(self):
  46. return ImageBatchSampler()
  47. def _get_result_class(self):
  48. return DetResult
  49. def _build(self):
  50. # build model preprocess ops
  51. pre_ops = [ReadImage(format="RGB")]
  52. for cfg in self.config["Preprocess"]:
  53. tf_key = cfg["type"]
  54. func = self._FUNC_MAP[tf_key]
  55. cfg.pop("type")
  56. args = cfg
  57. op = func(self, **args) if args else func(self)
  58. if op:
  59. pre_ops.append(op)
  60. # build infer
  61. infer = StaticInfer(
  62. model_dir=self.model_dir,
  63. model_prefix=self.MODEL_FILE_PREFIX,
  64. option=self.pp_option,
  65. )
  66. # build postprocess op
  67. post_op = self.build_postprocess(pre_ops = pre_ops)
  68. return pre_ops, infer, post_op
  69. def process(self, batch_data: List[Any], prompt: str, thresholds: Optional[dict] = None):
  70. """
  71. Process a batch of data through the preprocessing, inference, and postprocessing.
  72. Args:
  73. batch_data (List[str]): A batch of input data (e.g., image file paths).
  74. prompt (str): Text prompt for open vocabulary detection.
  75. thresholds (Optional[dict]): thresholds used for postprocess.
  76. Returns:
  77. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
  78. for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  79. """
  80. image_paths = batch_data
  81. src_images = self.pre_ops[0](batch_data)
  82. datas = src_images
  83. # preprocess
  84. for pre_op in self.pre_ops[1:-1]:
  85. datas = pre_op(datas)
  86. # use Model-specific preprocessor to format batch inputs
  87. batch_inputs = self.pre_ops[-1](datas, prompt)
  88. # do infer
  89. batch_preds = self.infer(batch_inputs)
  90. # postprocess
  91. current_thresholds = self._parse_current_thresholds(
  92. self.post_op, self.thresholds, thresholds
  93. )
  94. boxes = self.post_op(
  95. *batch_preds, prompt=prompt, src_images=src_images, **current_thresholds
  96. )
  97. return {
  98. "input_path": image_paths,
  99. "input_img": src_images,
  100. "boxes": boxes,
  101. }
  102. def _parse_current_thresholds(self, func, init_thresholds, process_thresholds):
  103. assert isinstance(func, Callable)
  104. thr2val = {}
  105. for name, param in inspect.signature(func).parameters.items():
  106. if "threshold" in name:
  107. thr2val[name] = None
  108. if init_thresholds is not None:
  109. thr2val.update(init_thresholds)
  110. if process_thresholds is not None:
  111. thr2val.update(process_thresholds)
  112. return thr2val
  113. def build_postprocess(self, **kwargs):
  114. if "GroundingDINO" in self.model_name:
  115. pre_ops = kwargs.get("pre_ops")
  116. return GroundingDINOPostProcessor(
  117. tokenizer=pre_ops[-1].tokenizer,
  118. box_threshold=self.config["box_threshold"],
  119. text_threshold=self.config["text_threshold"],
  120. )
  121. else:
  122. raise NotImplementedError
  123. @register("GroundingDINOProcessor")
  124. def build_grounding_dino_preprocessor(self, text_max_words=256, target_size=(800, 1333)):
  125. return GroundingDINOProcessor(
  126. model_dir=self.model_dir,
  127. text_max_words=text_max_words,
  128. target_size=target_size
  129. )