predictor.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Union
  15. import numpy as np
  16. from ....modules.multilabel_classification.model_list import MODELS
  17. from ..image_classification import ClasPredictor
  18. from .processors import MultiLabelThreshOutput
  19. from .result import MLClassResult
  20. class MLClasPredictor(ClasPredictor):
  21. """MLClasPredictor that inherits from BasePredictor."""
  22. entities = MODELS
  23. def __init__(
  24. self,
  25. threshold: Union[float, dict, list, None] = None,
  26. *args: List,
  27. **kwargs: Dict
  28. ) -> None:
  29. """Initializes MLClasPredictor.
  30. Args:
  31. threshold (float, dict, optional): The threshold predictions to return. If None, it will be depending on config of inference or predict. Defaults to None.
  32. *args: Arbitrary positional arguments passed to the superclass.
  33. **kwargs: Arbitrary keyword arguments passed to the superclass.
  34. """
  35. self.threshold = threshold
  36. super().__init__(*args, **kwargs)
  37. def _get_result_class(self) -> type:
  38. """Returns the result class, MLClassResult.
  39. Returns:
  40. type: The MLClassResult class.
  41. """
  42. return MLClassResult
  43. def process(
  44. self,
  45. batch_data: List[Union[str, np.ndarray]],
  46. threshold: Union[int, dict, None] = None,
  47. ) -> Dict[str, Any]:
  48. """
  49. Process a batch of data through the preprocessing, inference, and postprocessing.
  50. Args:
  51. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  52. threshold (float, dict, optional): The threshold predictions to return. If None, it will be depending on config of inference or predict. Defaults to None.
  53. Returns:
  54. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  55. """
  56. batch_raw_imgs = self.preprocessors["Read"](imgs=batch_data.instances)
  57. batch_imgs = self.preprocessors["Resize"](imgs=batch_raw_imgs)
  58. batch_imgs = self.preprocessors["Normalize"](imgs=batch_imgs)
  59. batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs)
  60. x = self.preprocessors["ToBatch"](imgs=batch_imgs)
  61. batch_preds = self.infer(x=x)
  62. batch_class_ids, batch_scores, batch_label_names = self.postprocessors[
  63. "MultiLabelThreshOutput"
  64. ](
  65. preds=batch_preds,
  66. threshold=self.threshold if threshold is None else threshold,
  67. )
  68. return {
  69. "input_path": batch_data.input_paths,
  70. "page_index": batch_data.page_indexes,
  71. "input_img": batch_raw_imgs,
  72. "class_ids": batch_class_ids,
  73. "scores": batch_scores,
  74. "label_names": batch_label_names,
  75. }
  76. @ClasPredictor.register("MultiLabelThreshOutput")
  77. def build_threshoutput(self, threshold: Union[float, dict, list], label_list=None):
  78. if self.threshold is None:
  79. self.threshold = threshold
  80. return "MultiLabelThreshOutput", MultiLabelThreshOutput(class_ids=label_list)