span_pre_proc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import collections
  3. import re
  4. import statistics
  5. import cv2
  6. import numpy as np
  7. from loguru import logger
  8. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
  9. get_minbox_if_overlap_by_ratio
  10. from mineru.utils.enum_class import BlockType, ContentType
  11. from mineru.utils.pdf_image_tools import get_crop_img
  12. from mineru.utils.pdf_text_tool import get_page
  13. def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
  14. def get_block_bboxes(blocks, block_type_list):
  15. return [block[0:4] for block in blocks if block[7] in block_type_list]
  16. image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
  17. table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
  18. other_block_type = []
  19. for block_type in BlockType.__dict__.values():
  20. if not isinstance(block_type, str):
  21. continue
  22. if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  23. other_block_type.append(block_type)
  24. other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
  25. discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
  26. new_spans = []
  27. for span in spans:
  28. span_bbox = span['bbox']
  29. span_type = span['type']
  30. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
  31. discarded_block_bboxes):
  32. new_spans.append(span)
  33. continue
  34. if span_type == ContentType.IMAGE:
  35. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  36. image_bboxes):
  37. new_spans.append(span)
  38. elif span_type == ContentType.TABLE:
  39. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  40. table_bboxes):
  41. new_spans.append(span)
  42. else:
  43. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  44. other_block_bboxes):
  45. new_spans.append(span)
  46. return new_spans
  47. def remove_overlaps_low_confidence_spans(spans):
  48. dropped_spans = []
  49. # 删除重叠spans中置信度低的的那些
  50. for span1 in spans:
  51. for span2 in spans:
  52. if span1 != span2:
  53. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  54. if span1 in dropped_spans or span2 in dropped_spans:
  55. continue
  56. else:
  57. if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
  58. if span1['score'] < span2['score']:
  59. span_need_remove = span1
  60. else:
  61. span_need_remove = span2
  62. if (
  63. span_need_remove is not None
  64. and span_need_remove not in dropped_spans
  65. ):
  66. dropped_spans.append(span_need_remove)
  67. if len(dropped_spans) > 0:
  68. for span_need_remove in dropped_spans:
  69. spans.remove(span_need_remove)
  70. return spans, dropped_spans
  71. def remove_overlaps_min_spans(spans):
  72. dropped_spans = []
  73. # 删除重叠spans中较小的那些
  74. for span1 in spans:
  75. for span2 in spans:
  76. if span1 != span2:
  77. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  78. if span1 in dropped_spans or span2 in dropped_spans:
  79. continue
  80. else:
  81. overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
  82. if overlap_box is not None:
  83. span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
  84. if span_need_remove is not None and span_need_remove not in dropped_spans:
  85. dropped_spans.append(span_need_remove)
  86. if len(dropped_spans) > 0:
  87. for span_need_remove in dropped_spans:
  88. spans.remove(span_need_remove)
  89. return spans, dropped_spans
  90. def __replace_ligatures(text: str):
  91. ligatures = {
  92. 'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
  93. }
  94. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  95. def __replace_unicode(text: str):
  96. ligatures = {
  97. '\r\n': '', '\u0002': '-',
  98. }
  99. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  100. """pdf_text dict方案 char级别"""
  101. def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
  102. page_dict = get_page(pdf_page)
  103. page_all_chars = []
  104. page_all_lines = []
  105. for block in page_dict['blocks']:
  106. for line in block['lines']:
  107. if 0 < abs(line['rotation']) < 90:
  108. # 旋转角度在0-90度之间的行,直接跳过
  109. continue
  110. page_all_lines.append(line)
  111. for span in line['spans']:
  112. for char in span['chars']:
  113. page_all_chars.append(char)
  114. # 计算所有sapn的高度的中位数
  115. span_height_list = []
  116. for span in spans:
  117. if span['type'] in [ContentType.TEXT]:
  118. span_height = span['bbox'][3] - span['bbox'][1]
  119. span['height'] = span_height
  120. span['width'] = span['bbox'][2] - span['bbox'][0]
  121. span_height_list.append(span_height)
  122. if len(span_height_list) == 0:
  123. return spans
  124. else:
  125. median_span_height = statistics.median(span_height_list)
  126. useful_spans = []
  127. unuseful_spans = []
  128. # 纵向span的两个特征:1. 高度超过多个line 2. 高宽比超过某个值
  129. vertical_spans = []
  130. for span in spans:
  131. if span['type'] in [ContentType.TEXT]:
  132. for block in all_bboxes + all_discarded_blocks:
  133. if block[7] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
  134. continue
  135. if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
  136. if span['height'] > median_span_height * 3 and span['height'] > span['width'] * 3:
  137. vertical_spans.append(span)
  138. elif block in all_bboxes:
  139. useful_spans.append(span)
  140. else:
  141. unuseful_spans.append(span)
  142. break
  143. """垂直的span框直接用line进行填充"""
  144. if len(vertical_spans) > 0:
  145. for pdfium_line in page_all_lines:
  146. for span in vertical_spans:
  147. if calculate_overlap_area_in_bbox1_area_ratio(pdfium_line['bbox'].bbox, span['bbox']) > 0.5:
  148. for pdfium_span in pdfium_line['spans']:
  149. span['content'] += pdfium_span['text']
  150. break
  151. for span in vertical_spans:
  152. if len(span['content']) == 0:
  153. spans.remove(span)
  154. """水平的span框先用char填充,再用ocr填充空的span框"""
  155. new_spans = []
  156. for span in useful_spans + unuseful_spans:
  157. if span['type'] in [ContentType.TEXT]:
  158. span['chars'] = []
  159. new_spans.append(span)
  160. need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars, median_span_height)
  161. """对未填充的span进行ocr"""
  162. if len(need_ocr_spans) > 0:
  163. for span in need_ocr_spans:
  164. # 对span的bbox截图再ocr
  165. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  166. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  167. # 计算span的对比度,低于0.20的span不进行ocr
  168. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  169. spans.remove(span)
  170. continue
  171. span['content'] = ''
  172. span['score'] = 1.0
  173. span['np_img'] = span_img
  174. return spans
  175. def fill_char_in_spans(spans, all_chars, median_span_height):
  176. # 简单从上到下排一下序
  177. spans = sorted(spans, key=lambda x: x['bbox'][1])
  178. grid_size = median_span_height
  179. grid = collections.defaultdict(list)
  180. for i, span in enumerate(spans):
  181. start_cell = int(span['bbox'][1] / grid_size)
  182. end_cell = int(span['bbox'][3] / grid_size)
  183. for cell_idx in range(start_cell, end_cell + 1):
  184. grid[cell_idx].append(i)
  185. for char in all_chars:
  186. char_center_y = (char['bbox'][1] + char['bbox'][3]) / 2
  187. cell_idx = int(char_center_y / grid_size)
  188. candidate_span_indices = grid.get(cell_idx, [])
  189. for span_idx in candidate_span_indices:
  190. span = spans[span_idx]
  191. if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
  192. span['chars'].append(char)
  193. break
  194. need_ocr_spans = []
  195. for span in spans:
  196. chars_to_content(span)
  197. # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
  198. if len(span['content']) * span['height'] < span['width'] * 0.5:
  199. # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
  200. need_ocr_spans.append(span)
  201. del span['height'], span['width']
  202. return need_ocr_spans
  203. LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
  204. LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
  205. Span_Height_Radio = 0.33 # 字符的中轴和span的中轴高度差不能超过1/3span高度
  206. def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=Span_Height_Radio):
  207. char_center_x = (char_bbox[0] + char_bbox[2]) / 2
  208. char_center_y = (char_bbox[1] + char_bbox[3]) / 2
  209. span_center_y = (span_bbox[1] + span_bbox[3]) / 2
  210. span_height = span_bbox[3] - span_bbox[1]
  211. if (
  212. span_bbox[0] < char_center_x < span_bbox[2]
  213. and span_bbox[1] < char_center_y < span_bbox[3]
  214. and abs(char_center_y - span_center_y) < span_height * span_height_radio # 字符的中轴和span的中轴高度差不能超过Span_Height_Radio
  215. ):
  216. return True
  217. else:
  218. # 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
  219. # 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
  220. if char in LINE_STOP_FLAG:
  221. if (
  222. (span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
  223. and char_center_x > span_bbox[0]
  224. and span_bbox[1] < char_center_y < span_bbox[3]
  225. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  226. ):
  227. return True
  228. elif char in LINE_START_FLAG:
  229. if (
  230. span_bbox[0] < char_bbox[2] < (span_bbox[0] + span_height)
  231. and char_center_x < span_bbox[2]
  232. and span_bbox[1] < char_center_y < span_bbox[3]
  233. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  234. ):
  235. return True
  236. else:
  237. return False
  238. def chars_to_content(span):
  239. # 检查span中的char是否为空
  240. if len(span['chars']) == 0:
  241. pass
  242. else:
  243. # 给chars按char_idx排序
  244. span['chars'] = sorted(span['chars'], key=lambda x: x['char_idx'])
  245. # Calculate the width of each character
  246. char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
  247. # Calculate the median width
  248. median_width = statistics.median(char_widths)
  249. content = ''
  250. for char in span['chars']:
  251. # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
  252. char1 = char
  253. char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
  254. if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['char'] != ' ' and char2['char'] != ' ':
  255. content += f"{char['char']} "
  256. else:
  257. content += char['char']
  258. content = __replace_unicode(content)
  259. content = __replace_ligatures(content)
  260. content = __replace_ligatures(content)
  261. span['content'] = content.strip()
  262. del span['chars']
  263. def calculate_contrast(img, img_mode) -> float:
  264. """
  265. 计算给定图像的对比度。
  266. :param img: 图像,类型为numpy.ndarray
  267. :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
  268. :return: 图像的对比度值
  269. """
  270. if img_mode == 'rgb':
  271. # 将RGB图像转换为灰度图
  272. gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  273. elif img_mode == 'bgr':
  274. # 将BGR图像转换为灰度图
  275. gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  276. else:
  277. raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
  278. # 计算均值和标准差
  279. mean_value = np.mean(gray_img)
  280. std_dev = np.std(gray_img)
  281. # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
  282. contrast = std_dev / (mean_value + 1e-6)
  283. # logger.debug(f"contrast: {contrast}")
  284. return round(contrast, 2)