Selaa lähdekoodia

feat: refactor device mode retrieval to use get_device utility

myhloli 5 kuukautta sitten
vanhempi
commit
5f1a509fdd
1 muutettua tiedostoa jossa 4 lisäystä ja 5 poistoa
  1. 4 5
      mineru/cli/client.py

+ 4 - 5
mineru/cli/client.py

@@ -4,6 +4,8 @@ import click
 from pathlib import Path
 import torch
 from loguru import logger
+
+from mineru.utils.config_reader import get_device
 from mineru.utils.model_utils import get_vram
 from ..version import __version__
 from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@@ -144,11 +146,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
     def get_device_mode() -> str:
         if device_mode is not None:
             return device_mode
-        if torch.cuda.is_available():
-            return "cuda"
-        if torch.backends.mps.is_available():
-            return "mps"
-        return "cpu"
+        else:
+            return get_device()
     if os.getenv('MINERU_DEVICE_MODE', None) is None:
         os.environ['MINERU_DEVICE_MODE'] = get_device_mode()