vlm_analyze.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from .model_output_to_middle_json import result_to_middle_json
  5. from ...data.data_reader_writer import DataWriter
  6. from mineru.utils.pdf_image_tools import load_images_from_pdf
  7. from ...utils.enum_class import ImageType
  8. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  9. from mineru_vl_utils import MinerUClient
  10. class ModelSingleton:
  11. _instance = None
  12. _models = {}
  13. def __new__(cls, *args, **kwargs):
  14. if cls._instance is None:
  15. cls._instance = super().__new__(cls)
  16. return cls._instance
  17. def get_model(
  18. self,
  19. backend: str,
  20. model_path: str | None,
  21. server_url: str | None,
  22. **kwargs,
  23. ) -> MinerUClient:
  24. key = (backend, model_path, server_url)
  25. if key not in self._models:
  26. if backend in ['transformers', 'vllm-engine'] and not model_path:
  27. model_path = auto_download_and_get_model_root_path("/","vlm")
  28. self._models[key] = MinerUClient(
  29. backend=backend,
  30. model_path=model_path,
  31. server_url=server_url,
  32. **kwargs,
  33. )
  34. return self._models[key]
  35. def doc_analyze(
  36. pdf_bytes,
  37. image_writer: DataWriter | None,
  38. predictor: MinerUClient | None = None,
  39. backend="transformers",
  40. model_path: str | None = None,
  41. server_url: str | None = None,
  42. **kwargs,
  43. ):
  44. if predictor is None:
  45. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  46. # load_images_start = time.time()
  47. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  48. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  49. # load_images_time = round(time.time() - load_images_start, 2)
  50. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  51. # infer_start = time.time()
  52. results = predictor.batch_two_step_extract(images=images_pil_list)
  53. # infer_time = round(time.time() - infer_start, 2)
  54. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  55. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  56. return middle_json, results
  57. async def aio_doc_analyze(
  58. pdf_bytes,
  59. image_writer: DataWriter | None,
  60. predictor: MinerUClient | None = None,
  61. backend="transformers",
  62. model_path: str | None = None,
  63. server_url: str | None = None,
  64. **kwargs,
  65. ):
  66. if predictor is None:
  67. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  68. # load_images_start = time.time()
  69. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  70. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  71. # load_images_time = round(time.time() - load_images_start, 2)
  72. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  73. # infer_start = time.time()
  74. results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
  75. # infer_time = round(time.time() - infer_start, 2)
  76. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  77. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  78. return middle_json, results