model_output_to_middle_json.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import os
  2. import time
  3. import cv2
  4. import numpy as np
  5. from loguru import logger
  6. from mineru.backend.utils import cross_page_table_merge
  7. from mineru.backend.vlm.vlm_magic_model import MagicModel
  8. from mineru.utils.config_reader import get_table_enable, get_llm_aided_config
  9. from mineru.utils.cut_image import cut_image_and_table
  10. from mineru.utils.enum_class import ContentType
  11. from mineru.utils.hash_utils import bytes_md5
  12. from mineru.utils.pdf_image_tools import get_crop_img
  13. from mineru.version import __version__
  14. heading_level_import_success = False
  15. llm_aided_config = get_llm_aided_config()
  16. if llm_aided_config:
  17. title_aided_config = llm_aided_config.get('title_aided', {})
  18. if title_aided_config.get('enable', False):
  19. try:
  20. from mineru.utils.llm_aided import llm_aided_title
  21. from mineru.backend.pipeline.model_init import AtomModelSingleton
  22. heading_level_import_success = True
  23. except Exception as e:
  24. logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
  25. "please execute `pip install mineru[core]` to install the required packages.")
  26. def blocks_to_page_info(page_blocks, image_dict, page, image_writer, page_index) -> dict:
  27. """将blocks转换为页面信息"""
  28. scale = image_dict["scale"]
  29. # page_pil_img = image_dict["img_pil"]
  30. page_pil_img = image_dict["img_pil"]
  31. page_img_md5 = bytes_md5(page_pil_img.tobytes())
  32. width, height = map(int, page.get_size())
  33. magic_model = MagicModel(page_blocks, width, height)
  34. image_blocks = magic_model.get_image_blocks()
  35. table_blocks = magic_model.get_table_blocks()
  36. title_blocks = magic_model.get_title_blocks()
  37. discarded_blocks = magic_model.get_discarded_blocks()
  38. code_blocks = magic_model.get_code_blocks()
  39. ref_text_blocks = magic_model.get_ref_text_blocks()
  40. phonetic_blocks = magic_model.get_phonetic_blocks()
  41. list_blocks = magic_model.get_list_blocks()
  42. # 如果有标题优化需求,则对title_blocks截图det
  43. if heading_level_import_success:
  44. atom_model_manager = AtomModelSingleton()
  45. ocr_model = atom_model_manager.get_atom_model(
  46. atom_model_name='ocr',
  47. ocr_show_log=False,
  48. det_db_box_thresh=0.3,
  49. lang='ch_lite'
  50. )
  51. for title_block in title_blocks:
  52. title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
  53. title_np_img = np.array(title_pil_img)
  54. # 给title_pil_img添加上下左右各50像素白边padding
  55. title_np_img = cv2.copyMakeBorder(
  56. title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
  57. )
  58. title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
  59. ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
  60. if len(ocr_det_res) > 0:
  61. # 计算所有res的平均高度
  62. avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
  63. title_block['line_avg_height'] = round(avg_height/scale)
  64. text_blocks = magic_model.get_text_blocks()
  65. interline_equation_blocks = magic_model.get_interline_equation_blocks()
  66. all_spans = magic_model.get_all_spans()
  67. # 对image/table/interline_equation的span截图
  68. for span in all_spans:
  69. if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
  70. span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
  71. page_blocks = []
  72. page_blocks.extend([
  73. *image_blocks,
  74. *table_blocks,
  75. *code_blocks,
  76. *ref_text_blocks,
  77. *phonetic_blocks,
  78. *title_blocks,
  79. *text_blocks,
  80. *interline_equation_blocks,
  81. *list_blocks,
  82. ])
  83. # 对page_blocks根据index的值进行排序
  84. page_blocks.sort(key=lambda x: x["index"])
  85. page_info = {"para_blocks": page_blocks, "discarded_blocks": discarded_blocks, "page_size": [width, height], "page_idx": page_index}
  86. return page_info
  87. def result_to_middle_json(model_output_blocks_list, images_list, pdf_doc, image_writer):
  88. middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
  89. for index, page_blocks in enumerate(model_output_blocks_list):
  90. page = pdf_doc[index]
  91. image_dict = images_list[index]
  92. page_info = blocks_to_page_info(page_blocks, image_dict, page, image_writer, index)
  93. middle_json["pdf_info"].append(page_info)
  94. """表格跨页合并"""
  95. table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
  96. if table_enable:
  97. cross_page_table_merge(middle_json["pdf_info"])
  98. """llm优化标题分级"""
  99. if heading_level_import_success:
  100. llm_aided_title_start_time = time.time()
  101. llm_aided_title(middle_json["pdf_info"], title_aided_config)
  102. logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
  103. # 关闭pdf文档
  104. pdf_doc.close()
  105. return middle_json