|
|
@@ -9,14 +9,27 @@ class Im2Seq(nn.Module):
|
|
|
super().__init__()
|
|
|
self.out_channels = in_channels
|
|
|
|
|
|
+ # def forward(self, x):
|
|
|
+ # B, C, H, W = x.shape
|
|
|
+ # # assert H == 1
|
|
|
+ # x = x.squeeze(dim=2)
|
|
|
+ # # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
|
|
|
+ # x = x.permute(0, 2, 1)
|
|
|
+ # return x
|
|
|
+
|
|
|
def forward(self, x):
|
|
|
B, C, H, W = x.shape
|
|
|
- # assert H == 1
|
|
|
- x = x.squeeze(dim=2)
|
|
|
- # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
|
|
|
- x = x.permute(0, 2, 1)
|
|
|
- return x
|
|
|
+ # 处理四维张量,将空间维度展平为序列
|
|
|
+ if H == 1:
|
|
|
+ # 原来的处理逻辑,适用于H=1的情况
|
|
|
+ x = x.squeeze(dim=2)
|
|
|
+ x = x.permute(0, 2, 1) # (B, W, C)
|
|
|
+ else:
|
|
|
+ # 处理H不为1的情况
|
|
|
+ x = x.permute(0, 2, 3, 1) # (B, H, W, C)
|
|
|
+ x = x.reshape(B, H * W, C) # (B, H*W, C)
|
|
|
|
|
|
+ return x
|
|
|
|
|
|
class EncoderWithRNN_(nn.Module):
|
|
|
def __init__(self, in_channels, hidden_size):
|