span_block_fix.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
  3. from mineru.utils.enum_class import BlockType, ContentType
  4. from mineru.utils.ocr_utils import __is_overlaps_y_exceeds_threshold
  5. def fill_spans_in_blocks(blocks, spans, radio):
  6. """将allspans中的span按位置关系,放入blocks中."""
  7. block_with_spans = []
  8. for block in blocks:
  9. block_type = block[7]
  10. block_bbox = block[0:4]
  11. block_dict = {
  12. 'type': block_type,
  13. 'bbox': block_bbox,
  14. }
  15. if block_type in [
  16. BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
  17. BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
  18. ]:
  19. block_dict['group_id'] = block[-1]
  20. block_spans = []
  21. for span in spans:
  22. span_bbox = span['bbox']
  23. if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > radio and span_block_type_compatible(
  24. span['type'], block_type):
  25. block_spans.append(span)
  26. block_dict['spans'] = block_spans
  27. block_with_spans.append(block_dict)
  28. # 从spans删除已经放入block_spans中的span
  29. if len(block_spans) > 0:
  30. for span in block_spans:
  31. spans.remove(span)
  32. return block_with_spans, spans
  33. def span_block_type_compatible(span_type, block_type):
  34. if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
  35. return block_type in [
  36. BlockType.TEXT,
  37. BlockType.TITLE,
  38. BlockType.IMAGE_CAPTION,
  39. BlockType.IMAGE_FOOTNOTE,
  40. BlockType.TABLE_CAPTION,
  41. BlockType.TABLE_FOOTNOTE,
  42. BlockType.DISCARDED
  43. ]
  44. elif span_type == ContentType.INTERLINE_EQUATION:
  45. return block_type in [BlockType.INTERLINE_EQUATION, BlockType.TEXT]
  46. elif span_type == ContentType.IMAGE:
  47. return block_type in [BlockType.IMAGE_BODY]
  48. elif span_type == ContentType.TABLE:
  49. return block_type in [BlockType.TABLE_BODY]
  50. else:
  51. return False
  52. def fix_discarded_block(discarded_block_with_spans):
  53. fix_discarded_blocks = []
  54. for block in discarded_block_with_spans:
  55. block = fix_text_block(block)
  56. fix_discarded_blocks.append(block)
  57. return fix_discarded_blocks
  58. def fix_text_block(block):
  59. # 文本block中的公式span都应该转换成行内type
  60. for span in block['spans']:
  61. if span['type'] == ContentType.INTERLINE_EQUATION:
  62. span['type'] = ContentType.INLINE_EQUATION
  63. block_lines = merge_spans_to_line(block['spans'])
  64. sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
  65. block['lines'] = sort_block_lines
  66. del block['spans']
  67. return block
  68. def merge_spans_to_line(spans, threshold=0.6):
  69. if len(spans) == 0:
  70. return []
  71. else:
  72. # 按照y0坐标排序
  73. spans.sort(key=lambda span: span['bbox'][1])
  74. lines = []
  75. current_line = [spans[0]]
  76. for span in spans[1:]:
  77. # 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
  78. # image和table类型,同上
  79. if span['type'] in [
  80. ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
  81. ContentType.TABLE
  82. ] or any(s['type'] in [
  83. ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
  84. ContentType.TABLE
  85. ] for s in current_line):
  86. # 则开始新行
  87. lines.append(current_line)
  88. current_line = [span]
  89. continue
  90. # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
  91. if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
  92. current_line.append(span)
  93. else:
  94. # 否则,开始新行
  95. lines.append(current_line)
  96. current_line = [span]
  97. # 添加最后一行
  98. if current_line:
  99. lines.append(current_line)
  100. return lines
  101. # 将每一个line中的span从左到右排序
  102. def line_sort_spans_by_left_to_right(lines):
  103. line_objects = []
  104. for line in lines:
  105. # 按照x0坐标排序
  106. line.sort(key=lambda span: span['bbox'][0])
  107. line_bbox = [
  108. min(span['bbox'][0] for span in line), # x0
  109. min(span['bbox'][1] for span in line), # y0
  110. max(span['bbox'][2] for span in line), # x1
  111. max(span['bbox'][3] for span in line), # y1
  112. ]
  113. line_objects.append({
  114. 'bbox': line_bbox,
  115. 'spans': line,
  116. })
  117. return line_objects
  118. def fix_block_spans(block_with_spans):
  119. fix_blocks = []
  120. for block in block_with_spans:
  121. block_type = block['type']
  122. if block_type in [BlockType.TEXT, BlockType.TITLE,
  123. BlockType.IMAGE_CAPTION, BlockType.IMAGE_CAPTION,
  124. BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
  125. ]:
  126. block = fix_text_block(block)
  127. elif block_type in [BlockType.INTERLINE_EQUATION, BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  128. block = fix_interline_block(block)
  129. else:
  130. continue
  131. fix_blocks.append(block)
  132. return fix_blocks
  133. def fix_interline_block(block):
  134. block_lines = merge_spans_to_line(block['spans'])
  135. sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
  136. block['lines'] = sort_block_lines
  137. del block['spans']
  138. return block