vlm_analyze.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import time
  4. from loguru import logger
  5. from .utils import enable_custom_logits_processors, set_defult_gpu_memory_utilization, set_defult_batch_size
  6. from .model_output_to_middle_json import result_to_middle_json
  7. from ...data.data_reader_writer import DataWriter
  8. from mineru.utils.pdf_image_tools import load_images_from_pdf
  9. from ...utils.config_reader import get_device
  10. from ...utils.enum_class import ImageType
  11. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  12. from mineru_vl_utils import MinerUClient
  13. from packaging import version
  14. class ModelSingleton:
  15. _instance = None
  16. _models = {}
  17. def __new__(cls, *args, **kwargs):
  18. if cls._instance is None:
  19. cls._instance = super().__new__(cls)
  20. return cls._instance
  21. def get_model(
  22. self,
  23. backend: str,
  24. model_path: str | None,
  25. server_url: str | None,
  26. **kwargs,
  27. ) -> MinerUClient:
  28. key = (backend, model_path, server_url)
  29. if key not in self._models:
  30. start_time = time.time()
  31. model = None
  32. processor = None
  33. vllm_llm = None
  34. vllm_async_llm = None
  35. batch_size = kwargs.get("batch_size", 0) # for transformers backend only
  36. max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
  37. http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
  38. # 从kwargs中移除这些参数,避免传递给不相关的初始化函数
  39. for param in ["batch_size", "max_concurrency", "http_timeout"]:
  40. if param in kwargs:
  41. del kwargs[param]
  42. if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
  43. model_path = auto_download_and_get_model_root_path("/","vlm")
  44. if backend == "transformers":
  45. try:
  46. from transformers import (
  47. AutoProcessor,
  48. Qwen2VLForConditionalGeneration,
  49. )
  50. from transformers import __version__ as transformers_version
  51. except ImportError:
  52. raise ImportError("Please install transformers to use the transformers backend.")
  53. if version.parse(transformers_version) >= version.parse("4.56.0"):
  54. dtype_key = "dtype"
  55. else:
  56. dtype_key = "torch_dtype"
  57. device = get_device()
  58. model = Qwen2VLForConditionalGeneration.from_pretrained(
  59. model_path,
  60. device_map={"": device},
  61. **{dtype_key: "auto"}, # type: ignore
  62. )
  63. processor = AutoProcessor.from_pretrained(
  64. model_path,
  65. use_fast=True,
  66. )
  67. if batch_size == 0:
  68. batch_size = set_defult_batch_size()
  69. else:
  70. if os.getenv('OMP_NUM_THREADS') is None:
  71. os.environ["OMP_NUM_THREADS"] = "1"
  72. if backend == "vllm-engine":
  73. try:
  74. import vllm
  75. from mineru_vl_utils import MinerULogitsProcessor
  76. except ImportError:
  77. raise ImportError("Please install vllm to use the vllm-engine backend.")
  78. if "gpu_memory_utilization" not in kwargs:
  79. kwargs["gpu_memory_utilization"] = set_defult_gpu_memory_utilization()
  80. if "model" not in kwargs:
  81. kwargs["model"] = model_path
  82. if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
  83. kwargs["logits_processors"] = [MinerULogitsProcessor]
  84. # 使用kwargs为 vllm初始化参数
  85. vllm_llm = vllm.LLM(**kwargs)
  86. elif backend == "vllm-async-engine":
  87. try:
  88. from vllm.engine.arg_utils import AsyncEngineArgs
  89. from vllm.v1.engine.async_llm import AsyncLLM
  90. from mineru_vl_utils import MinerULogitsProcessor
  91. except ImportError:
  92. raise ImportError("Please install vllm to use the vllm-async-engine backend.")
  93. if "gpu_memory_utilization" not in kwargs:
  94. kwargs["gpu_memory_utilization"] = set_defult_gpu_memory_utilization()
  95. if "model" not in kwargs:
  96. kwargs["model"] = model_path
  97. if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
  98. kwargs["logits_processors"] = [MinerULogitsProcessor]
  99. # 使用kwargs为 vllm初始化参数
  100. vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
  101. self._models[key] = MinerUClient(
  102. backend=backend,
  103. model=model,
  104. processor=processor,
  105. vllm_llm=vllm_llm,
  106. vllm_async_llm=vllm_async_llm,
  107. server_url=server_url,
  108. batch_size=batch_size,
  109. max_concurrency=max_concurrency,
  110. http_timeout=http_timeout,
  111. )
  112. elapsed = round(time.time() - start_time, 2)
  113. logger.info(f"get {backend} predictor cost: {elapsed}s")
  114. return self._models[key]
  115. def doc_analyze(
  116. pdf_bytes,
  117. image_writer: DataWriter | None,
  118. predictor: MinerUClient | None = None,
  119. backend="transformers",
  120. model_path: str | None = None,
  121. server_url: str | None = None,
  122. **kwargs,
  123. ):
  124. if predictor is None:
  125. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  126. # load_images_start = time.time()
  127. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  128. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  129. # load_images_time = round(time.time() - load_images_start, 2)
  130. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  131. # infer_start = time.time()
  132. results = predictor.batch_two_step_extract(images=images_pil_list)
  133. # infer_time = round(time.time() - infer_start, 2)
  134. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  135. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  136. return middle_json, results
  137. async def aio_doc_analyze(
  138. pdf_bytes,
  139. image_writer: DataWriter | None,
  140. predictor: MinerUClient | None = None,
  141. backend="transformers",
  142. model_path: str | None = None,
  143. server_url: str | None = None,
  144. **kwargs,
  145. ):
  146. if predictor is None:
  147. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  148. # load_images_start = time.time()
  149. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  150. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  151. # load_images_time = round(time.time() - load_images_start, 2)
  152. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  153. # infer_start = time.time()
  154. results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
  155. # infer_time = round(time.time() - infer_start, 2)
  156. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  157. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  158. return middle_json, results