clas.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ....utils import logging
  16. from ..base import BaseComponent
  17. __all__ = ["Topk", "NormalizeFeatures"]
  18. def _parse_class_id_map(class_ids):
  19. """parse class id to label map file"""
  20. if class_ids is None:
  21. return None
  22. class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
  23. return class_id_map
  24. class Topk(BaseComponent):
  25. """Topk Transform"""
  26. INPUT_KEYS = ["pred"]
  27. OUTPUT_KEYS = [["class_ids", "scores"], ["class_ids", "scores", "label_names"]]
  28. DEAULT_INPUTS = {"pred": "pred"}
  29. DEAULT_OUTPUTS = {
  30. "class_ids": "class_ids",
  31. "scores": "scores",
  32. "label_names": "label_names",
  33. }
  34. def __init__(self, topk, class_ids=None):
  35. super().__init__()
  36. assert isinstance(topk, (int,))
  37. self.topk = topk
  38. self.class_id_map = _parse_class_id_map(class_ids)
  39. def apply(self, pred):
  40. """apply"""
  41. cls_pred = pred[0]
  42. class_id_map = self.class_id_map
  43. index = cls_pred.argsort(axis=0)[-self.topk :][::-1].astype("int32")
  44. clas_id_list = []
  45. score_list = []
  46. label_name_list = []
  47. for i in index:
  48. clas_id_list.append(i.item())
  49. score_list.append(cls_pred[i].item())
  50. if class_id_map is not None:
  51. label_name_list.append(class_id_map[i.item()])
  52. result = {
  53. "class_ids": clas_id_list,
  54. "scores": np.around(score_list, decimals=5).tolist(),
  55. }
  56. if label_name_list is not None:
  57. result["label_names"] = label_name_list
  58. return result
  59. class NormalizeFeatures(BaseComponent):
  60. """Normalize Features Transform"""
  61. INPUT_KEYS = ["cls_pred"]
  62. OUTPUT_KEYS = ["cls_res"]
  63. DEAULT_INPUTS = {"cls_res": "cls_res"}
  64. DEAULT_OUTPUTS = {"cls_pred": "cls_pred"}
  65. def apply(self, cls_pred):
  66. """apply"""
  67. feas_norm = np.sqrt(np.sum(np.square(cls_pred), axis=0, keepdims=True))
  68. cls_res = np.divide(cls_pred, feas_norm)
  69. return {"cls_res": cls_res}