utils.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. from loguru import logger
  3. from packaging import version
  4. from mineru.utils.config_reader import get_device
  5. from mineru.utils.model_utils import get_vram
  6. def enable_custom_logits_processors() -> bool:
  7. import torch
  8. from vllm import __version__ as vllm_version
  9. if not torch.cuda.is_available():
  10. logger.info("CUDA not available, disabling custom_logits_processors")
  11. return False
  12. major, minor = torch.cuda.get_device_capability()
  13. # 正确计算Compute Capability
  14. compute_capability = f"{major}.{minor}"
  15. # 安全地处理环境变量
  16. vllm_use_v1_str = os.getenv('VLLM_USE_V1', "1")
  17. if vllm_use_v1_str.isdigit():
  18. vllm_use_v1 = int(vllm_use_v1_str)
  19. else:
  20. vllm_use_v1 = 1
  21. if vllm_use_v1 == 0:
  22. logger.info("VLLM_USE_V1 is set to 0, disabling custom_logits_processors")
  23. return False
  24. elif version.parse(vllm_version) < version.parse("0.10.1"):
  25. logger.info(f"vllm version: {vllm_version} < 0.10.1, disable custom_logits_processors")
  26. return False
  27. elif version.parse(compute_capability) < version.parse("8.0"):
  28. if version.parse(vllm_version) >= version.parse("0.10.2"):
  29. logger.info(f"compute_capability: {compute_capability} < 8.0, but vllm version: {vllm_version} >= 0.10.2, enable custom_logits_processors")
  30. return True
  31. else:
  32. logger.info(f"compute_capability: {compute_capability} < 8.0 and vllm version: {vllm_version} < 0.10.2, disable custom_logits_processors")
  33. return False
  34. else:
  35. logger.info(f"compute_capability: {compute_capability} >= 8.0 and vllm version: {vllm_version} >= 0.10.1, enable custom_logits_processors")
  36. return True
  37. def set_defult_gpu_memory_utilization() -> float:
  38. from vllm import __version__ as vllm_version
  39. if version.parse(vllm_version) >= version.parse("0.11.0"):
  40. return 0.7
  41. else:
  42. return 0.5
  43. def set_defult_batch_size() -> int:
  44. try:
  45. device = get_device()
  46. vram = get_vram(device)
  47. if vram is not None:
  48. gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
  49. if gpu_memory >= 16:
  50. batch_size = 8
  51. elif gpu_memory >= 8:
  52. batch_size = 4
  53. else:
  54. batch_size = 1
  55. logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
  56. else:
  57. # Default batch_ratio when VRAM can't be determined
  58. batch_size = 1
  59. logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
  60. except Exception as e:
  61. logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
  62. batch_size = 1
  63. return batch_size