|
|
@@ -1,4 +1,5 @@
|
|
|
# Copyright (c) Opendatalab. All rights reserved.
|
|
|
+import os
|
|
|
import time
|
|
|
|
|
|
from loguru import logger
|
|
|
@@ -6,8 +7,10 @@ from loguru import logger
|
|
|
from .model_output_to_middle_json import result_to_middle_json
|
|
|
from ...data.data_reader_writer import DataWriter
|
|
|
from mineru.utils.pdf_image_tools import load_images_from_pdf
|
|
|
+from ...utils.config_reader import get_device
|
|
|
|
|
|
from ...utils.enum_class import ImageType
|
|
|
+from ...utils.model_utils import get_vram
|
|
|
from ...utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
|
|
from mineru_vl_utils import MinerUClient
|
|
|
@@ -36,6 +39,7 @@ class ModelSingleton:
|
|
|
processor = None
|
|
|
vllm_llm = None
|
|
|
vllm_async_llm = None
|
|
|
+ batch_size = 0
|
|
|
if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
|
|
|
model_path = auto_download_and_get_model_root_path("/","vlm")
|
|
|
if backend == "transformers":
|
|
|
@@ -53,15 +57,34 @@ class ModelSingleton:
|
|
|
dtype_key = "dtype"
|
|
|
else:
|
|
|
dtype_key = "torch_dtype"
|
|
|
+ device = get_device()
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
model_path,
|
|
|
- device_map="auto",
|
|
|
+ device_map={"": device},
|
|
|
**{dtype_key: "auto"}, # type: ignore
|
|
|
)
|
|
|
processor = AutoProcessor.from_pretrained(
|
|
|
model_path,
|
|
|
use_fast=True,
|
|
|
)
|
|
|
+ try:
|
|
|
+ vram = get_vram(device)
|
|
|
+ if vram is not None:
|
|
|
+ gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
|
|
|
+ if gpu_memory >= 16:
|
|
|
+ batch_size = 8
|
|
|
+ elif gpu_memory >= 8:
|
|
|
+ batch_size = 4
|
|
|
+ else:
|
|
|
+ batch_size = 1
|
|
|
+ logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
|
|
|
+ else:
|
|
|
+ # Default batch_ratio when VRAM can't be determined
|
|
|
+ batch_size = 1
|
|
|
+ logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
|
|
|
+ batch_size = 1
|
|
|
elif backend == "vllm-engine":
|
|
|
try:
|
|
|
import vllm
|
|
|
@@ -92,6 +115,7 @@ class ModelSingleton:
|
|
|
vllm_llm=vllm_llm,
|
|
|
vllm_async_llm=vllm_async_llm,
|
|
|
server_url=server_url,
|
|
|
+ batch_size=batch_size,
|
|
|
)
|
|
|
elapsed = round(time.time() - start_time, 2)
|
|
|
logger.info(f"get {backend} predictor cost: {elapsed}s")
|