|
@@ -190,7 +190,15 @@ def batch_image_analyze(
|
|
|
batch_ratio = 1
|
|
batch_ratio = 1
|
|
|
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
|
|
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
|
|
|
|
|
|
|
|
- batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
|
|
|
|
|
|
|
+ # 检测torch的版本号
|
|
|
|
|
+ import torch
|
|
|
|
|
+ from packaging import version
|
|
|
|
|
+ if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
|
|
|
|
|
+ enable_ocr_det_batch = False
|
|
|
|
|
+ else:
|
|
|
|
|
+ enable_ocr_det_batch = True
|
|
|
|
|
+
|
|
|
|
|
+ batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
|
|
|
results = batch_model(images_with_extra_info)
|
|
results = batch_model(images_with_extra_info)
|
|
|
|
|
|
|
|
clean_memory(get_device())
|
|
clean_memory(get_device())
|