| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- """
- 设备检测工具
- 从 MinerU 移植,用于自动检测可用设备 (CUDA/MPS/NPU/CPU)
- """
- import os
- import torch
- try:
- import torch_npu
- except ImportError:
- torch_npu = None
- def get_device():
- """
- 自动检测并返回可用的设备
-
- 优先级: CUDA > MPS > NPU > CPU
-
- Returns:
- str: 设备名称 ('cuda', 'mps', 'npu', 'cpu')
-
- Environment Variables:
- MINERU_DEVICE_MODE: 强制指定设备模式
- """
- # 支持通过环境变量强制指定设备
- device_mode = os.getenv('MINERU_DEVICE_MODE', None)
- if device_mode is not None:
- return device_mode
-
- # 自动检测
- if torch.cuda.is_available():
- return "cuda"
- elif torch.backends.mps.is_available():
- return "mps"
- else:
- # 尝试检测华为 NPU
- try:
- if torch_npu is not None and torch_npu.npu.is_available():
- return "npu"
- except Exception:
- pass
-
- return "cpu"
- def get_device_name():
- """
- 获取设备的友好名称
-
- Returns:
- str: 设备的友好名称
- """
- device = get_device()
-
- device_names = {
- "cuda": "NVIDIA CUDA",
- "mps": "Apple Metal (MPS)",
- "npu": "Huawei NPU",
- "cpu": "CPU"
- }
-
- return device_names.get(device, device.upper())
- if __name__ == "__main__":
- """测试设备检测"""
- print(f"🔍 Detecting available device...")
- device = get_device()
- device_name = get_device_name()
- print(f"✅ Device: {device} ({device_name})")
-
- # 测试 torch 设备
- try:
- test_tensor = torch.tensor([1.0, 2.0, 3.0])
- if device != "cpu":
- test_tensor = test_tensor.to(device)
- print(f"✅ Torch tensor moved to {device}")
- print(f" Tensor device: {test_tensor.device}")
- except Exception as e:
- print(f"⚠️ Failed to move tensor to {device}: {e}")
|