custom_logits_processors.py 1.0 KB

123456789101112131415161718192021222324252627282930313233
  1. import os
  2. from loguru import logger
  3. def enable_custom_logits_processors():
  4. import torch
  5. if not torch.cuda.is_available():
  6. logger.info("CUDA not available, disabling custom_logits_processors")
  7. return False
  8. major, minor = torch.cuda.get_device_capability()
  9. # 正确计算Compute Capability
  10. compute_capability = f"{major}.{minor}"
  11. compute_capability = float(compute_capability)
  12. # 安全地处理环境变量
  13. try:
  14. vllm_use_v1 = int(os.getenv('VLLM_USE_V1', "1"))
  15. except (ValueError, TypeError):
  16. vllm_use_v1 = 1
  17. logger.warning("Invalid VLLM_USE_V1 value")
  18. if vllm_use_v1 == 0:
  19. logger.info("VLLM_USE_V1 is set to 0, disabling custom_logits_processors")
  20. return False
  21. elif compute_capability < 8.0:
  22. logger.info(f"compute_capability: {compute_capability} < 8.0, disable custom_logits_processors")
  23. return False
  24. else:
  25. logger.info(f"compute_capability: {compute_capability} >= 8.0, enable custom_logits_processors")
  26. return True