custom_logits_processors.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. from loguru import logger
  3. from packaging import version
  4. def enable_custom_logits_processors():
  5. import torch
  6. from vllm import __version__ as vllm_version
  7. if not torch.cuda.is_available():
  8. logger.info("CUDA not available, disabling custom_logits_processors")
  9. return False
  10. major, minor = torch.cuda.get_device_capability()
  11. # 正确计算Compute Capability
  12. compute_capability = f"{major}.{minor}"
  13. compute_capability = float(compute_capability)
  14. # 安全地处理环境变量
  15. vllm_use_v1_str = os.getenv('VLLM_USE_V1', "1")
  16. if vllm_use_v1_str.isdigit():
  17. vllm_use_v1 = int(vllm_use_v1_str)
  18. else:
  19. vllm_use_v1 = 1
  20. if vllm_use_v1 == 0:
  21. logger.info("VLLM_USE_V1 is set to 0, disabling custom_logits_processors")
  22. return False
  23. elif version.parse(vllm_version) < version.parse("0.10.1"):
  24. logger.info(f"vllm version: {vllm_version} < 0.10.1, disable custom_logits_processors")
  25. return False
  26. elif compute_capability < 8.0:
  27. if version.parse(vllm_version) >= version.parse("0.10.2"):
  28. logger.info(f"compute_capability: {compute_capability} < 8.0, but vllm version: {vllm_version} >= 0.10.2, enable custom_logits_processors")
  29. return True
  30. else:
  31. logger.info(f"compute_capability: {compute_capability} < 8.0 and vllm version: {vllm_version} < 0.10.2, disable custom_logits_processors")
  32. return False
  33. else:
  34. logger.info(f"compute_capability: {compute_capability} >= 8.0 and vllm version: {vllm_version} >= 0.10.1, enable custom_logits_processors")
  35. return True