Bläddra i källkod

Enhance custom logits processors with improved compute capability checks and environment variable handling

myhloli 1 månad sedan
förälder
incheckning
5a0cf9af7f
1 ändrade filer med 28 tillägg och 9 borttagningar
  1. 28 9
      mineru/backend/vlm/custom_logits_processors.py

+ 28 - 9
mineru/backend/vlm/custom_logits_processors.py

@@ -1,14 +1,33 @@
+import os
+
 from loguru import logger
 
 
 def enable_custom_logits_processors():
     import torch
-    compute_capability = 0.0
-    custom_logits_processors = False
-    if torch.cuda.is_available():
-        major, minor = torch.cuda.get_device_capability()
-        compute_capability = float(major) + (float(minor) / 10.0)
-    if compute_capability >= 8.0:
-        logger.info(f"compute_capability: {compute_capability}, enable custom_logits_processors")
-        custom_logits_processors = True
-    return custom_logits_processors
+
+    if not torch.cuda.is_available():
+        logger.info("CUDA not available, disabling custom_logits_processors")
+        return False
+
+    major, minor = torch.cuda.get_device_capability()
+    # 正确计算Compute Capability
+    compute_capability = f"{major}.{minor}"
+    compute_capability = float(compute_capability)
+
+    # 安全地处理环境变量
+    try:
+        vllm_use_v1 = int(os.getenv('VLLM_USE_V1', "1"))
+    except (ValueError, TypeError):
+        vllm_use_v1 = 1
+        logger.warning("Invalid VLLM_USE_V1 value")
+
+    if vllm_use_v1 == 0:
+        logger.info("VLLM_USE_V1 is set to 0, disabling custom_logits_processors")
+        return False
+    elif compute_capability < 8.0:
+        logger.info(f"compute_capability: {compute_capability} < 8.0, disable custom_logits_processors")
+        return False
+    else:
+        logger.info(f"compute_capability: {compute_capability} >= 8.0, enable custom_logits_processors")
+        return True