Browse Source

feat: enhance device detection to support NPU and improve error handling

myhloli 5 months ago
parent
commit
7ae4f80dad
2 changed files with 20 additions and 6 deletions
  1. 12 5
      mineru/backend/pipeline/pipeline_analyze.py
  2. 8 1
      mineru/utils/config_reader.py

+ 12 - 5
mineru/backend/pipeline/pipeline_analyze.py

@@ -1,6 +1,7 @@
 import os
 import time
-import numpy as np
+from typing import List, Tuple
+import PIL.Image
 import torch
 
 from .model_init import MineruPipelineModel
@@ -150,7 +151,7 @@ def doc_analyze(
 
 
 def batch_image_analyze(
-        images_with_extra_info: list[(np.ndarray, bool, str)],
+        images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
         formula_enable=None,
         table_enable=None):
     # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
@@ -163,9 +164,15 @@ def batch_image_analyze(
     device = get_device()
 
     if str(device).startswith('npu'):
-        import torch_npu
-        if torch_npu.npu.is_available():
-            torch.npu.set_compile_mode(jit_compile=False)
+        try:
+            import torch_npu
+            if torch_npu.npu.is_available():
+                torch.npu.set_compile_mode(jit_compile=False)
+        except Exception as e:
+            raise RuntimeError(
+                "NPU is selected as device, but torch_npu is not available. "
+                "Please ensure that the torch_npu package is installed correctly."
+            ) from e
 
     if str(device).startswith('npu') or str(device).startswith('cuda'):
         vram = get_vram(device)

+ 8 - 1
mineru/utils/config_reader.py

@@ -74,8 +74,15 @@ def get_device():
     else:
         if torch.cuda.is_available():
             return "cuda"
-        if torch.backends.mps.is_available():
+        elif torch.backends.mps.is_available():
             return "mps"
+        else:
+            try:
+                import torch_npu
+                if torch_npu.npu.is_available():
+                    return "npu"
+            except Exception as e:
+                pass
         return "cpu"