瀏覽代碼

feat: update model.py to automatically set vision tower path using auto_download_and_get_model_root_path

myhloli 5 月之前
父節點
當前提交
878b9e4798
共有 1 個文件被更改,包括 4 次插入0 次删除
  1. 4 0
      mineru/model/vlm_sglang_model/model.py

+ 4 - 0
mineru/model/vlm_sglang_model/model.py

@@ -22,6 +22,7 @@ from transformers import (
 
 from ..vlm_hf_model.configuration_mineru2 import Mineru2QwenConfig
 from ..vlm_hf_model.modeling_mineru2 import build_vision_projector
+from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 def flatten_nested_list(nested_list):
@@ -61,6 +62,9 @@ class Mineru2QwenForCausalLM(nn.Module):
 
         # load vision tower
         mm_vision_tower = self.config.mm_vision_tower
+        model_root_path = auto_download_and_get_model_root_path("/", "vlm")
+        mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
+
         if "clip" in mm_vision_tower:
             vision_config = CLIPVisionConfig.from_pretrained(mm_vision_tower)
             self.vision_tower = CLIPVisionModel(vision_config)  # type: ignore