vlm_analyze.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import time
  4. from loguru import logger
  5. from .custom_logits_processors import enable_custom_logits_processors
  6. from .model_output_to_middle_json import result_to_middle_json
  7. from ...data.data_reader_writer import DataWriter
  8. from mineru.utils.pdf_image_tools import load_images_from_pdf
  9. from ...utils.config_reader import get_device
  10. from ...utils.enum_class import ImageType
  11. from ...utils.model_utils import get_vram
  12. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  13. from mineru_vl_utils import MinerUClient
  14. from packaging import version
  15. class ModelSingleton:
  16. _instance = None
  17. _models = {}
  18. def __new__(cls, *args, **kwargs):
  19. if cls._instance is None:
  20. cls._instance = super().__new__(cls)
  21. return cls._instance
  22. def get_model(
  23. self,
  24. backend: str,
  25. model_path: str | None,
  26. server_url: str | None,
  27. **kwargs,
  28. ) -> MinerUClient:
  29. key = (backend, model_path, server_url)
  30. if key not in self._models:
  31. start_time = time.time()
  32. model = None
  33. processor = None
  34. vllm_llm = None
  35. vllm_async_llm = None
  36. batch_size = 0
  37. max_concurrency = kwargs.get("max_concurrency", 100)
  38. if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
  39. model_path = auto_download_and_get_model_root_path("/","vlm")
  40. if backend == "transformers":
  41. try:
  42. from transformers import (
  43. AutoProcessor,
  44. Qwen2VLForConditionalGeneration,
  45. )
  46. from transformers import __version__ as transformers_version
  47. except ImportError:
  48. raise ImportError("Please install transformers to use the transformers backend.")
  49. if version.parse(transformers_version) >= version.parse("4.56.0"):
  50. dtype_key = "dtype"
  51. else:
  52. dtype_key = "torch_dtype"
  53. device = get_device()
  54. model = Qwen2VLForConditionalGeneration.from_pretrained(
  55. model_path,
  56. device_map={"": device},
  57. **{dtype_key: "auto"}, # type: ignore
  58. )
  59. processor = AutoProcessor.from_pretrained(
  60. model_path,
  61. use_fast=True,
  62. )
  63. try:
  64. vram = get_vram(device)
  65. if vram is not None:
  66. gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
  67. if gpu_memory >= 16:
  68. batch_size = 8
  69. elif gpu_memory >= 8:
  70. batch_size = 4
  71. else:
  72. batch_size = 1
  73. logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
  74. else:
  75. # Default batch_ratio when VRAM can't be determined
  76. batch_size = 1
  77. logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_size}')
  78. except Exception as e:
  79. logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
  80. batch_size = 1
  81. else:
  82. os.environ["OMP_NUM_THREADS"] = "1"
  83. if backend == "vllm-engine":
  84. try:
  85. import vllm
  86. from mineru_vl_utils import MinerULogitsProcessor
  87. except ImportError:
  88. raise ImportError("Please install vllm to use the vllm-engine backend.")
  89. if "gpu_memory_utilization" not in kwargs:
  90. kwargs["gpu_memory_utilization"] = 0.5
  91. if "model" not in kwargs:
  92. kwargs["model"] = model_path
  93. if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
  94. kwargs["logits_processors"] = [MinerULogitsProcessor]
  95. # 使用kwargs为 vllm初始化参数
  96. vllm_llm = vllm.LLM(**kwargs)
  97. elif backend == "vllm-async-engine":
  98. try:
  99. from vllm.engine.arg_utils import AsyncEngineArgs
  100. from vllm.v1.engine.async_llm import AsyncLLM
  101. from mineru_vl_utils import MinerULogitsProcessor
  102. except ImportError:
  103. raise ImportError("Please install vllm to use the vllm-async-engine backend.")
  104. if "gpu_memory_utilization" not in kwargs:
  105. kwargs["gpu_memory_utilization"] = 0.5
  106. if "model" not in kwargs:
  107. kwargs["model"] = model_path
  108. if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
  109. kwargs["logits_processors"] = [MinerULogitsProcessor]
  110. # 使用kwargs为 vllm初始化参数
  111. vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
  112. self._models[key] = MinerUClient(
  113. backend=backend,
  114. model=model,
  115. processor=processor,
  116. vllm_llm=vllm_llm,
  117. vllm_async_llm=vllm_async_llm,
  118. server_url=server_url,
  119. batch_size=batch_size,
  120. max_concurrency=max_concurrency,
  121. )
  122. elapsed = round(time.time() - start_time, 2)
  123. logger.info(f"get {backend} predictor cost: {elapsed}s")
  124. return self._models[key]
  125. def doc_analyze(
  126. pdf_bytes,
  127. image_writer: DataWriter | None,
  128. predictor: MinerUClient | None = None,
  129. backend="transformers",
  130. model_path: str | None = None,
  131. server_url: str | None = None,
  132. **kwargs,
  133. ):
  134. if predictor is None:
  135. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  136. # load_images_start = time.time()
  137. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  138. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  139. # load_images_time = round(time.time() - load_images_start, 2)
  140. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  141. # infer_start = time.time()
  142. results = predictor.batch_two_step_extract(images=images_pil_list)
  143. # infer_time = round(time.time() - infer_start, 2)
  144. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  145. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  146. return middle_json, results
  147. async def aio_doc_analyze(
  148. pdf_bytes,
  149. image_writer: DataWriter | None,
  150. predictor: MinerUClient | None = None,
  151. backend="transformers",
  152. model_path: str | None = None,
  153. server_url: str | None = None,
  154. **kwargs,
  155. ):
  156. if predictor is None:
  157. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  158. # load_images_start = time.time()
  159. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  160. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  161. # load_images_time = round(time.time() - load_images_start, 2)
  162. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  163. # infer_start = time.time()
  164. results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
  165. # infer_time = round(time.time() - infer_start, 2)
  166. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  167. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  168. return middle_json, results