|
@@ -2,6 +2,7 @@
|
|
|
import json
|
|
import json
|
|
|
import os
|
|
import os
|
|
|
|
|
|
|
|
|
|
+import torch
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
|
|
|
|
|
# 定义配置文件名常量
|
|
# 定义配置文件名常量
|
|
@@ -93,8 +94,11 @@ def get_device():
|
|
|
if device_mode is not None:
|
|
if device_mode is not None:
|
|
|
return device_mode
|
|
return device_mode
|
|
|
else:
|
|
else:
|
|
|
- logger.warning(f"not found 'MINERU_DEVICE_MODE' in environment variable, use 'cpu' as default.")
|
|
|
|
|
- return 'cpu'
|
|
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
|
+ return "cuda"
|
|
|
|
|
+ if torch.backends.mps.is_available():
|
|
|
|
|
+ return "mps"
|
|
|
|
|
+ return "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_table_recog_config():
|
|
def get_table_recog_config():
|