processors.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from typing import Union
  16. class MultiLabelThreshOutput:
  17. """MultiLabelThresh Transform"""
  18. def __init__(self, class_ids=None, delimiter=None):
  19. super().__init__()
  20. self.delimiter = delimiter if delimiter is not None else " "
  21. self.class_id_map = self._parse_class_id_map(class_ids)
  22. def _parse_class_id_map(self, class_ids):
  23. """parse class id to label map file"""
  24. if class_ids is None:
  25. return None
  26. class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
  27. return class_id_map
  28. def __call__(self, preds, threshold: Union[float, dict, list]):
  29. threshold_list = []
  30. num_classes = preds[0].shape[-1]
  31. if isinstance(threshold, float):
  32. threshold_list = [threshold for _ in range(num_classes)]
  33. elif isinstance(threshold, dict):
  34. if threshold.get("default") is None:
  35. raise ValueError(
  36. "If using dictionary format, please specify default threshold explicitly with key 'default'."
  37. )
  38. default_threshold = threshold.pop("default")
  39. threshold_list = [default_threshold for _ in range(num_classes)]
  40. for k, v in threshold.items():
  41. if isinstance(k, str):
  42. assert (
  43. k.isdigit()
  44. ), f"Invalid key of threshold: {k}, it must be integer"
  45. k = int(k)
  46. if not isinstance(v, float):
  47. raise ValueError(
  48. f"Invalid value type of threshold: {type(v)}, it must be float"
  49. )
  50. assert (
  51. k < num_classes
  52. ), f"Invalid key of threshold: {k}, it must be less than the number of classes({num_classes})"
  53. threshold_list[k] = v
  54. elif isinstance(threshold, list):
  55. assert (
  56. len(threshold) == num_classes
  57. ), f"The length of threshold({len(threshold)}) should be equal to the number of classes({num_classes})."
  58. threshold_list = threshold
  59. else:
  60. raise ValueError(
  61. "Invalid type of threshold, should be 'list', 'dict' or 'float'."
  62. )
  63. pred_indexes = [
  64. np.argsort(-x[x > threshold])
  65. for x, threshold in zip(preds[0], threshold_list)
  66. ]
  67. indexes = [
  68. np.where(x > threshold)[0][indices]
  69. for x, indices, threshold in zip(preds[0], pred_indexes, threshold_list)
  70. ]
  71. scores = [
  72. np.around(pred[index].astype(float), decimals=5)
  73. for pred, index in zip(preds[0], indexes)
  74. ]
  75. label_names = [[self.class_id_map[i] for i in index] for index in indexes]
  76. return indexes, scores, label_names