vlm_analyze.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from ...data.data_reader_writer import DataWriter
  5. from ...libs.pdf_image_tools import load_images_from_pdf
  6. from .base_predictor import BasePredictor
  7. from .predictor import get_predictor
  8. from .token_to_middle_json import result_to_middle_json
  9. class ModelSingleton:
  10. _instance = None
  11. _models = {}
  12. def __new__(cls, *args, **kwargs):
  13. if cls._instance is None:
  14. cls._instance = super().__new__(cls)
  15. return cls._instance
  16. def get_model(
  17. self,
  18. backend: str,
  19. model_path: str | None,
  20. server_url: str | None,
  21. ) -> BasePredictor:
  22. key = (backend,)
  23. if key not in self._models:
  24. self._models[key] = get_predictor(
  25. backend=backend,
  26. model_path=model_path,
  27. server_url=server_url,
  28. )
  29. return self._models[key]
  30. def doc_analyze(
  31. pdf_bytes,
  32. image_writer: DataWriter | None,
  33. predictor: BasePredictor | None = None,
  34. backend="huggingface",
  35. model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release.
  36. server_url: str | None = None,
  37. ):
  38. if predictor is None:
  39. predictor = ModelSingleton().get_model(backend, model_path, server_url)
  40. load_images_start = time.time()
  41. images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
  42. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  43. load_images_time = round(time.time() - load_images_start, 2)
  44. logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  45. infer_start = time.time()
  46. results = predictor.batch_predict(images=images_base64_list)
  47. infer_time = round(time.time() - infer_start, 2)
  48. logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  49. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  50. return middle_json, results
  51. async def aio_doc_analyze(
  52. pdf_bytes,
  53. image_writer: DataWriter | None,
  54. predictor: BasePredictor | None = None,
  55. backend="huggingface",
  56. model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release.
  57. server_url: str | None = None,
  58. ):
  59. if predictor is None:
  60. predictor = ModelSingleton().get_model(backend, model_path, server_url)
  61. load_images_start = time.time()
  62. images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
  63. images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
  64. load_images_time = round(time.time() - load_images_start, 2)
  65. logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
  66. infer_start = time.time()
  67. results = await predictor.aio_batch_predict(images=images_base64_list)
  68. infer_time = round(time.time() - infer_start, 2)
  69. logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  70. middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  71. return middle_json