custom_logits_processors.py 515 B

1234567891011121314
  1. from loguru import logger
  2. def enable_custom_logits_processors():
  3. import torch
  4. compute_capability = 0.0
  5. custom_logits_processors = False
  6. if torch.cuda.is_available():
  7. major, minor = torch.cuda.get_device_capability()
  8. compute_capability = float(major) + (float(minor) / 10.0)
  9. if compute_capability >= 8.0:
  10. logger.info(f"compute_capability: {compute_capability}, enable custom_logits_processors")
  11. custom_logits_processors = True
  12. return custom_logits_processors