|
@@ -1,9 +1,15 @@
|
|
|
# Copyright (c) Opendatalab. All rights reserved.
|
|
# Copyright (c) Opendatalab. All rights reserved.
|
|
|
import json
|
|
import json
|
|
|
import os
|
|
import os
|
|
|
-
|
|
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
+try:
|
|
|
|
|
+ import torch
|
|
|
|
|
+ import torch_npu
|
|
|
|
|
+except ImportError:
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
# 定义配置文件名常量
|
|
# 定义配置文件名常量
|
|
|
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
|
|
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
|
|
|
|
|
|
|
@@ -71,15 +77,12 @@ def get_device():
|
|
|
if device_mode is not None:
|
|
if device_mode is not None:
|
|
|
return device_mode
|
|
return device_mode
|
|
|
else:
|
|
else:
|
|
|
- import torch
|
|
|
|
|
-
|
|
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
|
return "cuda"
|
|
return "cuda"
|
|
|
elif torch.backends.mps.is_available():
|
|
elif torch.backends.mps.is_available():
|
|
|
return "mps"
|
|
return "mps"
|
|
|
else:
|
|
else:
|
|
|
try:
|
|
try:
|
|
|
- import torch_npu
|
|
|
|
|
if torch_npu.npu.is_available():
|
|
if torch_npu.npu.is_available():
|
|
|
return "npu"
|
|
return "npu"
|
|
|
except Exception as e:
|
|
except Exception as e:
|