Forráskód Böngészése

refactor(magic_pdf): remove bfloat16 support checks and usage

- Remove supports_bfloat16 variable and related checks
- Remove model.bfloat16() call for LayoutLMv3ForTokenClassification
- Simplify device selection logic
myhloli 8 hónapja
szülő
commit
9b00f988ac
1 módosított fájl, 0 hozzáadás és 10 törlés
  1. 0 10
      magic_pdf/pdf_parse_union_core_v2.py

+ 0 - 10
magic_pdf/pdf_parse_union_core_v2.py

@@ -341,21 +341,14 @@ def model_init(model_name: str):
     device = get_device()
     device = get_device()
     if torch.cuda.is_available():
     if torch.cuda.is_available():
         device = torch.device('cuda')
         device = torch.device('cuda')
-        if torch.cuda.is_bf16_supported():
-            supports_bfloat16 = True
-        else:
-            supports_bfloat16 = False
     elif str(device).startswith("npu"):
     elif str(device).startswith("npu"):
         import torch_npu
         import torch_npu
         if torch_npu.npu.is_available():
         if torch_npu.npu.is_available():
             device = torch.device('npu')
             device = torch.device('npu')
-            supports_bfloat16 = False
         else:
         else:
             device = torch.device('cpu')
             device = torch.device('cpu')
-            supports_bfloat16 = False
     else:
     else:
         device = torch.device('cpu')
         device = torch.device('cpu')
-        supports_bfloat16 = False
 
 
     if model_name == 'layoutreader':
     if model_name == 'layoutreader':
         # 检测modelscope的缓存目录是否存在
         # 检测modelscope的缓存目录是否存在
@@ -371,9 +364,6 @@ def model_init(model_name: str):
             model = LayoutLMv3ForTokenClassification.from_pretrained(
             model = LayoutLMv3ForTokenClassification.from_pretrained(
                 'hantian/layoutreader'
                 'hantian/layoutreader'
             )
             )
-        # 检查设备是否支持 bfloat16
-        if supports_bfloat16:
-            model.bfloat16()
         model.to(device).eval()
         model.to(device).eval()
     else:
     else:
         logger.error('model name not allow')
         logger.error('model name not allow')