vlm_analyze.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from ...data.data_reader_writer import DataWriter
  5. from mineru.utils.pdf_image_tools import load_images_from_pdf
  6. from .base_predictor import BasePredictor
  7. from .predictor import get_predictor
  8. from .token_to_middle_json import result_to_middle_json
  9. from ...utils.enum_class import ModelPath
  10. class ModelSingleton:
  11. _instance = None
  12. _models = {}
  13. def __new__(cls, *args, **kwargs):
  14. if cls._instance is None:
  15. cls._instance = super().__new__(cls)
  16. return cls._instance
  17. def get_model(
  18. self,
  19. backend: str,
  20. model_path: str | None,
  21. server_url: str | None,
  22. ) -> BasePredictor:
  23. key = (backend,)
  24. if key not in self._models:
  25. self._models[key] = get_predictor(
  26. backend=backend,
  27. model_path=model_path,
  28. server_url=server_url,
  29. )
  30. return self._models[key]
  31. def doc_analyze(
  32. pdf_bytes,
  33. image_writer: DataWriter | None,
  34. predictor: BasePredictor | None = None,
  35. backend="huggingface",
  36. model_path=ModelPath.vlm_root_hf,
  37. server_url: str | None = None,
  38. ):
  39. if predictor is None:
  40. predictor = ModelSingleton().get_model(backend, model_path, server_url)
  41. # load_images_start = time.time()
  42. images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
  43. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  44. # load_images_time = round(time.time() - load_images_start, 2)
  45. # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  46. infer_start = time.time()
  47. results = predictor.batch_predict(images=images_base64_list)
  48. infer_time = round(time.time() - infer_start, 2)
  49. logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  50. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  51. return middle_json, results
  52. async def aio_doc_analyze(
  53. pdf_bytes,
  54. image_writer: DataWriter | None,
  55. predictor: BasePredictor | None = None,
  56. backend="huggingface",
  57. model_path=ModelPath.vlm_root_hf,
  58. server_url: str | None = None,
  59. ):
  60. if predictor is None:
  61. predictor = ModelSingleton().get_model(backend, model_path, server_url)
  62. load_images_start = time.time()
  63. images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
  64. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  65. load_images_time = round(time.time() - load_images_start, 2)
  66. logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  67. infer_start = time.time()
  68. results = await predictor.aio_batch_predict(images=images_base64_list)
  69. infer_time = round(time.time() - infer_start, 2)
  70. logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  71. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  72. return middle_json