vlm_analyze.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from ...data.data_reader_writer import DataWriter
  5. from mineru.utils.pdf_image_tools import load_images_from_pdf
  6. from .base_predictor import BasePredictor
  7. from .predictor import get_predictor
  8. from .token_to_middle_json import result_to_middle_json
  9. from ...utils.enum_class import ModelPath
  10. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  11. class ModelSingleton:
  12. _instance = None
  13. _models = {}
  14. def __new__(cls, *args, **kwargs):
  15. if cls._instance is None:
  16. cls._instance = super().__new__(cls)
  17. return cls._instance
  18. def get_model(
  19. self,
  20. backend: str,
  21. model_path: str | None,
  22. server_url: str | None,
  23. ) -> BasePredictor:
  24. key = (backend, model_path, server_url)
  25. if key not in self._models:
  26. if backend in ['transformers', 'sglang-engine'] and not model_path:
  27. model_path = auto_download_and_get_model_root_path("/","vlm")
  28. self._models[key] = get_predictor(
  29. backend=backend,
  30. model_path=model_path,
  31. server_url=server_url,
  32. )
  33. return self._models[key]
  34. def doc_analyze(
  35. pdf_bytes,
  36. image_writer: DataWriter | None,
  37. predictor: BasePredictor | None = None,
  38. backend="transformers",
  39. model_path: str | None = None,
  40. server_url: str | None = None,
  41. ):
  42. if predictor is None:
  43. predictor = ModelSingleton().get_model(backend, model_path, server_url)
  44. # load_images_start = time.time()
  45. images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
  46. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  47. # load_images_time = round(time.time() - load_images_start, 2)
  48. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  49. # infer_start = time.time()
  50. results = predictor.batch_predict(images=images_base64_list)
  51. # infer_time = round(time.time() - infer_start, 2)
  52. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  53. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  54. return middle_json, results
  55. async def aio_doc_analyze(
  56. pdf_bytes,
  57. image_writer: DataWriter | None,
  58. predictor: BasePredictor | None = None,
  59. backend="transformers",
  60. model_path: str | None = None,
  61. server_url: str | None = None,
  62. ):
  63. if predictor is None:
  64. predictor = ModelSingleton().get_model(backend, model_path, server_url)
  65. load_images_start = time.time()
  66. images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
  67. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  68. load_images_time = round(time.time() - load_images_start, 2)
  69. logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  70. infer_start = time.time()
  71. results = await predictor.aio_batch_predict(images=images_base64_list)
  72. infer_time = round(time.time() - infer_start, 2)
  73. logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  74. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  75. return middle_json