processors.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from typing import Union
  16. class MultiLabelThreshOutput:
  17. """MultiLabelThresh Transform"""
  18. def __init__(self, class_ids=None, delimiter=None):
  19. super().__init__()
  20. self.delimiter = delimiter if delimiter is not None else " "
  21. self.class_id_map = self._parse_class_id_map(class_ids)
  22. def _parse_class_id_map(self, class_ids):
  23. """parse class id to label map file"""
  24. if class_ids is None:
  25. return None
  26. class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
  27. return class_id_map
  28. def __call__(self, preds, threshold: Union[float, dict, list]):
  29. threshold_list = []
  30. num_classes = preds[0].shape[-1]
  31. if isinstance(threshold, float):
  32. threshold_list = [threshold for _ in range(num_classes)]
  33. elif isinstance(threshold, dict):
  34. if threshold.get("default") is None:
  35. raise ValueError(
  36. "If using dictionary format, please specify default threshold explicitly with key 'default'."
  37. )
  38. default_threshold = threshold.get("default")
  39. threshold_list = [default_threshold for _ in range(num_classes)]
  40. for k, v in threshold.items():
  41. if k == "default":
  42. continue
  43. if isinstance(k, str):
  44. assert (
  45. k.isdigit()
  46. ), f"Invalid key of threshold: {k}, it must be integer"
  47. k = int(k)
  48. if not isinstance(v, float):
  49. raise ValueError(
  50. f"Invalid value type of threshold: {type(v)}, it must be float"
  51. )
  52. assert (
  53. k < num_classes
  54. ), f"Invalid key of threshold: {k}, it must be less than the number of classes({num_classes})"
  55. threshold_list[k] = v
  56. elif isinstance(threshold, list):
  57. assert (
  58. len(threshold) == num_classes
  59. ), f"The length of threshold({len(threshold)}) should be equal to the number of classes({num_classes})."
  60. threshold_list = threshold
  61. else:
  62. raise ValueError(
  63. "Invalid type of threshold, should be 'list', 'dict' or 'float'."
  64. )
  65. pred_indexes = [
  66. np.argsort(-x[x > threshold])
  67. for x, threshold in zip(preds[0], threshold_list)
  68. ]
  69. indexes = [
  70. np.where(x > threshold)[0][indices]
  71. for x, indices, threshold in zip(preds[0], pred_indexes, threshold_list)
  72. ]
  73. scores = [
  74. np.around(pred[index].astype(float), decimals=5)
  75. for pred, index in zip(preds[0], indexes)
  76. ]
  77. label_names = [[self.class_id_map[i] for i in index] for index in indexes]
  78. return indexes, scores, label_names