Browse Source

refactor: enhance device mode detection to support CUDA and MPS

myhloli 5 months ago
parent
commit
7c8fb44b13
1 changed files with 6 additions and 2 deletions
  1. 6 2
      mineru/backend/pipeline/config_reader.py

+ 6 - 2
mineru/backend/pipeline/config_reader.py

@@ -2,6 +2,7 @@
 import json
 import os
 
+import torch
 from loguru import logger
 
 # 定义配置文件名常量
@@ -93,8 +94,11 @@ def get_device():
     if device_mode is not None:
         return device_mode
     else:
-        logger.warning(f"not found 'MINERU_DEVICE_MODE' in environment variable, use 'cpu' as default.")
-        return 'cpu'
+        if torch.cuda.is_available():
+            return "cuda"
+        if torch.backends.mps.is_available():
+            return "mps"
+        return "cpu"
 
 
 def get_table_recog_config():