predictor.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Tuple, Union
  15. import numpy as np
  16. from ....modules.image_classification.model_list import MODELS
  17. from ....utils.func_register import FuncRegister
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..base import BasePredictor
  21. from ..common import Normalize, Resize, ResizeByShort, ToBatch, ToCHWImage
  22. from .processors import Crop, Topk
  23. from .result import TopkResult
  24. class ClasPredictor(BasePredictor):
  25. """ClasPredictor that inherits from BasePredictor."""
  26. entities = MODELS
  27. _FUNC_MAP = {}
  28. register = FuncRegister(_FUNC_MAP)
  29. def __init__(
  30. self, topk: Union[int, None] = None, *args: List, **kwargs: Dict
  31. ) -> None:
  32. """Initializes ClasPredictor.
  33. Args:
  34. topk (int, optional): The number of top-k predictions to return. If None, it will be depending on config of inference or predict. Defaults to None.
  35. *args: Arbitrary positional arguments passed to the superclass.
  36. **kwargs: Arbitrary keyword arguments passed to the superclass.
  37. """
  38. super().__init__(*args, **kwargs)
  39. self.topk = topk
  40. self.preprocessors, self.infer, self.postprocessors = self._build()
  41. def _build_batch_sampler(self) -> ImageBatchSampler:
  42. """Builds and returns an ImageBatchSampler instance.
  43. Returns:
  44. ImageBatchSampler: An instance of ImageBatchSampler.
  45. """
  46. return ImageBatchSampler()
  47. def _get_result_class(self) -> type:
  48. """Returns the result class, TopkResult.
  49. Returns:
  50. type: The TopkResult class.
  51. """
  52. return TopkResult
  53. def _build(self) -> Tuple:
  54. """Build the preprocessors, inference engine, and postprocessors based on the configuration.
  55. Returns:
  56. tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
  57. """
  58. preprocessors = {"Read": ReadImage(format="RGB")}
  59. for cfg in self.config["PreProcess"]["transform_ops"]:
  60. tf_key = list(cfg.keys())[0]
  61. func = self._FUNC_MAP[tf_key]
  62. args = cfg.get(tf_key, {})
  63. name, op = func(self, **args) if args else func(self)
  64. preprocessors[name] = op
  65. preprocessors["ToBatch"] = ToBatch()
  66. infer = self.create_static_infer()
  67. postprocessors = {}
  68. for key in self.config["PostProcess"]:
  69. func = self._FUNC_MAP.get(key)
  70. args = self.config["PostProcess"].get(key, {})
  71. name, op = func(self, **args) if args else func(self)
  72. postprocessors[name] = op
  73. return preprocessors, infer, postprocessors
  74. def process(
  75. self, batch_data: List[Union[str, np.ndarray]], topk: Union[int, None] = None
  76. ) -> Dict[str, Any]:
  77. """
  78. Process a batch of data through the preprocessing, inference, and postprocessing.
  79. Args:
  80. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  81. topk: The number of top predictions to keep. If None, it will be depending on `self.topk`. Defaults to None.
  82. Returns:
  83. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  84. """
  85. batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data.instances)
  86. batch_imgs = self.preprocessors["Resize"](imgs=batch_raw_imgs)
  87. if "Crop" in self.preprocessors:
  88. batch_imgs = self.preprocessors["Crop"](imgs=batch_imgs)
  89. batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
  90. batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
  91. x = self.preprocessors["ToBatch"](imgs=batch_imgs)
  92. batch_preds = self.infer(x=x)
  93. batch_class_ids, batch_scores, batch_label_names = self.postprocessors["Topk"](
  94. batch_preds, topk=topk or self.topk
  95. )
  96. return {
  97. "input_path": batch_data.input_paths,
  98. "page_index": batch_data.page_indexes,
  99. "input_img": batch_raw_imgs,
  100. "class_ids": batch_class_ids,
  101. "scores": batch_scores,
  102. "label_names": batch_label_names,
  103. }
  104. @register("ResizeImage")
  105. # TODO(gaotingquan): backend & interpolation
  106. def build_resize(
  107. self, resize_short=None, size=None, backend="cv2", interpolation="LINEAR"
  108. ):
  109. assert resize_short or size
  110. if resize_short:
  111. op = ResizeByShort(
  112. target_short_edge=resize_short,
  113. size_divisor=None,
  114. interp=interpolation,
  115. backend=backend,
  116. )
  117. else:
  118. op = Resize(
  119. target_size=size,
  120. size_divisor=None,
  121. interp=interpolation,
  122. backend=backend,
  123. )
  124. return "Resize", op
  125. @register("CropImage")
  126. def build_crop(self, size=224):
  127. return "Crop", Crop(crop_size=size)
  128. @register("NormalizeImage")
  129. def build_normalize(
  130. self,
  131. mean=[0.485, 0.456, 0.406],
  132. std=[0.229, 0.224, 0.225],
  133. scale=1 / 255,
  134. order="",
  135. channel_num=3,
  136. ):
  137. assert channel_num == 3
  138. assert order == ""
  139. return "Normalize", Normalize(scale=scale, mean=mean, std=std)
  140. @register("ToCHWImage")
  141. def build_to_chw(self):
  142. return "ToCHW", ToCHWImage()
  143. @register("Topk")
  144. def build_topk(self, topk, label_list=None):
  145. if not self.topk:
  146. self.topk = int(topk)
  147. return "Topk", Topk(class_ids=label_list)