seg_metrics.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. def loss_computation(logits_list, labels, losses):
  17. loss_list = []
  18. for i in range(len(logits_list)):
  19. logits = logits_list[i]
  20. loss_i = losses['types'][i]
  21. loss_list.append(losses['coef'][i] * loss_i(logits, labels))
  22. return loss_list
  23. def f1_score(intersect_area, pred_area, label_area):
  24. intersect_area = intersect_area.numpy()
  25. pred_area = pred_area.numpy()
  26. label_area = label_area.numpy()
  27. class_f1_sco = []
  28. for i in range(len(intersect_area)):
  29. if pred_area[i] + label_area[i] == 0:
  30. f1_sco = 0
  31. elif pred_area[i] == 0:
  32. f1_sco = 0
  33. else:
  34. prec = intersect_area[i] / pred_area[i]
  35. rec = intersect_area[i] / label_area[i]
  36. f1_sco = 2 * prec * rec / (prec + rec)
  37. class_f1_sco.append(f1_sco)
  38. return np.array(class_f1_sco)
  39. def confusion_matrix(pred, label, num_classes, ignore_index=255):
  40. label = paddle.transpose(label, (0, 2, 3, 1))
  41. pred = paddle.transpose(pred, (0, 2, 3, 1))
  42. mask = label != ignore_index
  43. label = paddle.masked_select(label, mask)
  44. pred = paddle.masked_select(pred, mask)
  45. cat_matrix = num_classes * label + pred
  46. conf_mat = paddle.histogram(
  47. cat_matrix,
  48. bins=num_classes * num_classes,
  49. min=0,
  50. max=num_classes * num_classes - 1).reshape([num_classes, num_classes])
  51. return conf_mat