|
|
@@ -1419,9 +1419,7 @@ class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
|
|
|
# past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
|
past_key_values_length = 0
|
|
|
if past_key_values is not None:
|
|
|
- if hasattr(past_key_values, 'get_seq_length'):
|
|
|
- past_key_values_length = past_key_values.get_seq_length()
|
|
|
- elif isinstance(past_key_values, (list, tuple)) and past_key_values:
|
|
|
+ if isinstance(past_key_values, (list, tuple)) and past_key_values:
|
|
|
past_key_values_length = past_key_values[0][0].shape[2]
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
@@ -1510,7 +1508,7 @@ class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
|
|
|
# past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
past_key_value = None
|
|
|
if past_key_values is not None and len(past_key_values) > idx:
|
|
|
- if hasattr(past_key_values, 'get_usable_length') or isinstance(past_key_values, (list, tuple)):
|
|
|
+ if isinstance(past_key_values, (list, tuple)):
|
|
|
past_key_value = past_key_values[idx]
|
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|