predictor.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from typing import Any, Callable, Dict, List, Optional, Union
  16. from ....modules.open_vocabulary_detection.model_list import MODELS
  17. from ....utils.func_register import FuncRegister
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..base import BasePredictor
  21. from ..object_detection.result import DetResult
  22. from .processors import (
  23. GroundingDINOPostProcessor,
  24. GroundingDINOProcessor,
  25. YOLOWorldPostProcessor,
  26. YOLOWorldProcessor,
  27. )
  28. class OVDetPredictor(BasePredictor):
  29. entities = MODELS
  30. _FUNC_MAP = {}
  31. register = FuncRegister(_FUNC_MAP)
  32. def __init__(
  33. self, *args, thresholds: Optional[Union[Dict, float]] = None, **kwargs
  34. ):
  35. """Initializes DetPredictor.
  36. Args:
  37. *args: Arbitrary positional arguments passed to the superclass.
  38. thresholds (Optional[Union[Dict, float]], optional): The thresholds for filtering out low-confidence predictions, using a dict to record multiple thresholds
  39. Defaults to None.
  40. **kwargs: Arbitrary keyword arguments passed to the superclass.
  41. """
  42. super().__init__(*args, **kwargs)
  43. if isinstance(thresholds, float):
  44. thresholds = {"threshold": thresholds}
  45. self.thresholds = thresholds
  46. self.pre_ops, self.infer, self.post_op = self._build()
  47. def _build_batch_sampler(self):
  48. return ImageBatchSampler()
  49. def _get_result_class(self):
  50. return DetResult
  51. def _build(self):
  52. # build model preprocess ops
  53. pre_ops = [ReadImage(format="RGB")]
  54. for cfg in self.config["Preprocess"]:
  55. tf_key = cfg["type"]
  56. func = self._FUNC_MAP[tf_key]
  57. cfg.pop("type")
  58. args = cfg
  59. op = func(self, **args) if args else func(self)
  60. if op:
  61. pre_ops.append(op)
  62. # build infer
  63. infer = self.create_static_infer()
  64. # build postprocess op
  65. post_op = self.build_postprocess(pre_ops=pre_ops)
  66. return pre_ops, infer, post_op
  67. def process(
  68. self, batch_data: List[Any], prompt: str, thresholds: Optional[dict] = None
  69. ):
  70. """
  71. Process a batch of data through the preprocessing, inference, and postprocessing.
  72. Args:
  73. batch_data (List[str]): A batch of input data (e.g., image file paths).
  74. prompt (str): Text prompt for open vocabulary detection.
  75. thresholds (Optional[dict]): thresholds used for postprocess.
  76. Returns:
  77. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
  78. for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  79. """
  80. image_paths = batch_data.input_paths
  81. src_images = self.pre_ops[0](batch_data.instances)
  82. datas = src_images
  83. # preprocess for image only
  84. for pre_op in self.pre_ops[1:-1]:
  85. datas = pre_op(datas)
  86. # use Model-specific preprocessor to format batch inputs
  87. batch_inputs = self.pre_ops[-1](datas, prompt)
  88. # do infer
  89. batch_preds = self.infer(batch_inputs)
  90. # postprocess
  91. current_thresholds = self._parse_current_thresholds(
  92. self.post_op, self.thresholds, thresholds
  93. )
  94. boxes = self.post_op(
  95. *batch_preds, prompt=prompt, src_images=src_images, **current_thresholds
  96. )
  97. return {
  98. "input_path": image_paths,
  99. "input_img": [img[..., ::-1] for img in src_images],
  100. "boxes": boxes,
  101. }
  102. def _parse_current_thresholds(self, func, init_thresholds, process_thresholds):
  103. assert isinstance(func, Callable)
  104. thr2val = {}
  105. for name, param in inspect.signature(func).parameters.items():
  106. if "threshold" in name:
  107. thr2val[name] = None
  108. if init_thresholds is not None:
  109. thr2val.update(init_thresholds)
  110. if process_thresholds is not None:
  111. thr2val.update(process_thresholds)
  112. return thr2val
  113. def build_postprocess(self, **kwargs):
  114. if "GroundingDINO" in self.model_name:
  115. pre_ops = kwargs.get("pre_ops")
  116. return GroundingDINOPostProcessor(
  117. tokenizer=pre_ops[-1].tokenizer,
  118. box_threshold=self.config["box_threshold"],
  119. text_threshold=self.config["text_threshold"],
  120. )
  121. elif "YOLO-World" in self.model_name:
  122. return YOLOWorldPostProcessor(
  123. threshold=self.config["threshold"],
  124. )
  125. else:
  126. raise NotImplementedError
  127. @register("GroundingDINOProcessor")
  128. def build_grounding_dino_preprocessor(
  129. self, text_max_words=256, target_size=(800, 1333)
  130. ):
  131. return GroundingDINOProcessor(
  132. model_dir=self.model_dir,
  133. text_max_words=text_max_words,
  134. target_size=target_size,
  135. )
  136. @register("YOLOWorldProcessor")
  137. def build_yoloworld_preprocessor(
  138. self,
  139. image_target_size=(640, 640),
  140. image_mean=[0.0, 0.0, 0.0],
  141. image_std=[1.0, 1.0, 1.0],
  142. ):
  143. return YOLOWorldProcessor(
  144. model_dir=self.model_dir,
  145. image_target_size=image_target_size,
  146. image_mean=image_mean,
  147. image_std=image_std,
  148. )