span_pre_proc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import re
  3. import statistics
  4. import cv2
  5. import numpy as np
  6. from loguru import logger
  7. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
  8. get_minbox_if_overlap_by_ratio
  9. from mineru.utils.enum_class import BlockType, ContentType
  10. from mineru.utils.pdf_image_tools import get_crop_img
  11. from mineru.utils.pdf_text_tool import get_page
  12. def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
  13. def get_block_bboxes(blocks, block_type_list):
  14. return [block[0:4] for block in blocks if block[7] in block_type_list]
  15. image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
  16. table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
  17. other_block_type = []
  18. for block_type in BlockType.__dict__.values():
  19. if not isinstance(block_type, str):
  20. continue
  21. if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  22. other_block_type.append(block_type)
  23. other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
  24. discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
  25. new_spans = []
  26. for span in spans:
  27. span_bbox = span['bbox']
  28. span_type = span['type']
  29. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
  30. discarded_block_bboxes):
  31. new_spans.append(span)
  32. continue
  33. if span_type == ContentType.IMAGE:
  34. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  35. image_bboxes):
  36. new_spans.append(span)
  37. elif span_type == ContentType.TABLE:
  38. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  39. table_bboxes):
  40. new_spans.append(span)
  41. else:
  42. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  43. other_block_bboxes):
  44. new_spans.append(span)
  45. return new_spans
  46. def remove_overlaps_low_confidence_spans(spans):
  47. dropped_spans = []
  48. # 删除重叠spans中置信度低的的那些
  49. for span1 in spans:
  50. for span2 in spans:
  51. if span1 != span2:
  52. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  53. if span1 in dropped_spans or span2 in dropped_spans:
  54. continue
  55. else:
  56. if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
  57. if span1['score'] < span2['score']:
  58. span_need_remove = span1
  59. else:
  60. span_need_remove = span2
  61. if (
  62. span_need_remove is not None
  63. and span_need_remove not in dropped_spans
  64. ):
  65. dropped_spans.append(span_need_remove)
  66. if len(dropped_spans) > 0:
  67. for span_need_remove in dropped_spans:
  68. spans.remove(span_need_remove)
  69. return spans, dropped_spans
  70. def remove_overlaps_min_spans(spans):
  71. dropped_spans = []
  72. # 删除重叠spans中较小的那些
  73. for span1 in spans:
  74. for span2 in spans:
  75. if span1 != span2:
  76. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  77. if span1 in dropped_spans or span2 in dropped_spans:
  78. continue
  79. else:
  80. overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
  81. if overlap_box is not None:
  82. span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
  83. if span_need_remove is not None and span_need_remove not in dropped_spans:
  84. dropped_spans.append(span_need_remove)
  85. if len(dropped_spans) > 0:
  86. for span_need_remove in dropped_spans:
  87. spans.remove(span_need_remove)
  88. return spans, dropped_spans
  89. def __replace_ligatures(text: str):
  90. ligatures = {
  91. 'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
  92. }
  93. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  94. def __replace_unicode(text: str):
  95. ligatures = {
  96. '\r\n': '', '\u0002': '-',
  97. }
  98. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  99. """pdf_text dict方案 char级别"""
  100. def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
  101. page_dict = get_page(pdf_page)
  102. page_all_chars = []
  103. page_all_lines = []
  104. for block in page_dict['blocks']:
  105. for line in block['lines']:
  106. if 0 < abs(line['rotation']) < 90:
  107. # 旋转角度在0-90度之间的行,直接跳过
  108. continue
  109. page_all_lines.append(line)
  110. for span in line['spans']:
  111. for char in span['chars']:
  112. page_all_chars.append(char)
  113. # 计算所有sapn的高度的中位数
  114. span_height_list = []
  115. for span in spans:
  116. if span['type'] in [ContentType.TEXT]:
  117. span_height = span['bbox'][3] - span['bbox'][1]
  118. span['height'] = span_height
  119. span['width'] = span['bbox'][2] - span['bbox'][0]
  120. span_height_list.append(span_height)
  121. if len(span_height_list) == 0:
  122. return spans
  123. else:
  124. median_span_height = statistics.median(span_height_list)
  125. useful_spans = []
  126. unuseful_spans = []
  127. # 纵向span的两个特征:1. 高度超过多个line 2. 高宽比超过某个值
  128. vertical_spans = []
  129. for span in spans:
  130. if span['type'] in [ContentType.TEXT]:
  131. for block in all_bboxes + all_discarded_blocks:
  132. if block[7] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
  133. continue
  134. if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
  135. if span['height'] > median_span_height * 3 and span['height'] > span['width'] * 3:
  136. vertical_spans.append(span)
  137. elif block in all_bboxes:
  138. useful_spans.append(span)
  139. else:
  140. unuseful_spans.append(span)
  141. break
  142. """垂直的span框直接用line进行填充"""
  143. if len(vertical_spans) > 0:
  144. for pdfium_line in page_all_lines:
  145. for span in vertical_spans:
  146. if calculate_overlap_area_in_bbox1_area_ratio(pdfium_line['bbox'].bbox, span['bbox']) > 0.5:
  147. for pdfium_span in pdfium_line['spans']:
  148. span['content'] += pdfium_span['text']
  149. break
  150. for span in vertical_spans:
  151. if len(span['content']) == 0:
  152. spans.remove(span)
  153. """水平的span框先用char填充,再用ocr填充空的span框"""
  154. new_spans = []
  155. for span in useful_spans + unuseful_spans:
  156. if span['type'] in [ContentType.TEXT]:
  157. span['chars'] = []
  158. new_spans.append(span)
  159. need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars)
  160. """对未填充的span进行ocr"""
  161. if len(need_ocr_spans) > 0:
  162. for span in need_ocr_spans:
  163. # 对span的bbox截图再ocr
  164. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  165. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  166. # 计算span的对比度,低于0.20的span不进行ocr
  167. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  168. spans.remove(span)
  169. continue
  170. span['content'] = ''
  171. span['score'] = 1.0
  172. span['np_img'] = span_img
  173. return spans
  174. def fill_char_in_spans(spans, all_chars):
  175. # 简单从上到下排一下序
  176. spans = sorted(spans, key=lambda x: x['bbox'][1])
  177. for char in all_chars:
  178. for span in spans:
  179. if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
  180. span['chars'].append(char)
  181. break
  182. need_ocr_spans = []
  183. for span in spans:
  184. chars_to_content(span)
  185. # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
  186. if len(span['content']) * span['height'] < span['width'] * 0.5:
  187. # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
  188. need_ocr_spans.append(span)
  189. del span['height'], span['width']
  190. return need_ocr_spans
  191. LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
  192. LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
  193. Span_Height_Radio = 0.33 # 字符的中轴和span的中轴高度差不能超过1/3span高度
  194. def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=Span_Height_Radio):
  195. char_center_x = (char_bbox[0] + char_bbox[2]) / 2
  196. char_center_y = (char_bbox[1] + char_bbox[3]) / 2
  197. span_center_y = (span_bbox[1] + span_bbox[3]) / 2
  198. span_height = span_bbox[3] - span_bbox[1]
  199. if (
  200. span_bbox[0] < char_center_x < span_bbox[2]
  201. and span_bbox[1] < char_center_y < span_bbox[3]
  202. and abs(char_center_y - span_center_y) < span_height * span_height_radio # 字符的中轴和span的中轴高度差不能超过Span_Height_Radio
  203. ):
  204. return True
  205. else:
  206. # 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
  207. # 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
  208. if char in LINE_STOP_FLAG:
  209. if (
  210. (span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
  211. and char_center_x > span_bbox[0]
  212. and span_bbox[1] < char_center_y < span_bbox[3]
  213. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  214. ):
  215. return True
  216. elif char in LINE_START_FLAG:
  217. if (
  218. span_bbox[0] < char_bbox[2] < (span_bbox[0] + span_height)
  219. and char_center_x < span_bbox[2]
  220. and span_bbox[1] < char_center_y < span_bbox[3]
  221. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  222. ):
  223. return True
  224. else:
  225. return False
  226. def chars_to_content(span):
  227. # 检查span中的char是否为空
  228. if len(span['chars']) == 0:
  229. pass
  230. else:
  231. # 给chars按char_idx排序
  232. span['chars'] = sorted(span['chars'], key=lambda x: x['char_idx'])
  233. # Calculate the width of each character
  234. char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
  235. # Calculate the median width
  236. median_width = statistics.median(char_widths)
  237. content = ''
  238. for char in span['chars']:
  239. # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
  240. char1 = char
  241. char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
  242. if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['char'] != ' ' and char2['char'] != ' ':
  243. content += f"{char['char']} "
  244. else:
  245. content += char['char']
  246. content = __replace_unicode(content)
  247. content = __replace_ligatures(content)
  248. content = __replace_ligatures(content)
  249. span['content'] = content.strip()
  250. del span['chars']
  251. def calculate_contrast(img, img_mode) -> float:
  252. """
  253. 计算给定图像的对比度。
  254. :param img: 图像,类型为numpy.ndarray
  255. :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
  256. :return: 图像的对比度值
  257. """
  258. if img_mode == 'rgb':
  259. # 将RGB图像转换为灰度图
  260. gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  261. elif img_mode == 'bgr':
  262. # 将BGR图像转换为灰度图
  263. gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  264. else:
  265. raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
  266. # 计算均值和标准差
  267. mean_value = np.mean(gray_img)
  268. std_dev = np.std(gray_img)
  269. # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
  270. contrast = std_dev / (mean_value + 1e-6)
  271. # logger.debug(f"contrast: {contrast}")
  272. return round(contrast, 2)