Prechádzať zdrojové kódy

Merge pull request #1787 from myhloli/dev

refactor(magic_pdf): remove bfloat16 support checks and usage
Xiaomeng Zhao 8 mesiacov pred
rodič
commit
71b5024a34
1 zmenil súbory, kde vykonal 1 pridanie a 21 odobranie
  1. 1 21
      magic_pdf/pdf_parse_union_core_v2.py

+ 1 - 21
magic_pdf/pdf_parse_union_core_v2.py

@@ -338,24 +338,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
 
 def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
-    device = get_device()
-    if torch.cuda.is_available():
-        device = torch.device('cuda')
-        if torch.cuda.is_bf16_supported():
-            supports_bfloat16 = True
-        else:
-            supports_bfloat16 = False
-    elif str(device).startswith("npu"):
-        import torch_npu
-        if torch_npu.npu.is_available():
-            device = torch.device('npu')
-            supports_bfloat16 = False
-        else:
-            device = torch.device('cpu')
-            supports_bfloat16 = False
-    else:
-        device = torch.device('cpu')
-        supports_bfloat16 = False
+    device = torch.device(get_device())
 
     if model_name == 'layoutreader':
         # 检测modelscope的缓存目录是否存在
@@ -371,9 +354,6 @@ def model_init(model_name: str):
             model = LayoutLMv3ForTokenClassification.from_pretrained(
                 'hantian/layoutreader'
             )
-        # 检查设备是否支持 bfloat16
-        if supports_bfloat16:
-            model.bfloat16()
         model.to(device).eval()
     else:
         logger.error('model name not allow')