|
|
@@ -1506,10 +1506,11 @@ class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
|
|
|
continue
|
|
|
|
|
|
# past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
- past_key_value = None
|
|
|
- if past_key_values is not None and len(past_key_values) > idx:
|
|
|
- if isinstance(past_key_values, (list, tuple)):
|
|
|
- past_key_value = past_key_values[idx]
|
|
|
+ past_key_value = past_key_values[idx] if (
|
|
|
+ past_key_values is not None and
|
|
|
+ isinstance(past_key_values, (list, tuple)) and
|
|
|
+ idx < len(past_key_values)
|
|
|
+ ) else None
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
layer_outputs = self._gradient_checkpointing_func(
|