predictor.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple, Optional, Callable
  15. import numpy as np
  16. import inspect
  17. from ....utils.func_register import FuncRegister
  18. from ....modules.open_vocabulary_detection.model_list import MODELS
  19. from ...common.batch_sampler import ImageBatchSampler
  20. from ...common.reader import ReadImage
  21. from .processors import (
  22. GroundingDINOProcessor,
  23. GroundingDINOPostProcessor,
  24. YOLOWorldProcessor,
  25. YOLOWorldPostProcessor,
  26. )
  27. from ..common import StaticInfer
  28. from ..base import BasicPredictor
  29. from ..object_detection.result import DetResult
  30. class OVDetPredictor(BasicPredictor):
  31. entities = MODELS
  32. _FUNC_MAP = {}
  33. register = FuncRegister(_FUNC_MAP)
  34. def __init__(
  35. self, *args, thresholds: Optional[Union[Dict, float]] = None, **kwargs
  36. ):
  37. """Initializes DetPredictor.
  38. Args:
  39. *args: Arbitrary positional arguments passed to the superclass.
  40. thresholds (Optional[Union[Dict, float]], optional): The thresholds for filtering out low-confidence predictions, using a dict to record multiple thresholds
  41. Defaults to None.
  42. **kwargs: Arbitrary keyword arguments passed to the superclass.
  43. """
  44. super().__init__(*args, **kwargs)
  45. if isinstance(thresholds, float):
  46. thresholds = {"threshold": thresholds}
  47. self.thresholds = thresholds
  48. self.pre_ops, self.infer, self.post_op = self._build()
  49. def _build_batch_sampler(self):
  50. return ImageBatchSampler()
  51. def _get_result_class(self):
  52. return DetResult
  53. def _build(self):
  54. # build model preprocess ops
  55. pre_ops = [ReadImage(format="RGB")]
  56. for cfg in self.config["Preprocess"]:
  57. tf_key = cfg["type"]
  58. func = self._FUNC_MAP[tf_key]
  59. cfg.pop("type")
  60. args = cfg
  61. op = func(self, **args) if args else func(self)
  62. if op:
  63. pre_ops.append(op)
  64. # build infer
  65. infer = StaticInfer(
  66. model_dir=self.model_dir,
  67. model_prefix=self.MODEL_FILE_PREFIX,
  68. option=self.pp_option,
  69. )
  70. # build postprocess op
  71. post_op = self.build_postprocess(pre_ops=pre_ops)
  72. return pre_ops, infer, post_op
  73. def process(
  74. self, batch_data: List[Any], prompt: str, thresholds: Optional[dict] = None
  75. ):
  76. """
  77. Process a batch of data through the preprocessing, inference, and postprocessing.
  78. Args:
  79. batch_data (List[str]): A batch of input data (e.g., image file paths).
  80. prompt (str): Text prompt for open vocabulary detection.
  81. thresholds (Optional[dict]): thresholds used for postprocess.
  82. Returns:
  83. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
  84. for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  85. """
  86. image_paths = batch_data.input_paths
  87. src_images = self.pre_ops[0](batch_data.instances)
  88. datas = src_images
  89. # preprocess for image only
  90. for pre_op in self.pre_ops[1:-1]:
  91. datas = pre_op(datas)
  92. # use Model-specific preprocessor to format batch inputs
  93. batch_inputs = self.pre_ops[-1](datas, prompt)
  94. # do infer
  95. batch_preds = self.infer(batch_inputs)
  96. # postprocess
  97. current_thresholds = self._parse_current_thresholds(
  98. self.post_op, self.thresholds, thresholds
  99. )
  100. boxes = self.post_op(
  101. *batch_preds, prompt=prompt, src_images=src_images, **current_thresholds
  102. )
  103. return {
  104. "input_path": image_paths,
  105. "input_img": [img[..., ::-1] for img in src_images],
  106. "boxes": boxes,
  107. }
  108. def _parse_current_thresholds(self, func, init_thresholds, process_thresholds):
  109. assert isinstance(func, Callable)
  110. thr2val = {}
  111. for name, param in inspect.signature(func).parameters.items():
  112. if "threshold" in name:
  113. thr2val[name] = None
  114. if init_thresholds is not None:
  115. thr2val.update(init_thresholds)
  116. if process_thresholds is not None:
  117. thr2val.update(process_thresholds)
  118. return thr2val
  119. def build_postprocess(self, **kwargs):
  120. if "GroundingDINO" in self.model_name:
  121. pre_ops = kwargs.get("pre_ops")
  122. return GroundingDINOPostProcessor(
  123. tokenizer=pre_ops[-1].tokenizer,
  124. box_threshold=self.config["box_threshold"],
  125. text_threshold=self.config["text_threshold"],
  126. )
  127. elif "YOLO-World" in self.model_name:
  128. return YOLOWorldPostProcessor(
  129. threshold=self.config["threshold"],
  130. )
  131. else:
  132. raise NotImplementedError
  133. @register("GroundingDINOProcessor")
  134. def build_grounding_dino_preprocessor(
  135. self, text_max_words=256, target_size=(800, 1333)
  136. ):
  137. return GroundingDINOProcessor(
  138. model_dir=self.model_dir,
  139. text_max_words=text_max_words,
  140. target_size=target_size,
  141. )
  142. @register("YOLOWorldProcessor")
  143. def build_yoloworld_preprocessor(
  144. self,
  145. image_target_size=(640, 640),
  146. image_mean=[0.0, 0.0, 0.0],
  147. image_std=[1.0, 1.0, 1.0],
  148. ):
  149. return YOLOWorldProcessor(
  150. model_dir=self.model_dir,
  151. image_target_size=image_target_size,
  152. image_mean=image_mean,
  153. image_std=image_std,
  154. )