vlm_analyze.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from ...data.data_reader_writer import DataWriter
  5. from mineru.utils.pdf_image_tools import load_images_from_pdf
  6. from .base_predictor import BasePredictor
  7. from .predictor import get_predictor
  8. from .token_to_middle_json import result_to_middle_json
  9. from ...utils.enum_class import ImageType
  10. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  11. class ModelSingleton:
  12. _instance = None
  13. _models = {}
  14. def __new__(cls, *args, **kwargs):
  15. if cls._instance is None:
  16. cls._instance = super().__new__(cls)
  17. return cls._instance
  18. def get_model(
  19. self,
  20. backend: str,
  21. model_path: str | None,
  22. server_url: str | None,
  23. **kwargs,
  24. ) -> BasePredictor:
  25. key = (backend, model_path, server_url)
  26. if key not in self._models:
  27. if backend in ['transformers', 'sglang-engine'] and not model_path:
  28. model_path = auto_download_and_get_model_root_path("/","vlm")
  29. self._models[key] = get_predictor(
  30. backend=backend,
  31. model_path=model_path,
  32. server_url=server_url,
  33. **kwargs,
  34. )
  35. return self._models[key]
  36. def doc_analyze(
  37. pdf_bytes,
  38. image_writer: DataWriter | None,
  39. predictor: BasePredictor | None = None,
  40. backend="transformers",
  41. model_path: str | None = None,
  42. server_url: str | None = None,
  43. **kwargs,
  44. ):
  45. if predictor is None:
  46. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  47. # load_images_start = time.time()
  48. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.BASE64)
  49. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  50. # load_images_time = round(time.time() - load_images_start, 2)
  51. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  52. # infer_start = time.time()
  53. results = predictor.batch_predict(images=images_base64_list)
  54. # infer_time = round(time.time() - infer_start, 2)
  55. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  56. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  57. return middle_json, results
  58. async def aio_doc_analyze(
  59. pdf_bytes,
  60. image_writer: DataWriter | None,
  61. predictor: BasePredictor | None = None,
  62. backend="transformers",
  63. model_path: str | None = None,
  64. server_url: str | None = None,
  65. **kwargs,
  66. ):
  67. if predictor is None:
  68. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  69. # load_images_start = time.time()
  70. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.BASE64)
  71. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  72. # load_images_time = round(time.time() - load_images_start, 2)
  73. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  74. # infer_start = time.time()
  75. results = await predictor.aio_batch_predict(images=images_base64_list)
  76. # infer_time = round(time.time() - infer_start, 2)
  77. # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  78. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  79. return middle_json, results