rec_multi_head.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from torch import nn
  2. from ..necks.rnn import Im2Seq, SequenceEncoder
  3. from .rec_ctc_head import CTCHead
  4. class FCTranspose(nn.Module):
  5. def __init__(self, in_channels, out_channels, only_transpose=False):
  6. super().__init__()
  7. self.only_transpose = only_transpose
  8. if not self.only_transpose:
  9. self.fc = nn.Linear(in_channels, out_channels, bias=False)
  10. def forward(self, x):
  11. if self.only_transpose:
  12. return x.permute([0, 2, 1])
  13. else:
  14. return self.fc(x.permute([0, 2, 1]))
  15. class MultiHead(nn.Module):
  16. def __init__(self, in_channels, out_channels_list, **kwargs):
  17. super().__init__()
  18. self.head_list = kwargs.pop("head_list")
  19. self.gtc_head = "sar"
  20. assert len(self.head_list) >= 2
  21. for idx, head_name in enumerate(self.head_list):
  22. name = list(head_name)[0]
  23. if name == "SARHead":
  24. pass
  25. elif name == "NRTRHead":
  26. pass
  27. elif name == "CTCHead":
  28. # ctc neck
  29. self.encoder_reshape = Im2Seq(in_channels)
  30. neck_args = self.head_list[idx][name]["Neck"]
  31. encoder_type = neck_args.pop("name")
  32. self.ctc_encoder = SequenceEncoder(
  33. in_channels=in_channels, encoder_type=encoder_type, **neck_args
  34. )
  35. # ctc head
  36. head_args = self.head_list[idx][name].get("Head", {})
  37. if head_args is None:
  38. head_args = {}
  39. self.ctc_head = CTCHead(
  40. in_channels=self.ctc_encoder.out_channels,
  41. out_channels=out_channels_list["CTCLabelDecode"],
  42. **head_args,
  43. )
  44. else:
  45. raise NotImplementedError(f"{name} is not supported in MultiHead yet")
  46. def forward(self, x, data=None):
  47. ctc_encoder = self.ctc_encoder(x)
  48. return self.ctc_head(ctc_encoder)