vlm_analyze.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import time
  4. from loguru import logger
  5. from .model_output_to_middle_json import result_to_middle_json
  6. from ...data.data_reader_writer import DataWriter
  7. from mineru.utils.pdf_image_tools import load_images_from_pdf
  8. from ...utils.config_reader import get_device
  9. from ...utils.enum_class import ImageType
  10. from ...utils.model_utils import get_vram
  11. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  12. from mineru_vl_utils import MinerUClient
  13. from packaging import version
  14. class ModelSingleton:
  15. _instance = None
  16. _models = {}
  17. def __new__(cls, *args, **kwargs):
  18. if cls._instance is None:
  19. cls._instance = super().__new__(cls)
  20. return cls._instance
  21. def get_model(
  22. self,
  23. backend: str,
  24. model_path: str | None,
  25. server_url: str | None,
  26. **kwargs,
  27. ) -> MinerUClient:
  28. key = (backend, model_path, server_url)
  29. if key not in self._models:
  30. start_time = time.time()
  31. model = None
  32. processor = None
  33. vllm_llm = None
  34. vllm_async_llm = None
  35. batch_size = 0
  36. if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
  37. model_path = auto_download_and_get_model_root_path("/","vlm")
  38. import torch
  39. compute_capability = 0.0
  40. custom_logits_processors = False
  41. if torch.cuda.is_available():
  42. major, minor = torch.cuda.get_device_capability()
  43. compute_capability = float(major) + (float(minor) / 10.0)
  44. logger.info(f"compute_capability: {compute_capability}")
  45. if compute_capability >= 8.0:
  46. custom_logits_processors = True
  47. if backend == "transformers":
  48. try:
  49. from transformers import (
  50. AutoProcessor,
  51. Qwen2VLForConditionalGeneration,
  52. )
  53. from transformers import __version__ as transformers_version
  54. except ImportError:
  55. raise ImportError("Please install transformers to use the transformers backend.")
  56. if version.parse(transformers_version) >= version.parse("4.56.0"):
  57. dtype_key = "dtype"
  58. else:
  59. dtype_key = "torch_dtype"
  60. device = get_device()
  61. model = Qwen2VLForConditionalGeneration.from_pretrained(
  62. model_path,
  63. device_map={"": device},
  64. **{dtype_key: "auto"}, # type: ignore
  65. )
  66. processor = AutoProcessor.from_pretrained(
  67. model_path,
  68. use_fast=True,
  69. )
  70. try:
  71. vram = get_vram(device)
  72. if vram is not None:
  73. gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
  74. if gpu_memory >= 16:
  75. batch_size = 8
  76. elif gpu_memory >= 8:
  77. batch_size = 4
  78. else:
  79. batch_size = 1
  80. logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
  81. else:
  82. # Default batch_ratio when VRAM can't be determined
  83. batch_size = 1
  84. logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
  85. except Exception as e:
  86. logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
  87. batch_size = 1
  88. elif backend == "vllm-engine":
  89. try:
  90. import vllm
  91. vllm_version = vllm.__version__
  92. from mineru_vl_utils import MinerULogitsProcessor
  93. except ImportError:
  94. raise ImportError("Please install vllm to use the vllm-engine backend.")
  95. if "gpu_memory_utilization" not in kwargs:
  96. kwargs["gpu_memory_utilization"] = 0.5
  97. if "model" not in kwargs:
  98. kwargs["model"] = model_path
  99. if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
  100. kwargs["logits_processors"] = [MinerULogitsProcessor]
  101. # 使用kwargs为 vllm初始化参数
  102. vllm_llm = vllm.LLM(**kwargs)
  103. elif backend == "vllm-async-engine":
  104. try:
  105. from vllm.engine.arg_utils import AsyncEngineArgs
  106. from vllm.v1.engine.async_llm import AsyncLLM
  107. from vllm import __version__ as vllm_version
  108. from mineru_vl_utils import MinerULogitsProcessor
  109. except ImportError:
  110. raise ImportError("Please install vllm to use the vllm-async-engine backend.")
  111. if "gpu_memory_utilization" not in kwargs:
  112. kwargs["gpu_memory_utilization"] = 0.5
  113. if "model" not in kwargs:
  114. kwargs["model"] = model_path
  115. if custom_logits_processors and version.parse(vllm_version) >= version.parse("0.10.1") and "logits_processors" not in kwargs:
  116. kwargs["logits_processors"] = [MinerULogitsProcessor]
  117. # 使用kwargs为 vllm初始化参数
  118. vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
  119. self._models[key] = MinerUClient(
  120. backend=backend,
  121. model=model,
  122. processor=processor,
  123. vllm_llm=vllm_llm,
  124. vllm_async_llm=vllm_async_llm,
  125. server_url=server_url,
  126. batch_size=batch_size,
  127. )
  128. elapsed = round(time.time() - start_time, 2)
  129. logger.info(f"get {backend} predictor cost: {elapsed}s")
  130. return self._models[key]
  131. def doc_analyze(
  132. pdf_bytes,
  133. image_writer: DataWriter | None,
  134. predictor: MinerUClient | None = None,
  135. backend="transformers",
  136. model_path: str | None = None,
  137. server_url: str | None = None,
  138. **kwargs,
  139. ):
  140. if predictor is None:
  141. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  142. # load_images_start = time.time()
  143. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  144. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  145. # load_images_time = round(time.time() - load_images_start, 2)
  146. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  147. # infer_start = time.time()
  148. results = predictor.batch_two_step_extract(images=images_pil_list)
  149. # infer_time = round(time.time() - infer_start, 2)
  150. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  151. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  152. return middle_json, results
  153. async def aio_doc_analyze(
  154. pdf_bytes,
  155. image_writer: DataWriter | None,
  156. predictor: MinerUClient | None = None,
  157. backend="transformers",
  158. model_path: str | None = None,
  159. server_url: str | None = None,
  160. **kwargs,
  161. ):
  162. if predictor is None:
  163. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  164. # load_images_start = time.time()
  165. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  166. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  167. # load_images_time = round(time.time() - load_images_start, 2)
  168. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  169. # infer_start = time.time()
  170. results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
  171. # infer_time = round(time.time() - infer_start, 2)
  172. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  173. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  174. return middle_json, results