span_block_fix.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
  3. from mineru.utils.enum_class import BlockType, ContentType
  4. from mineru.utils.ocr_utils import _is_overlaps_y_exceeds_threshold, _is_overlaps_x_exceeds_threshold
  5. VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD = 2
  6. VERTICAL_SPAN_IN_BLOCK_THRESHOLD = 0.8
  7. def fill_spans_in_blocks(blocks, spans, radio):
  8. """将allspans中的span按位置关系,放入blocks中."""
  9. block_with_spans = []
  10. for block in blocks:
  11. block_type = block[7]
  12. block_bbox = block[0:4]
  13. block_dict = {
  14. 'type': block_type,
  15. 'bbox': block_bbox,
  16. }
  17. if block_type in [
  18. BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE,
  19. BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
  20. ]:
  21. block_dict['group_id'] = block[-1]
  22. block_spans = []
  23. for span in spans:
  24. temp_radio = radio
  25. span_bbox = span['bbox']
  26. if span['type'] in [ContentType.IMAGE, ContentType.TABLE]:
  27. temp_radio = 0.9
  28. if calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > temp_radio and span_block_type_compatible(span['type'], block_type):
  29. block_spans.append(span)
  30. block_dict['spans'] = block_spans
  31. block_with_spans.append(block_dict)
  32. # 从spans删除已经放入block_spans中的span
  33. if len(block_spans) > 0:
  34. for span in block_spans:
  35. spans.remove(span)
  36. return block_with_spans, spans
  37. def span_block_type_compatible(span_type, block_type):
  38. if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
  39. return block_type in [
  40. BlockType.TEXT,
  41. BlockType.TITLE,
  42. BlockType.IMAGE_CAPTION,
  43. BlockType.IMAGE_FOOTNOTE,
  44. BlockType.TABLE_CAPTION,
  45. BlockType.TABLE_FOOTNOTE,
  46. BlockType.DISCARDED
  47. ]
  48. elif span_type == ContentType.INTERLINE_EQUATION:
  49. return block_type in [BlockType.INTERLINE_EQUATION, BlockType.TEXT]
  50. elif span_type == ContentType.IMAGE:
  51. return block_type in [BlockType.IMAGE_BODY]
  52. elif span_type == ContentType.TABLE:
  53. return block_type in [BlockType.TABLE_BODY]
  54. else:
  55. return False
  56. def fix_discarded_block(discarded_block_with_spans):
  57. fix_discarded_blocks = []
  58. for block in discarded_block_with_spans:
  59. block = fix_text_block(block)
  60. fix_discarded_blocks.append(block)
  61. return fix_discarded_blocks
  62. def fix_text_block(block):
  63. # 文本block中的公式span都应该转换成行内type
  64. for span in block['spans']:
  65. if span['type'] == ContentType.INTERLINE_EQUATION:
  66. span['type'] = ContentType.INLINE_EQUATION
  67. # 假设block中的span超过80%的数量高度是宽度的两倍以上,则认为是纵向文本块
  68. vertical_span_count = sum(
  69. 1 for span in block['spans']
  70. if (span['bbox'][3] - span['bbox'][1]) / (span['bbox'][2] - span['bbox'][0]) > VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD
  71. )
  72. total_span_count = len(block['spans'])
  73. if total_span_count == 0:
  74. vertical_ratio = 0
  75. else:
  76. vertical_ratio = vertical_span_count / total_span_count
  77. if vertical_ratio > VERTICAL_SPAN_IN_BLOCK_THRESHOLD:
  78. # 如果是纵向文本块,则按纵向lines处理
  79. block_lines = merge_spans_to_vertical_line(block['spans'])
  80. sort_block_lines = vertical_line_sort_spans_from_top_to_bottom(block_lines)
  81. else:
  82. block_lines = merge_spans_to_line(block['spans'])
  83. sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
  84. block['lines'] = sort_block_lines
  85. del block['spans']
  86. return block
  87. def merge_spans_to_line(spans, threshold=0.6):
  88. if len(spans) == 0:
  89. return []
  90. else:
  91. # 按照y0坐标排序
  92. spans.sort(key=lambda span: span['bbox'][1])
  93. lines = []
  94. current_line = [spans[0]]
  95. for span in spans[1:]:
  96. # 如果当前的span类型为"interline_equation" 或者 当前行中已经有"interline_equation"
  97. # image和table类型,同上
  98. if span['type'] in [
  99. ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
  100. ContentType.TABLE
  101. ] or any(s['type'] in [
  102. ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
  103. ContentType.TABLE
  104. ] for s in current_line):
  105. # 则开始新行
  106. lines.append(current_line)
  107. current_line = [span]
  108. continue
  109. # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
  110. if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
  111. current_line.append(span)
  112. else:
  113. # 否则,开始新行
  114. lines.append(current_line)
  115. current_line = [span]
  116. # 添加最后一行
  117. if current_line:
  118. lines.append(current_line)
  119. return lines
  120. def merge_spans_to_vertical_line(spans, threshold=0.6):
  121. """将纵向文本的spans合并成纵向lines(从右向左阅读)"""
  122. if len(spans) == 0:
  123. return []
  124. else:
  125. # 按照x2坐标从大到小排序(从右向左)
  126. spans.sort(key=lambda span: span['bbox'][2], reverse=True)
  127. vertical_lines = []
  128. current_line = [spans[0]]
  129. for span in spans[1:]:
  130. # 特殊类型元素单独成列
  131. if span['type'] in [
  132. ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
  133. ContentType.TABLE
  134. ] or any(s['type'] in [
  135. ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
  136. ContentType.TABLE
  137. ] for s in current_line):
  138. vertical_lines.append(current_line)
  139. current_line = [span]
  140. continue
  141. # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
  142. if _is_overlaps_x_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
  143. current_line.append(span)
  144. else:
  145. vertical_lines.append(current_line)
  146. current_line = [span]
  147. # 添加最后一列
  148. if current_line:
  149. vertical_lines.append(current_line)
  150. return vertical_lines
  151. # 将每一个line中的span从左到右排序
  152. def line_sort_spans_by_left_to_right(lines):
  153. line_objects = []
  154. for line in lines:
  155. # 按照x0坐标排序
  156. line.sort(key=lambda span: span['bbox'][0])
  157. line_bbox = [
  158. min(span['bbox'][0] for span in line), # x0
  159. min(span['bbox'][1] for span in line), # y0
  160. max(span['bbox'][2] for span in line), # x1
  161. max(span['bbox'][3] for span in line), # y1
  162. ]
  163. line_objects.append({
  164. 'bbox': line_bbox,
  165. 'spans': line,
  166. })
  167. return line_objects
  168. def vertical_line_sort_spans_from_top_to_bottom(vertical_lines):
  169. line_objects = []
  170. for line in vertical_lines:
  171. # 按照y0坐标排序(从上到下)
  172. line.sort(key=lambda span: span['bbox'][1])
  173. # 计算整个列的边界框
  174. line_bbox = [
  175. min(span['bbox'][0] for span in line), # x0
  176. min(span['bbox'][1] for span in line), # y0
  177. max(span['bbox'][2] for span in line), # x1
  178. max(span['bbox'][3] for span in line), # y1
  179. ]
  180. # 组装结果
  181. line_objects.append({
  182. 'bbox': line_bbox,
  183. 'spans': line,
  184. })
  185. return line_objects
  186. def fix_block_spans(block_with_spans):
  187. fix_blocks = []
  188. for block in block_with_spans:
  189. block_type = block['type']
  190. if block_type in [BlockType.TEXT, BlockType.TITLE,
  191. BlockType.IMAGE_CAPTION, BlockType.IMAGE_CAPTION,
  192. BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE
  193. ]:
  194. block = fix_text_block(block)
  195. elif block_type in [BlockType.INTERLINE_EQUATION, BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  196. block = fix_interline_block(block)
  197. else:
  198. continue
  199. fix_blocks.append(block)
  200. return fix_blocks
  201. def fix_interline_block(block):
  202. block_lines = merge_spans_to_line(block['spans'])
  203. sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
  204. block['lines'] = sort_block_lines
  205. del block['spans']
  206. return block