span_pre_proc.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import re
  3. import cv2
  4. import numpy as np
  5. from loguru import logger
  6. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
  7. get_minbox_if_overlap_by_ratio
  8. from mineru.utils.enum_class import BlockType, ContentType
  9. from mineru.utils.pdf_image_tools import get_crop_img
  10. from mineru.utils.pdf_text_tool import get_page
  11. def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
  12. def get_block_bboxes(blocks, block_type_list):
  13. return [block[0:4] for block in blocks if block[7] in block_type_list]
  14. image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
  15. table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
  16. other_block_type = []
  17. for block_type in BlockType.__dict__.values():
  18. if not isinstance(block_type, str):
  19. continue
  20. if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  21. other_block_type.append(block_type)
  22. other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
  23. discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
  24. new_spans = []
  25. for span in spans:
  26. span_bbox = span['bbox']
  27. span_type = span['type']
  28. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
  29. discarded_block_bboxes):
  30. new_spans.append(span)
  31. continue
  32. if span_type == ContentType.IMAGE:
  33. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  34. image_bboxes):
  35. new_spans.append(span)
  36. elif span_type == ContentType.TABLE:
  37. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  38. table_bboxes):
  39. new_spans.append(span)
  40. else:
  41. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  42. other_block_bboxes):
  43. new_spans.append(span)
  44. return new_spans
  45. def remove_overlaps_low_confidence_spans(spans):
  46. dropped_spans = []
  47. # 删除重叠spans中置信度低的的那些
  48. for span1 in spans:
  49. for span2 in spans:
  50. if span1 != span2:
  51. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  52. if span1 in dropped_spans or span2 in dropped_spans:
  53. continue
  54. else:
  55. if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
  56. if span1['score'] < span2['score']:
  57. span_need_remove = span1
  58. else:
  59. span_need_remove = span2
  60. if (
  61. span_need_remove is not None
  62. and span_need_remove not in dropped_spans
  63. ):
  64. dropped_spans.append(span_need_remove)
  65. if len(dropped_spans) > 0:
  66. for span_need_remove in dropped_spans:
  67. spans.remove(span_need_remove)
  68. return spans, dropped_spans
  69. def remove_overlaps_min_spans(spans):
  70. dropped_spans = []
  71. # 删除重叠spans中较小的那些
  72. for span1 in spans:
  73. for span2 in spans:
  74. if span1 != span2:
  75. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  76. if span1 in dropped_spans or span2 in dropped_spans:
  77. continue
  78. else:
  79. overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
  80. if overlap_box is not None:
  81. span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
  82. if span_need_remove is not None and span_need_remove not in dropped_spans:
  83. dropped_spans.append(span_need_remove)
  84. if len(dropped_spans) > 0:
  85. for span_need_remove in dropped_spans:
  86. spans.remove(span_need_remove)
  87. return spans, dropped_spans
  88. def __replace_ligatures(text: str):
  89. ligatures = {
  90. 'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
  91. }
  92. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  93. def __replace_unicode(text: str):
  94. ligatures = {
  95. '\r\n': '', '\u0002': '-',
  96. }
  97. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  98. def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
  99. textpage = pdf_page.get_textpage()
  100. width, height = pdf_page.get_size()
  101. cropbox = pdf_page.get_cropbox()
  102. need_ocr_spans = []
  103. for span in spans:
  104. if span['type'] in [ContentType.INTERLINE_EQUATION, ContentType.IMAGE, ContentType.TABLE]:
  105. continue
  106. span_bbox = span['bbox']
  107. rect_box = [span_bbox[0] + cropbox[0],
  108. height - span_bbox[3] + cropbox[1],
  109. span_bbox[2] + cropbox[0],
  110. height - span_bbox[1] + cropbox[1]]
  111. # logger.info(f"span bbox: {span_bbox}, rect_box: {rect_box}")
  112. middle_height = (rect_box[1] + rect_box[3]) / 2
  113. rect_box[1] = middle_height - 1
  114. rect_box[3] = middle_height + 1
  115. text = textpage.get_text_bounded(left=rect_box[0], top=rect_box[1],
  116. right=rect_box[2], bottom=rect_box[3])
  117. if text and len(text) > 0:
  118. text = __replace_unicode(text)
  119. text = __replace_ligatures(text)
  120. span['content'] = text.strip()
  121. span['score'] = 1.0
  122. else:
  123. need_ocr_spans.append(span)
  124. if len(need_ocr_spans) > 0:
  125. for span in need_ocr_spans:
  126. # 对span的bbox截图再ocr
  127. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  128. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  129. # 计算span的对比度,低于0.20的span不进行ocr
  130. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  131. spans.remove(span)
  132. continue
  133. span['content'] = ''
  134. span['score'] = 1.0
  135. span['np_img'] = span_img
  136. return spans
  137. def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
  138. page_dict = get_page(pdf_page)
  139. page_all_spans = []
  140. for block in page_dict['blocks']:
  141. for line in block['lines']:
  142. if 0 < abs(line['rotation']) < 90:
  143. # 旋转角度在0-90度之间的行,直接跳过
  144. continue
  145. for span in line['spans']:
  146. page_all_spans.append(span)
  147. need_ocr_spans = []
  148. for span in spans:
  149. if span['type'] in [ContentType.TEXT]:
  150. span['sub_spans'] = []
  151. matched_spans = []
  152. for page_span in page_all_spans:
  153. if calculate_overlap_area_in_bbox1_area_ratio(page_span['bbox'].bbox, span['bbox']) > 0.5:
  154. span['sub_spans'].append(page_span)
  155. matched_spans.append(page_span)
  156. # 从page_all_spans中移除已匹配的元素
  157. page_all_spans = [span for span in page_all_spans if span not in matched_spans]
  158. # 对sub_spans按照bbox的x坐标进行排序
  159. span['sub_spans'].sort(key=lambda x: x['bbox'].x_start)
  160. # 对sub_spans的content进行拼接
  161. span_content = ''.join([sub_span['text'] for sub_span in span['sub_spans']])
  162. if span_content and len(span_content) > 0:
  163. span_content = __replace_unicode(span_content)
  164. span_content = __replace_ligatures(span_content)
  165. span['content'] = span_content.strip()
  166. span['score'] = 1.0
  167. else:
  168. need_ocr_spans.append(span)
  169. # 移除span的sub_spans
  170. span.pop('sub_spans', None)
  171. else:
  172. pass
  173. if len(need_ocr_spans) > 0:
  174. for span in need_ocr_spans:
  175. # 对span的bbox截图再ocr
  176. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  177. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  178. # 计算span的对比度,低于0.20的span不进行ocr
  179. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  180. spans.remove(span)
  181. continue
  182. span['content'] = ''
  183. span['score'] = 1.0
  184. span['np_img'] = span_img
  185. return spans
  186. def calculate_contrast(img, img_mode) -> float:
  187. """
  188. 计算给定图像的对比度。
  189. :param img: 图像,类型为numpy.ndarray
  190. :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
  191. :return: 图像的对比度值
  192. """
  193. if img_mode == 'rgb':
  194. # 将RGB图像转换为灰度图
  195. gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  196. elif img_mode == 'bgr':
  197. # 将BGR图像转换为灰度图
  198. gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  199. else:
  200. raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
  201. # 计算均值和标准差
  202. mean_value = np.mean(gray_img)
  203. std_dev = np.std(gray_img)
  204. # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
  205. contrast = std_dev / (mean_value + 1e-6)
  206. # logger.debug(f"contrast: {contrast}")
  207. return round(contrast, 2)