Selaa lähdekoodia

feat: update dtype handling based on transformers version

myhloli 2 kuukautta sitten
vanhempi
commit
cb4d1cceb3
1 muutettua tiedostoa jossa 6 lisäystä ja 2 poistoa
  1. 6 2
      mineru/backend/vlm/hf_predictor.py

+ 6 - 2
mineru/backend/vlm/hf_predictor.py

@@ -4,7 +4,7 @@ from typing import Iterable, List, Optional, Union
 import torch
 from PIL import Image
 from tqdm import tqdm
-from transformers import AutoTokenizer, BitsAndBytesConfig
+from transformers import AutoTokenizer, BitsAndBytesConfig, __version__
 
 from ...model.vlm_hf_model import Mineru2QwenForCausalLM
 from ...model.vlm_hf_model.image_processing_mineru2 import process_images
@@ -66,7 +66,11 @@ class HuggingfacePredictor(BasePredictor):
                 bnb_4bit_quant_type="nf4",
             )
         else:
-            kwargs["torch_dtype"] = torch_dtype
+            from packaging import version
+            if version.parse(__version__) >= version.parse("4.56.0"):
+                kwargs["dtype"] = torch_dtype
+            else:
+                kwargs["torch_dtype"] = torch_dtype
 
         if use_flash_attn:
             kwargs["attn_implementation"] = "flash_attention_2"