|
|
@@ -219,7 +219,7 @@ class AttentionMaskConverter:
|
|
|
|
|
|
if causal_4d_mask is not None:
|
|
|
expanded_attn_mask = causal_4d_mask.masked_fill_(
|
|
|
- expanded_attn_mask.cast(torch.bool), torch.finfo(dtype).min
|
|
|
+ expanded_attn_mask.to(torch.bool), torch.finfo(dtype).min
|
|
|
)
|
|
|
|
|
|
expanded_4d_mask = expanded_attn_mask
|