model_json_to_middle_json.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. from mineru.backend.pipeline.config_reader import get_device
  3. from mineru.backend.pipeline.model_init import AtomModelSingleton
  4. from mineru.backend.pipeline.para_split import para_split
  5. from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
  6. from mineru.utils.block_sort import sort_blocks_by_bbox
  7. from mineru.utils.cut_image import cut_image_and_table
  8. from mineru.utils.model_utils import clean_memory
  9. from mineru.utils.pipeline_magic_model import MagicModel
  10. from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
  11. from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
  12. remove_overlaps_min_spans, txt_spans_extract
  13. from mineru.version import __version__
  14. from mineru.utils.hash_utils import str_md5
  15. def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr=False):
  16. scale = image_dict["scale"]
  17. page_pil_img = image_dict["img_pil"]
  18. page_img_md5 = str_md5(image_dict["img_base64"])
  19. page_w, page_h = map(int, page.get_size())
  20. magic_model = MagicModel(page_model_info, scale)
  21. """从magic_model对象中获取后面会用到的区块信息"""
  22. img_groups = magic_model.get_imgs()
  23. table_groups = magic_model.get_tables()
  24. """对image和table的区块分组"""
  25. img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
  26. img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
  27. )
  28. table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
  29. table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
  30. )
  31. discarded_blocks = magic_model.get_discarded()
  32. text_blocks = magic_model.get_text_blocks()
  33. title_blocks = magic_model.get_title_blocks()
  34. inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
  35. """将所有区块的bbox整理到一起"""
  36. interline_equation_blocks = []
  37. if len(interline_equation_blocks) > 0:
  38. all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
  39. img_body_blocks, img_caption_blocks, img_footnote_blocks,
  40. table_body_blocks, table_caption_blocks, table_footnote_blocks,
  41. discarded_blocks,
  42. text_blocks,
  43. title_blocks,
  44. interline_equation_blocks,
  45. page_w,
  46. page_h,
  47. )
  48. else:
  49. all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
  50. img_body_blocks, img_caption_blocks, img_footnote_blocks,
  51. table_body_blocks, table_caption_blocks, table_footnote_blocks,
  52. discarded_blocks,
  53. text_blocks,
  54. title_blocks,
  55. interline_equations,
  56. page_w,
  57. page_h,
  58. )
  59. """获取所有的spans信息"""
  60. spans = magic_model.get_all_spans()
  61. """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
  62. """顺便删除大水印并保留abandon的span"""
  63. spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
  64. """删除重叠spans中置信度较低的那些"""
  65. spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
  66. """删除重叠spans中较小的那些"""
  67. spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
  68. """根据parse_mode,构造spans,主要是文本类的字符填充"""
  69. if ocr:
  70. pass
  71. else:
  72. """使用新版本的混合ocr方案."""
  73. spans = txt_spans_extract(page, spans, page_pil_img, scale)
  74. """先处理不需要排版的discarded_blocks"""
  75. discarded_block_with_spans, spans = fill_spans_in_blocks(
  76. all_discarded_blocks, spans, 0.4
  77. )
  78. fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
  79. """如果当前页面没有有效的bbox则跳过"""
  80. if len(all_bboxes) == 0:
  81. return None
  82. """对image和table截图"""
  83. for span in spans:
  84. if span['type'] in ['image', 'table']:
  85. span = cut_image_and_table(
  86. span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
  87. )
  88. """span填充进block"""
  89. block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
  90. """对block进行fix操作"""
  91. fix_blocks = fix_block_spans(block_with_spans)
  92. """同一行被断开的titile合并"""
  93. # merge_title_blocks(fix_blocks)
  94. """对block进行排序"""
  95. sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
  96. """构造page_info"""
  97. page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
  98. return page_info
  99. def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr=False):
  100. middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
  101. for page_index, page_model_info in enumerate(model_list):
  102. page = pdf_doc[page_index]
  103. image_dict = images_list[page_index]
  104. page_info = page_model_info_to_page_info(
  105. page_model_info, image_dict, page, image_writer, page_index, ocr=ocr
  106. )
  107. if page_info is None:
  108. page_w, page_h = map(int, page.get_size())
  109. page_info = make_page_info_dict([], page_index, page_w, page_h, [])
  110. middle_json["pdf_info"].append(page_info)
  111. """后置ocr处理"""
  112. need_ocr_list = []
  113. img_crop_list = []
  114. text_block_list = []
  115. for page_info in middle_json["pdf_info"]:
  116. for block in page_info['preproc_blocks']:
  117. if block['type'] in ['table', 'image']:
  118. for sub_block in block['blocks']:
  119. if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
  120. text_block_list.append(sub_block)
  121. elif block['type'] in ['text', 'title']:
  122. text_block_list.append(block)
  123. for block in page_info['discarded_blocks']:
  124. text_block_list.append(block)
  125. for block in text_block_list:
  126. for line in block['lines']:
  127. for span in line['spans']:
  128. if 'np_img' in span:
  129. need_ocr_list.append(span)
  130. img_crop_list.append(span['np_img'])
  131. span.pop('np_img')
  132. if len(img_crop_list) > 0:
  133. atom_model_manager = AtomModelSingleton()
  134. ocr_model = atom_model_manager.get_atom_model(
  135. atom_model_name='ocr',
  136. ocr_show_log=False,
  137. det_db_box_thresh=0.3,
  138. lang=lang
  139. )
  140. ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
  141. assert len(ocr_res_list) == len(
  142. need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
  143. for index, span in enumerate(need_ocr_list):
  144. ocr_text, ocr_score = ocr_res_list[index]
  145. span['content'] = ocr_text
  146. span['score'] = float(f"{ocr_score:.3f}")
  147. """分段"""
  148. para_split(middle_json["pdf_info"])
  149. clean_memory(get_device())
  150. return middle_json
  151. def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
  152. return_dict = {
  153. 'preproc_blocks': blocks,
  154. 'page_idx': page_id,
  155. 'page_size': [page_w, page_h],
  156. 'discarded_blocks': discarded_blocks,
  157. }
  158. return return_dict