rec_ctc_head.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import torch.nn.functional as F
  2. from torch import nn
  3. class CTCHead(nn.Module):
  4. def __init__(
  5. self,
  6. in_channels,
  7. out_channels=6625,
  8. fc_decay=0.0004,
  9. mid_channels=None,
  10. return_feats=False,
  11. **kwargs
  12. ):
  13. super(CTCHead, self).__init__()
  14. if mid_channels is None:
  15. self.fc = nn.Linear(
  16. in_channels,
  17. out_channels,
  18. bias=True,
  19. )
  20. else:
  21. self.fc1 = nn.Linear(
  22. in_channels,
  23. mid_channels,
  24. bias=True,
  25. )
  26. self.fc2 = nn.Linear(
  27. mid_channels,
  28. out_channels,
  29. bias=True,
  30. )
  31. self.out_channels = out_channels
  32. self.mid_channels = mid_channels
  33. self.return_feats = return_feats
  34. def forward(self, x, labels=None):
  35. if self.mid_channels is None:
  36. predicts = self.fc(x)
  37. else:
  38. x = self.fc1(x)
  39. predicts = self.fc2(x)
  40. if self.return_feats:
  41. result = (x, predicts)
  42. else:
  43. result = predicts
  44. if not self.training:
  45. predicts = F.softmax(predicts, dim=2)
  46. result = predicts
  47. return result