Преглед на файлове

Merge pull request #3315 from opendatalab/fix-torch_2_8

Fix torch 2 8
Xiaomeng Zhao преди 3 месеца
родител
ревизия
30b698ecc5
променени са 1 файла, в които са добавени 9 реда и са изтрити 1 реда
  1. 9 1
      mineru/backend/pipeline/pipeline_analyze.py

+ 9 - 1
mineru/backend/pipeline/pipeline_analyze.py

@@ -190,7 +190,15 @@ def batch_image_analyze(
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
-    batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
+    # 检测torch的版本号
+    import torch
+    from packaging import version
+    if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
+        enable_ocr_det_batch = False
+    else:
+        enable_ocr_det_batch = True
+
+    batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
     results = batch_model(images_with_extra_info)
 
     clean_memory(get_device())