clas.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ....utils import logging
  16. from ...results import TopkResult
  17. from ..base import BaseComponent
  18. __all__ = ["Topk", "NormalizeFeatures"]
  19. def _parse_class_id_map(class_ids):
  20. """parse class id to label map file"""
  21. if class_ids is None:
  22. return None
  23. class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
  24. return class_id_map
  25. class Topk(BaseComponent):
  26. """Topk Transform"""
  27. INPUT_KEYS = ["pred", "img_path"]
  28. OUTPUT_KEYS = ["topk_res"]
  29. DEAULT_INPUTS = {"pred": "pred", "img_path": "img_path"}
  30. DEAULT_OUTPUTS = {"topk_res": "topk_res"}
  31. def __init__(self, topk, class_ids=None):
  32. super().__init__()
  33. assert isinstance(topk, (int,))
  34. self.topk = topk
  35. self.class_id_map = _parse_class_id_map(class_ids)
  36. def apply(self, pred, img_path):
  37. """apply"""
  38. cls_pred = pred
  39. class_id_map = self.class_id_map
  40. index = cls_pred.argsort(axis=0)[-self.topk :][::-1].astype("int32")
  41. clas_id_list = []
  42. score_list = []
  43. label_name_list = []
  44. for i in index:
  45. clas_id_list.append(i.item())
  46. score_list.append(cls_pred[i].item())
  47. if class_id_map is not None:
  48. label_name_list.append(class_id_map[i.item()])
  49. result = {
  50. "img_path": img_path,
  51. "class_ids": clas_id_list,
  52. "scores": np.around(score_list, decimals=5).tolist(),
  53. }
  54. if label_name_list is not None:
  55. result["label_names"] = label_name_list
  56. return {"topk_res": TopkResult(result)}
  57. class NormalizeFeatures(BaseComponent):
  58. """Normalize Features Transform"""
  59. INPUT_KEYS = ["cls_pred"]
  60. OUTPUT_KEYS = ["cls_res"]
  61. DEAULT_INPUTS = {"cls_res": "cls_res"}
  62. DEAULT_OUTPUTS = {"cls_pred": "cls_pred"}
  63. def apply(self, cls_pred):
  64. """apply"""
  65. feas_norm = np.sqrt(np.sum(np.square(cls_pred), axis=0, keepdims=True))
  66. cls_res = np.divide(cls_pred, feas_norm)
  67. return {"cls_res": cls_res}