span_pre_proc.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import re
  3. import cv2
  4. import numpy as np
  5. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
  6. get_minbox_if_overlap_by_ratio
  7. from mineru.utils.enum_class import BlockType, ContentType
  8. from mineru.utils.pdf_image_tools import get_crop_img
  9. def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
  10. def get_block_bboxes(blocks, block_type_list):
  11. return [block[0:4] for block in blocks if block[7] in block_type_list]
  12. image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
  13. table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
  14. other_block_type = []
  15. for block_type in BlockType.__dict__.values():
  16. if not isinstance(block_type, str):
  17. continue
  18. if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  19. other_block_type.append(block_type)
  20. other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
  21. discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
  22. new_spans = []
  23. for span in spans:
  24. span_bbox = span['bbox']
  25. span_type = span['type']
  26. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
  27. discarded_block_bboxes):
  28. new_spans.append(span)
  29. continue
  30. if span_type == ContentType.IMAGE:
  31. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  32. image_bboxes):
  33. new_spans.append(span)
  34. elif span_type == ContentType.TABLE:
  35. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  36. table_bboxes):
  37. new_spans.append(span)
  38. else:
  39. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  40. other_block_bboxes):
  41. new_spans.append(span)
  42. return new_spans
  43. def remove_overlaps_low_confidence_spans(spans):
  44. dropped_spans = []
  45. # 删除重叠spans中置信度低的的那些
  46. for span1 in spans:
  47. for span2 in spans:
  48. if span1 != span2:
  49. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  50. if span1 in dropped_spans or span2 in dropped_spans:
  51. continue
  52. else:
  53. if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
  54. if span1['score'] < span2['score']:
  55. span_need_remove = span1
  56. else:
  57. span_need_remove = span2
  58. if (
  59. span_need_remove is not None
  60. and span_need_remove not in dropped_spans
  61. ):
  62. dropped_spans.append(span_need_remove)
  63. if len(dropped_spans) > 0:
  64. for span_need_remove in dropped_spans:
  65. spans.remove(span_need_remove)
  66. return spans, dropped_spans
  67. def remove_overlaps_min_spans(spans):
  68. dropped_spans = []
  69. # 删除重叠spans中较小的那些
  70. for span1 in spans:
  71. for span2 in spans:
  72. if span1 != span2:
  73. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  74. if span1 in dropped_spans or span2 in dropped_spans:
  75. continue
  76. else:
  77. overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
  78. if overlap_box is not None:
  79. span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
  80. if span_need_remove is not None and span_need_remove not in dropped_spans:
  81. dropped_spans.append(span_need_remove)
  82. if len(dropped_spans) > 0:
  83. for span_need_remove in dropped_spans:
  84. spans.remove(span_need_remove)
  85. return spans, dropped_spans
  86. def __replace_ligatures(text: str):
  87. ligatures = {
  88. 'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
  89. }
  90. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  91. def __replace_unicode(text: str):
  92. ligatures = {
  93. '\r\n': '', '\u0002': '-',
  94. }
  95. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  96. def txt_spans_extract(pdf_page, spans, pil_img, scale):
  97. textpage = pdf_page.get_textpage()
  98. width, height = pdf_page.get_size()
  99. cropbox = pdf_page.get_cropbox()
  100. need_ocr_spans = []
  101. for span in spans:
  102. if span['type'] in [ContentType.INTERLINE_EQUATION, ContentType.IMAGE, ContentType.TABLE]:
  103. continue
  104. span_bbox = span['bbox']
  105. rect_box = [span_bbox[0] + cropbox[0],
  106. height - span_bbox[3] + cropbox[1],
  107. span_bbox[2] + cropbox[0],
  108. height - span_bbox[1] + cropbox[1]]
  109. text = textpage.get_text_bounded(left=rect_box[0], top=rect_box[1],
  110. right=rect_box[2], bottom=rect_box[3])
  111. if text and len(text) > 0:
  112. text = __replace_unicode(text)
  113. text = __replace_ligatures(text)
  114. span['content'] = text.strip()
  115. span['score'] = 1.0
  116. else:
  117. need_ocr_spans.append(span)
  118. if len(need_ocr_spans) > 0:
  119. for span in need_ocr_spans:
  120. # 对span的bbox截图再ocr
  121. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  122. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  123. # 计算span的对比度,低于0.20的span不进行ocr
  124. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  125. spans.remove(span)
  126. continue
  127. span['content'] = ''
  128. span['score'] = 1.0
  129. span['np_img'] = span_img
  130. return spans
  131. def calculate_contrast(img, img_mode) -> float:
  132. """
  133. 计算给定图像的对比度。
  134. :param img: 图像,类型为numpy.ndarray
  135. :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
  136. :return: 图像的对比度值
  137. """
  138. if img_mode == 'rgb':
  139. # 将RGB图像转换为灰度图
  140. gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  141. elif img_mode == 'bgr':
  142. # 将BGR图像转换为灰度图
  143. gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  144. else:
  145. raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
  146. # 计算均值和标准差
  147. mean_value = np.mean(gray_img)
  148. std_dev = np.std(gray_img)
  149. # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
  150. contrast = std_dev / (mean_value + 1e-6)
  151. # logger.debug(f"contrast: {contrast}")
  152. return round(contrast, 2)