|
@@ -4,6 +4,8 @@ import click
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
import torch
|
|
import torch
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
|
|
+
|
|
|
|
|
+from mineru.utils.config_reader import get_device
|
|
|
from mineru.utils.model_utils import get_vram
|
|
from mineru.utils.model_utils import get_vram
|
|
|
from ..version import __version__
|
|
from ..version import __version__
|
|
|
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
|
|
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
|
|
@@ -144,11 +146,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
|
|
|
def get_device_mode() -> str:
|
|
def get_device_mode() -> str:
|
|
|
if device_mode is not None:
|
|
if device_mode is not None:
|
|
|
return device_mode
|
|
return device_mode
|
|
|
- if torch.cuda.is_available():
|
|
|
|
|
- return "cuda"
|
|
|
|
|
- if torch.backends.mps.is_available():
|
|
|
|
|
- return "mps"
|
|
|
|
|
- return "cpu"
|
|
|
|
|
|
|
+ else:
|
|
|
|
|
+ return get_device()
|
|
|
if os.getenv('MINERU_DEVICE_MODE', None) is None:
|
|
if os.getenv('MINERU_DEVICE_MODE', None) is None:
|
|
|
os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
|
|
os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
|
|
|
|
|
|