model_json_to_middle_json.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import time
  4. from loguru import logger
  5. from tqdm import tqdm
  6. from mineru.backend.utils import cross_page_table_merge
  7. from mineru.utils.config_reader import get_device, get_llm_aided_config, get_formula_enable
  8. from mineru.backend.pipeline.model_init import AtomModelSingleton
  9. from mineru.backend.pipeline.para_split import para_split
  10. from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
  11. from mineru.utils.block_sort import sort_blocks_by_bbox
  12. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
  13. from mineru.utils.cut_image import cut_image_and_table
  14. from mineru.utils.enum_class import ContentType
  15. from mineru.utils.llm_aided import llm_aided_title
  16. from mineru.utils.model_utils import clean_memory
  17. from mineru.backend.pipeline.pipeline_magic_model import MagicModel
  18. from mineru.utils.ocr_utils import OcrConfidence
  19. from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
  20. from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
  21. remove_overlaps_min_spans, txt_spans_extract
  22. from mineru.version import __version__
  23. from mineru.utils.hash_utils import bytes_md5
  24. def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
  25. scale = image_dict["scale"]
  26. page_pil_img = image_dict["img_pil"]
  27. # page_img_md5 = str_md5(image_dict["img_base64"])
  28. page_img_md5 = bytes_md5(page_pil_img.tobytes())
  29. page_w, page_h = map(int, page.get_size())
  30. magic_model = MagicModel(page_model_info, scale)
  31. """从magic_model对象中获取后面会用到的区块信息"""
  32. discarded_blocks = magic_model.get_discarded()
  33. text_blocks = magic_model.get_text_blocks()
  34. title_blocks = magic_model.get_title_blocks()
  35. inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
  36. img_groups = magic_model.get_imgs()
  37. table_groups = magic_model.get_tables()
  38. """对image和table的区块分组"""
  39. img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
  40. img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
  41. )
  42. table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
  43. table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
  44. )
  45. """获取所有的spans信息"""
  46. spans = magic_model.get_all_spans()
  47. """某些图可能是文本块,通过简单的规则判断一下"""
  48. if len(maybe_text_image_blocks) > 0:
  49. for block in maybe_text_image_blocks:
  50. should_add_to_text_blocks = False
  51. if ocr_enable:
  52. # 找到与当前block重叠的text spans
  53. span_in_block_list = [
  54. span for span in spans
  55. if span['type'] == 'text' and
  56. calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7
  57. ]
  58. if len(span_in_block_list) > 0:
  59. # 计算spans总面积
  60. spans_area = sum(
  61. (span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1])
  62. for span in span_in_block_list
  63. )
  64. # 计算block面积
  65. block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
  66. # 判断是否符合文本图条件
  67. if block_area > 0 and spans_area / block_area > 0.25:
  68. should_add_to_text_blocks = True
  69. # 根据条件决定添加到哪个列表
  70. if should_add_to_text_blocks:
  71. block.pop('group_id', None) # 移除group_id
  72. text_blocks.append(block)
  73. else:
  74. img_body_blocks.append(block)
  75. """将所有区块的bbox整理到一起"""
  76. if formula_enabled:
  77. interline_equation_blocks = []
  78. if len(interline_equation_blocks) > 0:
  79. for block in interline_equation_blocks:
  80. spans.append({
  81. "type": ContentType.INTERLINE_EQUATION,
  82. 'score': block['score'],
  83. "bbox": block['bbox'],
  84. "content": "",
  85. })
  86. all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
  87. img_body_blocks, img_caption_blocks, img_footnote_blocks,
  88. table_body_blocks, table_caption_blocks, table_footnote_blocks,
  89. discarded_blocks,
  90. text_blocks,
  91. title_blocks,
  92. interline_equation_blocks,
  93. page_w,
  94. page_h,
  95. )
  96. else:
  97. all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
  98. img_body_blocks, img_caption_blocks, img_footnote_blocks,
  99. table_body_blocks, table_caption_blocks, table_footnote_blocks,
  100. discarded_blocks,
  101. text_blocks,
  102. title_blocks,
  103. interline_equations,
  104. page_w,
  105. page_h,
  106. )
  107. """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
  108. """顺便删除大水印并保留abandon的span"""
  109. spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
  110. """删除重叠spans中置信度较低的那些"""
  111. spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
  112. """删除重叠spans中较小的那些"""
  113. spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
  114. """根据parse_mode,构造spans,主要是文本类的字符填充"""
  115. if ocr_enable:
  116. pass
  117. else:
  118. """使用新版本的混合ocr方案."""
  119. spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
  120. """先处理不需要排版的discarded_blocks"""
  121. discarded_block_with_spans, spans = fill_spans_in_blocks(
  122. all_discarded_blocks, spans, 0.4
  123. )
  124. fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
  125. """如果当前页面没有有效的bbox则跳过"""
  126. if len(all_bboxes) == 0 and len(fix_discarded_blocks) == 0:
  127. return None
  128. """对image/table/interline_equation截图"""
  129. for span in spans:
  130. if span['type'] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
  131. span = cut_image_and_table(
  132. span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
  133. )
  134. """span填充进block"""
  135. block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
  136. """对block进行fix操作"""
  137. fix_blocks = fix_block_spans(block_with_spans)
  138. """对block进行排序"""
  139. sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
  140. """构造page_info"""
  141. page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
  142. return page_info
  143. def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
  144. middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
  145. formula_enabled = get_formula_enable(formula_enabled)
  146. for page_index, page_model_info in tqdm(enumerate(model_list), total=len(model_list), desc="Processing pages"):
  147. page = pdf_doc[page_index]
  148. image_dict = images_list[page_index]
  149. page_info = page_model_info_to_page_info(
  150. page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
  151. )
  152. if page_info is None:
  153. page_w, page_h = map(int, page.get_size())
  154. page_info = make_page_info_dict([], page_index, page_w, page_h, [])
  155. middle_json["pdf_info"].append(page_info)
  156. """后置ocr处理"""
  157. need_ocr_list = []
  158. img_crop_list = []
  159. text_block_list = []
  160. for page_info in middle_json["pdf_info"]:
  161. for block in page_info['preproc_blocks']:
  162. if block['type'] in ['table', 'image']:
  163. for sub_block in block['blocks']:
  164. if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
  165. text_block_list.append(sub_block)
  166. elif block['type'] in ['text', 'title']:
  167. text_block_list.append(block)
  168. for block in page_info['discarded_blocks']:
  169. text_block_list.append(block)
  170. for block in text_block_list:
  171. for line in block['lines']:
  172. for span in line['spans']:
  173. if 'np_img' in span:
  174. need_ocr_list.append(span)
  175. img_crop_list.append(span['np_img'])
  176. span.pop('np_img')
  177. if len(img_crop_list) > 0:
  178. atom_model_manager = AtomModelSingleton()
  179. ocr_model = atom_model_manager.get_atom_model(
  180. atom_model_name='ocr',
  181. det_db_box_thresh=0.3,
  182. lang=lang
  183. )
  184. ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
  185. assert len(ocr_res_list) == len(
  186. need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
  187. for index, span in enumerate(need_ocr_list):
  188. ocr_text, ocr_score = ocr_res_list[index]
  189. if ocr_score > OcrConfidence.min_confidence:
  190. span['content'] = ocr_text
  191. span['score'] = float(f"{ocr_score:.3f}")
  192. else:
  193. span['content'] = ''
  194. span['score'] = 0.0
  195. """分段"""
  196. para_split(middle_json["pdf_info"])
  197. """表格跨页合并"""
  198. cross_page_table_merge(middle_json["pdf_info"])
  199. """llm优化"""
  200. llm_aided_config = get_llm_aided_config()
  201. if llm_aided_config is not None:
  202. """标题优化"""
  203. title_aided_config = llm_aided_config.get('title_aided', None)
  204. if title_aided_config is not None:
  205. if title_aided_config.get('enable', False):
  206. llm_aided_title_start_time = time.time()
  207. llm_aided_title(middle_json["pdf_info"], title_aided_config)
  208. logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
  209. """清理内存"""
  210. pdf_doc.close()
  211. if os.getenv('MINERU_DONOT_CLEAN_MEM') is None and len(model_list) >= 10:
  212. clean_memory(get_device())
  213. return middle_json
  214. def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
  215. return_dict = {
  216. 'preproc_blocks': blocks,
  217. 'page_idx': page_id,
  218. 'page_size': [page_w, page_h],
  219. 'discarded_blocks': discarded_blocks,
  220. }
  221. return return_dict