predictor.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple, Optional, Callable
  15. import numpy as np
  16. import inspect
  17. from ....utils.func_register import FuncRegister
  18. from ....modules.open_vocabulary_detection.model_list import MODELS
  19. from ...common.batch_sampler import ImageBatchSampler
  20. from ...common.reader import ReadImage
  21. from .processors import (
  22. GroundingDINOProcessor,
  23. GroundingDINOPostProcessor,
  24. YOLOWorldProcessor,
  25. YOLOWorldPostProcessor,
  26. )
  27. from ..common import StaticInfer
  28. from ..base import BasePredictor
  29. from ..object_detection.result import DetResult
  30. class OVDetPredictor(BasePredictor):
  31. entities = MODELS
  32. _FUNC_MAP = {}
  33. register = FuncRegister(_FUNC_MAP)
  34. def __init__(
  35. self, *args, thresholds: Optional[Union[Dict, float]] = None, **kwargs
  36. ):
  37. """Initializes DetPredictor.
  38. Args:
  39. *args: Arbitrary positional arguments passed to the superclass.
  40. thresholds (Optional[Union[Dict, float]], optional): The thresholds for filtering out low-confidence predictions, using a dict to record multiple thresholds
  41. Defaults to None.
  42. **kwargs: Arbitrary keyword arguments passed to the superclass.
  43. """
  44. super().__init__(*args, **kwargs)
  45. if isinstance(thresholds, float):
  46. thresholds = {"threshold": thresholds}
  47. self.thresholds = thresholds
  48. self.pre_ops, self.infer, self.post_op = self._build()
  49. def _build_batch_sampler(self):
  50. return ImageBatchSampler()
  51. def _get_result_class(self):
  52. return DetResult
  53. def _build(self):
  54. # build model preprocess ops
  55. pre_ops = [ReadImage(format="RGB")]
  56. for cfg in self.config["Preprocess"]:
  57. tf_key = cfg["type"]
  58. func = self._FUNC_MAP[tf_key]
  59. cfg.pop("type")
  60. args = cfg
  61. op = func(self, **args) if args else func(self)
  62. if op:
  63. pre_ops.append(op)
  64. # build infer
  65. infer = self.create_static_infer()
  66. # build postprocess op
  67. post_op = self.build_postprocess(pre_ops=pre_ops)
  68. return pre_ops, infer, post_op
  69. def process(
  70. self, batch_data: List[Any], prompt: str, thresholds: Optional[dict] = None
  71. ):
  72. """
  73. Process a batch of data through the preprocessing, inference, and postprocessing.
  74. Args:
  75. batch_data (List[str]): A batch of input data (e.g., image file paths).
  76. prompt (str): Text prompt for open vocabulary detection.
  77. thresholds (Optional[dict]): thresholds used for postprocess.
  78. Returns:
  79. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
  80. for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  81. """
  82. image_paths = batch_data.input_paths
  83. src_images = self.pre_ops[0](batch_data.instances)
  84. datas = src_images
  85. # preprocess for image only
  86. for pre_op in self.pre_ops[1:-1]:
  87. datas = pre_op(datas)
  88. # use Model-specific preprocessor to format batch inputs
  89. batch_inputs = self.pre_ops[-1](datas, prompt)
  90. # do infer
  91. batch_preds = self.infer(batch_inputs)
  92. # postprocess
  93. current_thresholds = self._parse_current_thresholds(
  94. self.post_op, self.thresholds, thresholds
  95. )
  96. boxes = self.post_op(
  97. *batch_preds, prompt=prompt, src_images=src_images, **current_thresholds
  98. )
  99. return {
  100. "input_path": image_paths,
  101. "input_img": [img[..., ::-1] for img in src_images],
  102. "boxes": boxes,
  103. }
  104. def _parse_current_thresholds(self, func, init_thresholds, process_thresholds):
  105. assert isinstance(func, Callable)
  106. thr2val = {}
  107. for name, param in inspect.signature(func).parameters.items():
  108. if "threshold" in name:
  109. thr2val[name] = None
  110. if init_thresholds is not None:
  111. thr2val.update(init_thresholds)
  112. if process_thresholds is not None:
  113. thr2val.update(process_thresholds)
  114. return thr2val
  115. def build_postprocess(self, **kwargs):
  116. if "GroundingDINO" in self.model_name:
  117. pre_ops = kwargs.get("pre_ops")
  118. return GroundingDINOPostProcessor(
  119. tokenizer=pre_ops[-1].tokenizer,
  120. box_threshold=self.config["box_threshold"],
  121. text_threshold=self.config["text_threshold"],
  122. )
  123. elif "YOLO-World" in self.model_name:
  124. return YOLOWorldPostProcessor(
  125. threshold=self.config["threshold"],
  126. )
  127. else:
  128. raise NotImplementedError
  129. @register("GroundingDINOProcessor")
  130. def build_grounding_dino_preprocessor(
  131. self, text_max_words=256, target_size=(800, 1333)
  132. ):
  133. return GroundingDINOProcessor(
  134. model_dir=self.model_dir,
  135. text_max_words=text_max_words,
  136. target_size=target_size,
  137. )
  138. @register("YOLOWorldProcessor")
  139. def build_yoloworld_preprocessor(
  140. self,
  141. image_target_size=(640, 640),
  142. image_mean=[0.0, 0.0, 0.0],
  143. image_std=[1.0, 1.0, 1.0],
  144. ):
  145. return YOLOWorldProcessor(
  146. model_dir=self.model_dir,
  147. image_target_size=image_target_size,
  148. image_mean=image_mean,
  149. image_std=image_std,
  150. )