cls_head.py 596 B

1234567891011121314151617181920212223
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. class ClsHead(nn.Module):
  5. """
  6. Class orientation
  7. Args:
  8. params(dict): super parameters for build Class network
  9. """
  10. def __init__(self, in_channels, class_dim, **kwargs):
  11. super(ClsHead, self).__init__()
  12. self.pool = nn.AdaptiveAvgPool2d(1)
  13. self.fc = nn.Linear(in_channels, class_dim, bias=True)
  14. def forward(self, x):
  15. x = self.pool(x)
  16. x = torch.reshape(x, shape=[x.shape[0], x.shape[1]])
  17. x = self.fc(x)
  18. x = F.softmax(x, dim=1)
  19. return x