|
|
@@ -4,14 +4,13 @@ import time
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
-from .custom_logits_processors import enable_custom_logits_processors
|
|
|
+from .utils import enable_custom_logits_processors, set_defult_gpu_memory_utilization, set_defult_batch_size
|
|
|
from .model_output_to_middle_json import result_to_middle_json
|
|
|
from ...data.data_reader_writer import DataWriter
|
|
|
from mineru.utils.pdf_image_tools import load_images_from_pdf
|
|
|
from ...utils.config_reader import get_device
|
|
|
|
|
|
from ...utils.enum_class import ImageType
|
|
|
-from ...utils.model_utils import get_vram
|
|
|
from ...utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
|
|
from mineru_vl_utils import MinerUClient
|
|
|
@@ -41,8 +40,13 @@ class ModelSingleton:
|
|
|
processor = None
|
|
|
vllm_llm = None
|
|
|
vllm_async_llm = None
|
|
|
- batch_size = 0
|
|
|
- max_concurrency = kwargs.get("max_concurrency", 100)
|
|
|
+ batch_size = kwargs.get("batch_size", 0) # for transformers backend only
|
|
|
+ max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
|
|
|
+ http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
|
|
|
+ # 从kwargs中移除这些参数,避免传递给不相关的初始化函数
|
|
|
+ for param in ["batch_size", "max_concurrency", "http_timeout"]:
|
|
|
+ if param in kwargs:
|
|
|
+ del kwargs[param]
|
|
|
if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
|
|
|
model_path = auto_download_and_get_model_root_path("/","vlm")
|
|
|
if backend == "transformers":
|
|
|
@@ -69,24 +73,8 @@ class ModelSingleton:
|
|
|
model_path,
|
|
|
use_fast=True,
|
|
|
)
|
|
|
- try:
|
|
|
- vram = get_vram(device)
|
|
|
- if vram is not None:
|
|
|
- gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
|
|
|
- if gpu_memory >= 16:
|
|
|
- batch_size = 8
|
|
|
- elif gpu_memory >= 8:
|
|
|
- batch_size = 4
|
|
|
- else:
|
|
|
- batch_size = 1
|
|
|
- logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
|
|
|
- else:
|
|
|
- # Default batch_ratio when VRAM can't be determined
|
|
|
- batch_size = 1
|
|
|
- logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
|
|
|
- batch_size = 1
|
|
|
+ if batch_size == 0:
|
|
|
+ batch_size = set_defult_batch_size()
|
|
|
else:
|
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
|
if backend == "vllm-engine":
|
|
|
@@ -96,7 +84,7 @@ class ModelSingleton:
|
|
|
except ImportError:
|
|
|
raise ImportError("Please install vllm to use the vllm-engine backend.")
|
|
|
if "gpu_memory_utilization" not in kwargs:
|
|
|
- kwargs["gpu_memory_utilization"] = 0.7
|
|
|
+ kwargs["gpu_memory_utilization"] = set_defult_gpu_memory_utilization()
|
|
|
if "model" not in kwargs:
|
|
|
kwargs["model"] = model_path
|
|
|
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
|
|
|
@@ -111,7 +99,7 @@ class ModelSingleton:
|
|
|
except ImportError:
|
|
|
raise ImportError("Please install vllm to use the vllm-async-engine backend.")
|
|
|
if "gpu_memory_utilization" not in kwargs:
|
|
|
- kwargs["gpu_memory_utilization"] = 0.7
|
|
|
+ kwargs["gpu_memory_utilization"] = set_defult_gpu_memory_utilization()
|
|
|
if "model" not in kwargs:
|
|
|
kwargs["model"] = model_path
|
|
|
if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
|
|
|
@@ -127,6 +115,7 @@ class ModelSingleton:
|
|
|
server_url=server_url,
|
|
|
batch_size=batch_size,
|
|
|
max_concurrency=max_concurrency,
|
|
|
+ http_timeout=http_timeout,
|
|
|
)
|
|
|
elapsed = round(time.time() - start_time, 2)
|
|
|
logger.info(f"get {backend} predictor cost: {elapsed}s")
|