Эх сурвалжийг харах

Merge pull request #1959 from myhloli/dev

Dev push
Xiaomeng Zhao 8 сар өмнө
parent
commit
07eaa2d7e5

+ 4 - 4
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -257,13 +257,13 @@ def may_batch_image_analyze(
     if str(device).startswith('npu') or str(device).startswith('cuda'):
     if str(device).startswith('npu') or str(device).startswith('cuda'):
         gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
         gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
         if gpu_memory is not None:
         if gpu_memory is not None:
-            if gpu_memory >= 20:
+            if gpu_memory >= 16:
                 batch_ratio = 16
                 batch_ratio = 16
-            elif gpu_memory >= 15:
+            elif gpu_memory >= 12:
                 batch_ratio = 8
                 batch_ratio = 8
-            elif gpu_memory >= 10:
+            elif gpu_memory >= 8:
                 batch_ratio = 4
                 batch_ratio = 4
-            elif gpu_memory >= 7:
+            elif gpu_memory >= 6:
                 batch_ratio = 2
                 batch_ratio = 2
             else:
             else:
                 batch_ratio = 1
                 batch_ratio = 1

+ 12 - 3
magic_pdf/pdf_parse_union_core_v2.py

@@ -333,8 +333,14 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
 
 
 def model_init(model_name: str):
 def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
     from transformers import LayoutLMv3ForTokenClassification
-    device = torch.device(get_device())
-
+    device_name = get_device()
+    bf_16_support = False
+    if device_name.startswith("cuda"):
+        bf_16_support = torch.cuda.is_bf16_supported()
+    elif device_name.startswith("mps"):
+        bf_16_support = True
+
+    device = torch.device(device_name)
     if model_name == 'layoutreader':
     if model_name == 'layoutreader':
         # 检测modelscope的缓存目录是否存在
         # 检测modelscope的缓存目录是否存在
         layoutreader_model_dir = get_local_layoutreader_model_dir()
         layoutreader_model_dir = get_local_layoutreader_model_dir()
@@ -349,7 +355,10 @@ def model_init(model_name: str):
             model = LayoutLMv3ForTokenClassification.from_pretrained(
             model = LayoutLMv3ForTokenClassification.from_pretrained(
                 'hantian/layoutreader'
                 'hantian/layoutreader'
             )
             )
-        model.to(device).eval()
+        if bf_16_support:
+            model.to(device).eval().bfloat16()
+        else:
+            model.to(device).eval()
     else:
     else:
         logger.error('model name not allow')
         logger.error('model name not allow')
         exit(1)
         exit(1)