Просмотр исходного кода

Add custom logits processors functionality with compute capability check

myhloli 1 месяц назад
Родитель
Сommit
f5e0e67545

+ 14 - 0
mineru/backend/vlm/custom_logits_processors.py

@@ -0,0 +1,14 @@
+from loguru import logger
+
+
+def enable_custom_logits_processors():
+    import torch
+    compute_capability = 0.0
+    custom_logits_processors = False
+    if torch.cuda.is_available():
+        major, minor = torch.cuda.get_device_capability()
+        compute_capability = float(major) + (float(minor) / 10.0)
+    if compute_capability >= 8.0:
+        logger.info(f"compute_capability: {compute_capability}, enable custom_logits_processors")
+        custom_logits_processors = True
+    return custom_logits_processors

+ 3 - 9
mineru/backend/vlm/vlm_analyze.py

@@ -4,6 +4,7 @@ import time
 
 from loguru import logger
 
+from .custom_logits_processors import enable_custom_logits_processors
 from .model_output_to_middle_json import result_to_middle_json
 from ...data.data_reader_writer import DataWriter
 from mineru.utils.pdf_image_tools import load_images_from_pdf
@@ -43,15 +44,8 @@ class ModelSingleton:
             batch_size = 0
             if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
                 model_path = auto_download_and_get_model_root_path("/","vlm")
-                import torch
-                compute_capability = 0.0
-                custom_logits_processors = False
-                if torch.cuda.is_available():
-                    major, minor = torch.cuda.get_device_capability()
-                    compute_capability = float(major) + (float(minor) / 10.0)
-                    logger.info(f"compute_capability: {compute_capability}")
-                if compute_capability >= 8.0:
-                    custom_logits_processors = True
+
+                custom_logits_processors = enable_custom_logits_processors()
 
                 if backend == "transformers":
                     try:

+ 2 - 11
mineru/model/vlm_vllm_model/server.py

@@ -1,7 +1,6 @@
 import sys
 
-from loguru import logger
-
+from mineru.backend.vlm.custom_logits_processors import enable_custom_logits_processors
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 from vllm.entrypoints.cli.main import main as vllm_main
@@ -39,15 +38,7 @@ def main():
         for index in sorted(model_arg_indices, reverse=True):
             args.pop(index)
 
-    import torch
-    compute_capability = 0.0
-    custom_logits_processors = False
-    if torch.cuda.is_available():
-        major, minor = torch.cuda.get_device_capability()
-        compute_capability = float(major) + (float(minor) / 10.0)
-        logger.info(f"compute_capability: {compute_capability}")
-    if compute_capability >= 8.0:
-        custom_logits_processors = True
+    custom_logits_processors = enable_custom_logits_processors()
 
     # 添加默认参数
     if not has_port_arg: