Ver Fonte

fix: transformers 4.54.0 adaptation

myhloli há 3 meses atrás
pai
commit
0490af1cd7

+ 2 - 4
mineru/model/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py

@@ -1419,9 +1419,7 @@ class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
         # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
         past_key_values_length = 0
         if past_key_values is not None:
-            if hasattr(past_key_values, 'get_seq_length'):
-                past_key_values_length = past_key_values.get_seq_length()
-            elif isinstance(past_key_values, (list, tuple)) and past_key_values:
+            if isinstance(past_key_values, (list, tuple)) and past_key_values:
                 past_key_values_length = past_key_values[0][0].shape[2]
 
         if inputs_embeds is None:
@@ -1510,7 +1508,7 @@ class UnimerMBartDecoder(UnimerMBartPreTrainedModel):
             # past_key_value = past_key_values[idx] if past_key_values is not None else None
             past_key_value = None
             if past_key_values is not None and len(past_key_values) > idx:
-                if hasattr(past_key_values, 'get_usable_length') or isinstance(past_key_values, (list, tuple)):
+                if isinstance(past_key_values, (list, tuple)):
                     past_key_value = past_key_values[idx]
 
             if self.gradient_checkpointing and self.training: