|
|
@@ -4,7 +4,8 @@ import time
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
-from .utils import enable_custom_logits_processors, set_default_gpu_memory_utilization, set_default_batch_size
|
|
|
+from .utils import enable_custom_logits_processors, set_default_gpu_memory_utilization, set_default_batch_size, \
|
|
|
+ set_lmdeploy_backend
|
|
|
from .model_output_to_middle_json import result_to_middle_json
|
|
|
from ...data.data_reader_writer import DataWriter
|
|
|
from mineru.utils.pdf_image_tools import load_images_from_pdf
|
|
|
@@ -128,15 +129,23 @@ class ModelSingleton:
|
|
|
if "cache_max_entry_count" not in kwargs:
|
|
|
kwargs["cache_max_entry_count"] = 0.5
|
|
|
|
|
|
- # 默认使用 turbomind
|
|
|
- lm_backend = "turbomind"
|
|
|
- device = kwargs.get("device", "cuda").lower()
|
|
|
- # 特定设备强制使用 pytorch backend
|
|
|
- if device in ["ascend", "maca", "camb"]:
|
|
|
- lm_backend = "pytorch"
|
|
|
- backend_config = PytorchEngineConfig(**kwargs)
|
|
|
+ if "device" in kwargs:
|
|
|
+ device_type = kwargs.pop("device")
|
|
|
else:
|
|
|
+ device_type = os.getenv('MINERU_DEVICE_MODE', "cuda").lower()
|
|
|
+ # device_type 如果有则去除":"
|
|
|
+ if ":" in device_type:
|
|
|
+ device_type = device_type.split(":")[0]
|
|
|
+
|
|
|
+ lm_backend = set_lmdeploy_backend(device_type)
|
|
|
+
|
|
|
+ if lm_backend == "pytorch":
|
|
|
+ kwargs["device_type"] = device_type
|
|
|
+ backend_config = PytorchEngineConfig(**kwargs)
|
|
|
+ elif lm_backend == "turbomind":
|
|
|
backend_config = TurbomindEngineConfig(**kwargs)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
|
|
|
|
|
|
log_level = 'ERROR'
|
|
|
from lmdeploy.utils import get_logger
|