瀏覽代碼

refactor: update config file name and enhance model path handling

myhloli 5 月之前
父節點
當前提交
a29489ef51

+ 0 - 0
magic-pdf.template.json → mineru.template.json


+ 1 - 0
mineru/backend/vlm/hf_predictor.py

@@ -77,6 +77,7 @@ class HuggingfacePredictor(BasePredictor):
             low_cpu_mem_usage=True,
             **kwargs,
         )
+        setattr(self.model.config, "_name_or_path", model_path)
         self.model.eval()
 
         vision_tower = self.model.get_model().vision_tower

+ 6 - 3
mineru/cli/common.py

@@ -158,6 +158,9 @@ def do_parse(
 
             logger.info(f"local output dir is {local_md_dir}")
     else:
+        if backend.startswith("vlm-"):
+            backend = backend[4:]
+
         f_draw_span_bbox = False
         parse_method = "vlm"
         for idx, pdf_bytes in enumerate(pdf_bytes_list):
@@ -216,10 +219,10 @@ def do_parse(
 
 
 if __name__ == "__main__":
-    # pdf_path = "../../demo/pdfs/demo2.pdf"
-    pdf_path = "C:/Users/zhaoxiaomeng/Downloads/input_img_0.jpg"
+    pdf_path = "../../demo/pdfs/demo2.pdf"
+    # pdf_path = "C:/Users/zhaoxiaomeng/Downloads/input_img_0.jpg"
 
     try:
-       do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"], end_page_id=20,)
+       do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"], end_page_id=1, backend='vlm-huggingface')
     except Exception as e:
         logger.exception(e)

+ 5 - 1
mineru/model/vlm_hf_model/modeling_mineru2.py

@@ -79,8 +79,12 @@ class SiglipVisionTower(nn.Module):
 
 def build_vision_tower(config: Mineru2QwenConfig):
     vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
+    model_path = getattr(config, "_name_or_path", "")
     if "siglip" in vision_tower.lower():
-        return SiglipVisionTower(vision_tower)
+        if model_path:
+            return SiglipVisionTower(f"{model_path}/{vision_tower}")
+        else:
+            return SiglipVisionTower(vision_tower)
     raise ValueError(f"Unknown vision tower: {vision_tower}")
 
 

+ 1 - 1
mineru/utils/config_reader.py

@@ -6,7 +6,7 @@ import torch
 from loguru import logger
 
 # 定义配置文件名常量
-CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
+CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
 
 
 def read_config():