clas.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ....utils import logging
  16. from ..base import BaseComponent
  17. __all__ = ["Topk", "NormalizeFeatures"]
  18. def _parse_class_id_map(class_ids):
  19. """parse class id to label map file"""
  20. if class_ids is None:
  21. return None
  22. class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
  23. return class_id_map
  24. class Topk(BaseComponent):
  25. """Topk Transform"""
  26. INPUT_KEYS = ["pred"]
  27. OUTPUT_KEYS = [["class_ids", "scores"], ["class_ids", "scores", "label_names"]]
  28. DEAULT_INPUTS = {"pred": "pred"}
  29. DEAULT_OUTPUTS = {
  30. "class_ids": "class_ids",
  31. "scores": "scores",
  32. "label_names": "label_names",
  33. }
  34. def __init__(self, topk, class_ids=None):
  35. super().__init__()
  36. assert isinstance(topk, (int,))
  37. self.topk = topk
  38. self.class_id_map = _parse_class_id_map(class_ids)
  39. def apply(self, pred):
  40. """apply"""
  41. cls_pred = pred[0]
  42. index = cls_pred.argsort(axis=0)[-self.topk :][::-1].astype("int32")
  43. clas_id_list = []
  44. score_list = []
  45. label_name_list = []
  46. for i in index:
  47. clas_id_list.append(i.item())
  48. score_list.append(cls_pred[i].item())
  49. if self.class_id_map is not None:
  50. label_name_list.append(self.class_id_map[i.item()])
  51. result = {
  52. "class_ids": clas_id_list,
  53. "scores": np.around(score_list, decimals=5),
  54. }
  55. if label_name_list is not None:
  56. result["label_names"] = label_name_list
  57. return result
  58. class MultiLabelThreshOutput(BaseComponent):
  59. INPUT_KEYS = ["pred"]
  60. OUTPUT_KEYS = [["class_ids", "scores"], ["class_ids", "scores", "label_names"]]
  61. DEAULT_INPUTS = {"pred": "pred"}
  62. DEAULT_OUTPUTS = {
  63. "class_ids": "class_ids",
  64. "scores": "scores",
  65. "label_names": "label_names",
  66. }
  67. def __init__(self, threshold=0.5, class_ids=None, delimiter=None):
  68. super().__init__()
  69. assert isinstance(threshold, (float,))
  70. self.threshold = threshold
  71. self.delimiter = delimiter if delimiter is not None else " "
  72. self.class_id_map = _parse_class_id_map(class_ids)
  73. def apply(self, pred):
  74. """apply"""
  75. y = []
  76. x = pred[0]
  77. pred_index = np.where(x >= self.threshold)[0].astype("int32")
  78. index = pred_index[np.argsort(x[pred_index])][::-1]
  79. clas_id_list = []
  80. score_list = []
  81. label_name_list = []
  82. for i in index:
  83. clas_id_list.append(i.item())
  84. score_list.append(x[i].item())
  85. if self.class_id_map is not None:
  86. label_name_list.append(self.class_id_map[i.item()])
  87. result = {
  88. "class_ids": clas_id_list,
  89. "scores": np.around(score_list, decimals=5),
  90. }
  91. if label_name_list is not None:
  92. result["label_names"] = label_name_list
  93. return result
  94. class NormalizeFeatures(BaseComponent):
  95. """Normalize Features Transform"""
  96. INPUT_KEYS = ["pred"]
  97. OUTPUT_KEYS = ["rec_feature"]
  98. DEAULT_INPUTS = {"pred": "pred"}
  99. DEAULT_OUTPUTS = {"rec_feature": "rec_feature"}
  100. def apply(self, pred):
  101. """apply"""
  102. feas_norm = np.sqrt(np.sum(np.square(pred[0]), axis=0, keepdims=True))
  103. rec_feature = np.divide(pred[0], feas_norm)
  104. return {"rec_feature": rec_feature}