|
|
@@ -251,13 +251,13 @@ class AttentionMaskConverter:
|
|
|
bsz, src_len = mask.shape
|
|
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
|
expanded_mask = (
|
|
|
- mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).cast(dtype)
|
|
|
+ mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).to(dtype)
|
|
|
)
|
|
|
|
|
|
inverted_mask = 1.0 - expanded_mask
|
|
|
|
|
|
return inverted_mask.masked_fill_(
|
|
|
- inverted_mask.cast(torch.bool), torch.finfo(dtype).min
|
|
|
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
|
)
|
|
|
|
|
|
def _expand_mask_export(self, mask, dtype, tgt_len=None):
|