processors.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Union
  15. import numpy as np
  16. from ...utils.benchmark import benchmark
  17. @benchmark.timeit
  18. class MultiLabelThreshOutput:
  19. """MultiLabelThresh Transform"""
  20. def __init__(self, class_ids=None, delimiter=None):
  21. super().__init__()
  22. self.delimiter = delimiter if delimiter is not None else " "
  23. self.class_id_map = self._parse_class_id_map(class_ids)
  24. def _parse_class_id_map(self, class_ids):
  25. """parse class id to label map file"""
  26. if class_ids is None:
  27. return None
  28. class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
  29. return class_id_map
  30. def __call__(self, preds, threshold: Union[float, dict, list]):
  31. threshold_list = []
  32. num_classes = preds[0].shape[-1]
  33. if isinstance(threshold, float):
  34. threshold_list = [threshold for _ in range(num_classes)]
  35. elif isinstance(threshold, dict):
  36. if threshold.get("default") is None:
  37. raise ValueError(
  38. "If using dictionary format, please specify default threshold explicitly with key 'default'."
  39. )
  40. default_threshold = threshold.get("default")
  41. threshold_list = [default_threshold for _ in range(num_classes)]
  42. for k, v in threshold.items():
  43. if k == "default":
  44. continue
  45. if isinstance(k, str):
  46. assert (
  47. k.isdigit()
  48. ), f"Invalid key of threshold: {k}, it must be integer"
  49. k = int(k)
  50. if not isinstance(v, float):
  51. raise ValueError(
  52. f"Invalid value type of threshold: {type(v)}, it must be float"
  53. )
  54. assert (
  55. k < num_classes
  56. ), f"Invalid key of threshold: {k}, it must be less than the number of classes({num_classes})"
  57. threshold_list[k] = v
  58. elif isinstance(threshold, list):
  59. assert (
  60. len(threshold) == num_classes
  61. ), f"The length of threshold({len(threshold)}) should be equal to the number of classes({num_classes})."
  62. threshold_list = threshold
  63. else:
  64. raise ValueError(
  65. "Invalid type of threshold, should be 'list', 'dict' or 'float'."
  66. )
  67. pred_indexes = [
  68. np.argsort(-x[x > threshold])
  69. for x, threshold in zip(preds[0], threshold_list)
  70. ]
  71. indexes = [
  72. np.where(x > threshold)[0][indices]
  73. for x, indices, threshold in zip(preds[0], pred_indexes, threshold_list)
  74. ]
  75. scores = [
  76. np.around(pred[index].astype(float), decimals=5)
  77. for pred, index in zip(preds[0], indexes)
  78. ]
  79. label_names = [[self.class_id_map[i] for i in index] for index in indexes]
  80. return indexes, scores, label_names