Ver Fonte

fix: improve device compatibility check for bf16 support in model initialization

myhloli há 2 semanas atrás
pai
commit
279e84bf58
1 ficheiros alterados com 3 adições e 2 exclusões
  1. 3 2
      mineru/utils/block_sort.py

+ 3 - 2
mineru/utils/block_sort.py

@@ -179,13 +179,14 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
 def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
     device_name = get_device()
+    device = torch.device(device_name)
     bf_16_support = False
     if device_name.startswith("cuda"):
-        bf_16_support = torch.cuda.is_bf16_supported()
+        if torch.cuda.get_device_properties(device).major >= 8:
+            bf_16_support = True
     elif device_name.startswith("mps"):
         bf_16_support = True
 
-    device = torch.device(device_name)
     if model_name == 'layoutreader':
         # 检测modelscope的缓存目录是否存在
         layoutreader_model_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.layout_reader), ModelPath.layout_reader)