device_utils.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """
  2. 设备检测工具
  3. 从 MinerU 移植,用于自动检测可用设备 (CUDA/MPS/NPU/CPU)
  4. """
  5. import os
  6. import torch
  7. try:
  8. import torch_npu
  9. except ImportError:
  10. torch_npu = None
  11. def get_device():
  12. """
  13. 自动检测并返回可用的设备
  14. 优先级: CUDA > MPS > NPU > CPU
  15. Returns:
  16. str: 设备名称 ('cuda', 'mps', 'npu', 'cpu')
  17. Environment Variables:
  18. MINERU_DEVICE_MODE: 强制指定设备模式
  19. """
  20. # 支持通过环境变量强制指定设备
  21. device_mode = os.getenv('MINERU_DEVICE_MODE', None)
  22. if device_mode is not None:
  23. return device_mode
  24. # 自动检测
  25. if torch.cuda.is_available():
  26. return "cuda"
  27. elif torch.backends.mps.is_available():
  28. return "mps"
  29. else:
  30. # 尝试检测华为 NPU
  31. try:
  32. if torch_npu is not None and torch_npu.npu.is_available():
  33. return "npu"
  34. except Exception:
  35. pass
  36. return "cpu"
  37. def get_device_name():
  38. """
  39. 获取设备的友好名称
  40. Returns:
  41. str: 设备的友好名称
  42. """
  43. device = get_device()
  44. device_names = {
  45. "cuda": "NVIDIA CUDA",
  46. "mps": "Apple Metal (MPS)",
  47. "npu": "Huawei NPU",
  48. "cpu": "CPU"
  49. }
  50. return device_names.get(device, device.upper())
  51. if __name__ == "__main__":
  52. """测试设备检测"""
  53. print(f"🔍 Detecting available device...")
  54. device = get_device()
  55. device_name = get_device_name()
  56. print(f"✅ Device: {device} ({device_name})")
  57. # 测试 torch 设备
  58. try:
  59. test_tensor = torch.tensor([1.0, 2.0, 3.0])
  60. if device != "cpu":
  61. test_tensor = test_tensor.to(device)
  62. print(f"✅ Torch tensor moved to {device}")
  63. print(f" Tensor device: {test_tensor.device}")
  64. except Exception as e:
  65. print(f"⚠️ Failed to move tensor to {device}: {e}")