| 1234567891011121314151617181920212223 |
- import torch
- import torch.nn.functional as F
- from torch import nn
- class ClsHead(nn.Module):
- """
- Class orientation
- Args:
- params(dict): super parameters for build Class network
- """
- def __init__(self, in_channels, class_dim, **kwargs):
- super(ClsHead, self).__init__()
- self.pool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Linear(in_channels, class_dim, bias=True)
- def forward(self, x):
- x = self.pool(x)
- x = torch.reshape(x, shape=[x.shape[0], x.shape[1]])
- x = self.fc(x)
- x = F.softmax(x, dim=1)
- return x
|