|
|
@@ -15,7 +15,6 @@ def enable_custom_logits_processors():
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
|
# 正确计算Compute Capability
|
|
|
compute_capability = f"{major}.{minor}"
|
|
|
- compute_capability = float(compute_capability)
|
|
|
|
|
|
# 安全地处理环境变量
|
|
|
vllm_use_v1_str = os.getenv('VLLM_USE_V1', "1")
|
|
|
@@ -30,7 +29,7 @@ def enable_custom_logits_processors():
|
|
|
elif version.parse(vllm_version) < version.parse("0.10.1"):
|
|
|
logger.info(f"vllm version: {vllm_version} < 0.10.1, disable custom_logits_processors")
|
|
|
return False
|
|
|
- elif compute_capability < 8.0:
|
|
|
+ elif version.parse(compute_capability) < version.parse("8.0"):
|
|
|
if version.parse(vllm_version) >= version.parse("0.10.2"):
|
|
|
logger.info(f"compute_capability: {compute_capability} < 8.0, but vllm version: {vllm_version} >= 0.10.2, enable custom_logits_processors")
|
|
|
return True
|