vlm_analyze.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from .model_output_to_middle_json import result_to_middle_json
  5. from ...data.data_reader_writer import DataWriter
  6. from mineru.utils.pdf_image_tools import load_images_from_pdf
  7. from ...utils.enum_class import ImageType
  8. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  9. from mineru_vl_utils import MinerUClient
  10. class ModelSingleton:
  11. _instance = None
  12. _models = {}
  13. def __new__(cls, *args, **kwargs):
  14. if cls._instance is None:
  15. cls._instance = super().__new__(cls)
  16. return cls._instance
  17. def get_model(
  18. self,
  19. backend: str,
  20. model_path: str | None,
  21. server_url: str | None,
  22. **kwargs,
  23. ) -> MinerUClient:
  24. key = (backend, model_path, server_url)
  25. if key not in self._models:
  26. start_time = time.time()
  27. model = None
  28. processor = None
  29. vllm_llm = None
  30. vllm_async_llm = None
  31. if backend in ['transformers', 'vllm-engine', "vllm-async-engine"] and not model_path:
  32. model_path = auto_download_and_get_model_root_path("/","vlm")
  33. if backend == "transformers":
  34. if not model_path:
  35. raise ValueError("model_path must be provided when model or processor is None.")
  36. try:
  37. from transformers import (
  38. AutoProcessor,
  39. Qwen2VLForConditionalGeneration,
  40. )
  41. from transformers import __version__ as transformers_version
  42. except ImportError:
  43. raise ImportError("Please install transformers to use the transformers backend.")
  44. from packaging import version
  45. if version.parse(transformers_version) >= version.parse("4.56.0"):
  46. dtype_key = "dtype"
  47. else:
  48. dtype_key = "torch_dtype"
  49. model = Qwen2VLForConditionalGeneration.from_pretrained(
  50. model_path,
  51. device_map="auto",
  52. **{dtype_key: "auto"}, # type: ignore
  53. )
  54. processor = AutoProcessor.from_pretrained(
  55. model_path,
  56. use_fast=True,
  57. )
  58. elif backend == "vllm-engine":
  59. if not model_path:
  60. raise ValueError("model_path must be provided when vllm_llm is None.")
  61. try:
  62. import vllm
  63. except ImportError:
  64. raise ImportError("Please install vllm to use the vllm-engine backend.")
  65. # logger.debug(kwargs)
  66. if "gpu_memory_utilization" not in kwargs:
  67. kwargs["gpu_memory_utilization"] = 0.5
  68. if "model" not in kwargs:
  69. kwargs["model"] = model_path
  70. # 使用kwargs为 vllm初始化参数
  71. vllm_llm = vllm.LLM(**kwargs)
  72. elif backend == "vllm-async-engine":
  73. if not model_path:
  74. raise ValueError("model_path must be provided when vllm_llm is None.")
  75. try:
  76. from vllm.engine.arg_utils import AsyncEngineArgs
  77. from vllm.v1.engine.async_llm import AsyncLLM
  78. except ImportError:
  79. raise ImportError("Please install vllm to use the vllm-async-engine backend.")
  80. # logger.debug(kwargs)
  81. if "gpu_memory_utilization" not in kwargs:
  82. kwargs["gpu_memory_utilization"] = 0.5
  83. if "model" not in kwargs:
  84. kwargs["model"] = model_path
  85. # 使用kwargs为 vllm初始化参数
  86. vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
  87. self._models[key] = MinerUClient(
  88. backend=backend,
  89. model=model,
  90. processor=processor,
  91. vllm_llm=vllm_llm,
  92. vllm_async_llm=vllm_async_llm,
  93. server_url=server_url,
  94. )
  95. elapsed = round(time.time() - start_time, 2)
  96. logger.info(f"get {backend} predictor cost: {elapsed}s")
  97. return self._models[key]
  98. def doc_analyze(
  99. pdf_bytes,
  100. image_writer: DataWriter | None,
  101. predictor: MinerUClient | None = None,
  102. backend="transformers",
  103. model_path: str | None = None,
  104. server_url: str | None = None,
  105. **kwargs,
  106. ):
  107. if predictor is None:
  108. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  109. # load_images_start = time.time()
  110. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  111. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  112. # load_images_time = round(time.time() - load_images_start, 2)
  113. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  114. # infer_start = time.time()
  115. results = predictor.batch_two_step_extract(images=images_pil_list)
  116. # infer_time = round(time.time() - infer_start, 2)
  117. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  118. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  119. return middle_json, results
  120. async def aio_doc_analyze(
  121. pdf_bytes,
  122. image_writer: DataWriter | None,
  123. predictor: MinerUClient | None = None,
  124. backend="transformers",
  125. model_path: str | None = None,
  126. server_url: str | None = None,
  127. **kwargs,
  128. ):
  129. if predictor is None:
  130. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  131. # load_images_start = time.time()
  132. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  133. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  134. # load_images_time = round(time.time() - load_images_start, 2)
  135. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  136. # infer_start = time.time()
  137. results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
  138. # infer_time = round(time.time() - infer_start, 2)
  139. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  140. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  141. return middle_json, results