|
@@ -4,7 +4,7 @@ from typing import Iterable, List, Optional, Union
|
|
|
import torch
|
|
import torch
|
|
|
from PIL import Image
|
|
from PIL import Image
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
-from transformers import AutoTokenizer, BitsAndBytesConfig
|
|
|
|
|
|
|
+from transformers import AutoTokenizer, BitsAndBytesConfig, __version__
|
|
|
|
|
|
|
|
from ...model.vlm_hf_model import Mineru2QwenForCausalLM
|
|
from ...model.vlm_hf_model import Mineru2QwenForCausalLM
|
|
|
from ...model.vlm_hf_model.image_processing_mineru2 import process_images
|
|
from ...model.vlm_hf_model.image_processing_mineru2 import process_images
|
|
@@ -66,7 +66,11 @@ class HuggingfacePredictor(BasePredictor):
|
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_quant_type="nf4",
|
|
|
)
|
|
)
|
|
|
else:
|
|
else:
|
|
|
- kwargs["torch_dtype"] = torch_dtype
|
|
|
|
|
|
|
+ from packaging import version
|
|
|
|
|
+ if version.parse(__version__) >= version.parse("4.56.0"):
|
|
|
|
|
+ kwargs["dtype"] = torch_dtype
|
|
|
|
|
+ else:
|
|
|
|
|
+ kwargs["torch_dtype"] = torch_dtype
|
|
|
|
|
|
|
|
if use_flash_attn:
|
|
if use_flash_attn:
|
|
|
kwargs["attn_implementation"] = "flash_attention_2"
|
|
kwargs["attn_implementation"] = "flash_attention_2"
|