cls_postprocess.py 685 B

1234567891011121314151617181920
  1. import torch
  2. class ClsPostProcess(object):
  3. """ Convert between text-label and text-index """
  4. def __init__(self, label_list, **kwargs):
  5. super(ClsPostProcess, self).__init__()
  6. self.label_list = label_list
  7. def __call__(self, preds, label=None, *args, **kwargs):
  8. if isinstance(preds, torch.Tensor):
  9. preds = preds.cpu().numpy()
  10. pred_idxs = preds.argmax(axis=1)
  11. decode_out = [(self.label_list[idx], preds[i, idx])
  12. for i, idx in enumerate(pred_idxs)]
  13. if label is None:
  14. return decode_out
  15. label = [(self.label_list[idx], 1.0) for idx in label]
  16. return decode_out, label