device_utils.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. """
  2. 设备检测工具
  3. 从 MinerU 移植,用于自动检测可用设备 (CUDA/MPS/NPU/CPU)
  4. """
  5. import os
  6. try:
  7. import torch
  8. except ImportError:
  9. torch = None
  10. try:
  11. import torch_npu
  12. except ImportError:
  13. torch_npu = None
  14. def get_device():
  15. """
  16. 自动检测并返回可用的设备
  17. 优先级: CUDA > MPS > NPU > CPU
  18. Returns:
  19. str: 设备名称 ('cuda', 'mps', 'npu', 'cpu')
  20. Environment Variables:
  21. MINERU_DEVICE_MODE: 强制指定设备模式
  22. """
  23. # 支持通过环境变量强制指定设备
  24. device_mode = os.getenv('MINERU_DEVICE_MODE', None)
  25. if device_mode is not None:
  26. return device_mode
  27. # 如果没有 torch,返回 cpu
  28. if torch is None:
  29. return "cpu"
  30. # 自动检测
  31. if torch.cuda.is_available():
  32. return "cuda"
  33. elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
  34. return "mps"
  35. else:
  36. # 尝试检测华为 NPU
  37. try:
  38. if torch_npu is not None and torch_npu.npu.is_available():
  39. return "npu"
  40. except Exception:
  41. pass
  42. return "cpu"
  43. def get_device_name():
  44. """
  45. 获取设备的友好名称
  46. Returns:
  47. str: 设备的友好名称
  48. """
  49. device = get_device()
  50. device_names = {
  51. "cuda": "NVIDIA CUDA",
  52. "mps": "Apple Metal (MPS)",
  53. "npu": "Huawei NPU",
  54. "cpu": "CPU"
  55. }
  56. return device_names.get(device, device.upper())
  57. if __name__ == "__main__":
  58. """测试设备检测"""
  59. print(f"🔍 Detecting available device...")
  60. device = get_device()
  61. device_name = get_device_name()
  62. print(f"✅ Device: {device} ({device_name})")
  63. # 测试 torch 设备
  64. if torch is not None:
  65. try:
  66. test_tensor = torch.tensor([1.0, 2.0, 3.0])
  67. if device != "cpu":
  68. test_tensor = test_tensor.to(device)
  69. print(f"✅ Torch tensor moved to {device}")
  70. print(f" Tensor device: {test_tensor.device}")
  71. except Exception as e:
  72. print(f"⚠️ Failed to move tensor to {device}: {e}")