span_pre_proc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import re
  3. import statistics
  4. import cv2
  5. import numpy as np
  6. from loguru import logger
  7. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
  8. get_minbox_if_overlap_by_ratio
  9. from mineru.utils.enum_class import BlockType, ContentType
  10. from mineru.utils.pdf_image_tools import get_crop_img
  11. from mineru.utils.pdf_text_tool import get_page
  12. def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
  13. def get_block_bboxes(blocks, block_type_list):
  14. return [block[0:4] for block in blocks if block[7] in block_type_list]
  15. image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
  16. table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
  17. other_block_type = []
  18. for block_type in BlockType.__dict__.values():
  19. if not isinstance(block_type, str):
  20. continue
  21. if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  22. other_block_type.append(block_type)
  23. other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
  24. discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
  25. new_spans = []
  26. for span in spans:
  27. span_bbox = span['bbox']
  28. span_type = span['type']
  29. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
  30. discarded_block_bboxes):
  31. new_spans.append(span)
  32. continue
  33. if span_type == ContentType.IMAGE:
  34. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  35. image_bboxes):
  36. new_spans.append(span)
  37. elif span_type == ContentType.TABLE:
  38. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  39. table_bboxes):
  40. new_spans.append(span)
  41. else:
  42. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  43. other_block_bboxes):
  44. new_spans.append(span)
  45. return new_spans
  46. def remove_overlaps_low_confidence_spans(spans):
  47. dropped_spans = []
  48. # 删除重叠spans中置信度低的的那些
  49. for span1 in spans:
  50. for span2 in spans:
  51. if span1 != span2:
  52. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  53. if span1 in dropped_spans or span2 in dropped_spans:
  54. continue
  55. else:
  56. if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
  57. if span1['score'] < span2['score']:
  58. span_need_remove = span1
  59. else:
  60. span_need_remove = span2
  61. if (
  62. span_need_remove is not None
  63. and span_need_remove not in dropped_spans
  64. ):
  65. dropped_spans.append(span_need_remove)
  66. if len(dropped_spans) > 0:
  67. for span_need_remove in dropped_spans:
  68. spans.remove(span_need_remove)
  69. return spans, dropped_spans
  70. def remove_overlaps_min_spans(spans):
  71. dropped_spans = []
  72. # 删除重叠spans中较小的那些
  73. for span1 in spans:
  74. for span2 in spans:
  75. if span1 != span2:
  76. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  77. if span1 in dropped_spans or span2 in dropped_spans:
  78. continue
  79. else:
  80. overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
  81. if overlap_box is not None:
  82. span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
  83. if span_need_remove is not None and span_need_remove not in dropped_spans:
  84. dropped_spans.append(span_need_remove)
  85. if len(dropped_spans) > 0:
  86. for span_need_remove in dropped_spans:
  87. spans.remove(span_need_remove)
  88. return spans, dropped_spans
  89. def __replace_ligatures(text: str):
  90. ligatures = {
  91. 'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
  92. }
  93. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  94. def __replace_unicode(text: str):
  95. ligatures = {
  96. '\r\n': '', '\u0002': '-',
  97. }
  98. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  99. """textpage.get_text_bounded方案"""
  100. def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
  101. textpage = pdf_page.get_textpage()
  102. width, height = pdf_page.get_size()
  103. cropbox = pdf_page.get_cropbox()
  104. need_ocr_spans = []
  105. for span in spans:
  106. if span['type'] in [ContentType.INTERLINE_EQUATION, ContentType.IMAGE, ContentType.TABLE]:
  107. continue
  108. span_bbox = span['bbox']
  109. rect_box = [span_bbox[0] + cropbox[0],
  110. height - span_bbox[3] + cropbox[1],
  111. span_bbox[2] + cropbox[0],
  112. height - span_bbox[1] + cropbox[1]]
  113. # logger.info(f"span bbox: {span_bbox}, rect_box: {rect_box}")
  114. middle_height = (rect_box[1] + rect_box[3]) / 2
  115. rect_box[1] = middle_height - 1
  116. rect_box[3] = middle_height + 1
  117. text = textpage.get_text_bounded(left=rect_box[0], top=rect_box[1],
  118. right=rect_box[2], bottom=rect_box[3])
  119. if text and len(text) > 0:
  120. text = __replace_unicode(text)
  121. text = __replace_ligatures(text)
  122. span['content'] = text.strip()
  123. span['score'] = 1.0
  124. else:
  125. need_ocr_spans.append(span)
  126. if len(need_ocr_spans) > 0:
  127. for span in need_ocr_spans:
  128. # 对span的bbox截图再ocr
  129. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  130. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  131. # 计算span的对比度,低于0.20的span不进行ocr
  132. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  133. spans.remove(span)
  134. continue
  135. span['content'] = ''
  136. span['score'] = 1.0
  137. span['np_img'] = span_img
  138. return spans
  139. """pdf_text dict方案 span级别"""
  140. def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
  141. page_dict = get_page(pdf_page)
  142. page_all_spans = []
  143. for block in page_dict['blocks']:
  144. for line in block['lines']:
  145. if 0 < abs(line['rotation']) < 90:
  146. # 旋转角度在0-90度之间的行,直接跳过
  147. continue
  148. for span in line['spans']:
  149. page_all_spans.append(span)
  150. need_ocr_spans = []
  151. for span in spans:
  152. if span['type'] in [ContentType.TEXT]:
  153. span['sub_spans'] = []
  154. matched_spans = []
  155. for page_span in page_all_spans:
  156. if calculate_overlap_area_in_bbox1_area_ratio(page_span['bbox'].bbox, span['bbox']) > 0.5:
  157. span['sub_spans'].append(page_span)
  158. matched_spans.append(page_span)
  159. # 从page_all_spans中移除已匹配的元素
  160. page_all_spans = [span for span in page_all_spans if span not in matched_spans]
  161. # 对sub_spans按照bbox的x坐标进行排序
  162. span['sub_spans'].sort(key=lambda x: x['bbox'].x_start)
  163. # 对sub_spans的content进行拼接
  164. span_content = ''.join([sub_span['text'] for sub_span in span['sub_spans']])
  165. if span_content and len(span_content) > 0:
  166. span_content = __replace_unicode(span_content)
  167. span_content = __replace_ligatures(span_content)
  168. span['content'] = span_content.strip()
  169. span['score'] = 1.0
  170. else:
  171. need_ocr_spans.append(span)
  172. # 移除span的sub_spans
  173. span.pop('sub_spans', None)
  174. else:
  175. pass
  176. if len(need_ocr_spans) > 0:
  177. for span in need_ocr_spans:
  178. # 对span的bbox截图再ocr
  179. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  180. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  181. # 计算span的对比度,低于0.20的span不进行ocr
  182. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  183. spans.remove(span)
  184. continue
  185. span['content'] = ''
  186. span['score'] = 1.0
  187. span['np_img'] = span_img
  188. return spans
  189. """pdf_text dict方案 char级别"""
  190. def txt_spans_extract_v3(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
  191. page_dict = get_page(pdf_page)
  192. page_all_chars = []
  193. page_all_lines = []
  194. for block in page_dict['blocks']:
  195. for line in block['lines']:
  196. if 0 < abs(line['rotation']) < 90:
  197. # 旋转角度在0-90度之间的行,直接跳过
  198. continue
  199. page_all_lines.append(line)
  200. for span in line['spans']:
  201. for char in span['chars']:
  202. page_all_chars.append(char)
  203. # 计算所有sapn的高度的中位数
  204. span_height_list = []
  205. for span in spans:
  206. if span['type'] in [ContentType.TEXT]:
  207. span_height = span['bbox'][3] - span['bbox'][1]
  208. span['height'] = span_height
  209. span['width'] = span['bbox'][2] - span['bbox'][0]
  210. span_height_list.append(span_height)
  211. if len(span_height_list) == 0:
  212. return spans
  213. else:
  214. median_span_height = statistics.median(span_height_list)
  215. useful_spans = []
  216. unuseful_spans = []
  217. # 纵向span的两个特征:1. 高度超过多个line 2. 高宽比超过某个值
  218. vertical_spans = []
  219. for span in spans:
  220. if span['type'] in [ContentType.TEXT]:
  221. for block in all_bboxes + all_discarded_blocks:
  222. if block[7] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
  223. continue
  224. if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
  225. if span['height'] > median_span_height * 3 and span['height'] > span['width'] * 3:
  226. vertical_spans.append(span)
  227. elif block in all_bboxes:
  228. useful_spans.append(span)
  229. else:
  230. unuseful_spans.append(span)
  231. break
  232. """垂直的span框直接用line进行填充"""
  233. if len(vertical_spans) > 0:
  234. for pdfium_line in page_all_lines:
  235. for span in vertical_spans:
  236. if calculate_overlap_area_in_bbox1_area_ratio(pdfium_line['bbox'].bbox, span['bbox']) > 0.5:
  237. for pdfium_span in pdfium_line['spans']:
  238. span['content'] += pdfium_span['text']
  239. break
  240. for span in vertical_spans:
  241. if len(span['content']) == 0:
  242. spans.remove(span)
  243. """水平的span框先用char填充,再用ocr填充空的span框"""
  244. new_spans = []
  245. for span in useful_spans + unuseful_spans:
  246. if span['type'] in [ContentType.TEXT]:
  247. span['chars'] = []
  248. new_spans.append(span)
  249. need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars)
  250. """对未填充的span进行ocr"""
  251. if len(need_ocr_spans) > 0:
  252. for span in need_ocr_spans:
  253. # 对span的bbox截图再ocr
  254. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  255. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  256. # 计算span的对比度,低于0.20的span不进行ocr
  257. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  258. spans.remove(span)
  259. continue
  260. span['content'] = ''
  261. span['score'] = 1.0
  262. span['np_img'] = span_img
  263. return spans
  264. def fill_char_in_spans(spans, all_chars):
  265. # 简单从上到下排一下序
  266. spans = sorted(spans, key=lambda x: x['bbox'][1])
  267. for char in all_chars:
  268. for span in spans:
  269. if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
  270. span['chars'].append(char)
  271. break
  272. need_ocr_spans = []
  273. for span in spans:
  274. chars_to_content(span)
  275. # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
  276. if len(span['content']) * span['height'] < span['width'] * 0.5:
  277. # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
  278. need_ocr_spans.append(span)
  279. del span['height'], span['width']
  280. return need_ocr_spans
  281. LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
  282. LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
  283. def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
  284. char_center_x = (char_bbox[0] + char_bbox[2]) / 2
  285. char_center_y = (char_bbox[1] + char_bbox[3]) / 2
  286. span_center_y = (span_bbox[1] + span_bbox[3]) / 2
  287. span_height = span_bbox[3] - span_bbox[1]
  288. if (
  289. span_bbox[0] < char_center_x < span_bbox[2]
  290. and span_bbox[1] < char_center_y < span_bbox[3]
  291. and abs(char_center_y - span_center_y) < span_height * span_height_radio # 字符的中轴和span的中轴高度差不能超过1/4span高度
  292. ):
  293. return True
  294. else:
  295. # 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
  296. # 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
  297. if char in LINE_STOP_FLAG:
  298. if (
  299. (span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
  300. and char_center_x > span_bbox[0]
  301. and span_bbox[1] < char_center_y < span_bbox[3]
  302. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  303. ):
  304. return True
  305. elif char in LINE_START_FLAG:
  306. if (
  307. span_bbox[0] < char_bbox[2] < (span_bbox[0] + span_height)
  308. and char_center_x < span_bbox[2]
  309. and span_bbox[1] < char_center_y < span_bbox[3]
  310. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  311. ):
  312. return True
  313. else:
  314. return False
  315. def chars_to_content(span):
  316. # 检查span中的char是否为空
  317. if len(span['chars']) == 0:
  318. pass
  319. else:
  320. # 先给chars按char['bbox']的中心点的x坐标排序
  321. span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
  322. # Calculate the width of each character
  323. char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
  324. # Calculate the median width
  325. median_width = statistics.median(char_widths)
  326. # 通过x轴重叠比率移除一部分char
  327. span = remove_x_overlapping_chars(span, median_width)
  328. content = ''
  329. for char in span['chars']:
  330. # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
  331. char1 = char
  332. char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
  333. if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['char'] != ' ' and char2['char'] != ' ':
  334. content += f"{char['char']} "
  335. else:
  336. content += char['char']
  337. content = __replace_unicode(content)
  338. content = __replace_ligatures(content)
  339. content = __replace_ligatures(content)
  340. span['content'] = content.strip()
  341. del span['chars']
  342. def remove_x_overlapping_chars(span, median_width):
  343. """
  344. Remove characters from a span that overlap significantly on the x-axis.
  345. Args:
  346. median_width:
  347. span (dict): A span containing a list of chars, each with bbox coordinates
  348. in the format [x0, y0, x1, y1]
  349. Returns:
  350. dict: The span with overlapping characters removed
  351. """
  352. if 'chars' not in span or len(span['chars']) < 2:
  353. return span
  354. overlap_threshold = median_width * 0.3
  355. i = 0
  356. while i < len(span['chars']) - 1:
  357. char1 = span['chars'][i]
  358. char2 = span['chars'][i + 1]
  359. # Calculate overlap width
  360. x_left = max(char1['bbox'][0], char2['bbox'][0])
  361. x_right = min(char1['bbox'][2], char2['bbox'][2])
  362. if x_right > x_left: # There is overlap
  363. overlap_width = x_right - x_left
  364. if overlap_width > overlap_threshold:
  365. if char1['char'] == char2['char'] or char1['char'] == ' ' or char2['char'] == ' ':
  366. # Determine which character to remove
  367. width1 = char1['bbox'][2] - char1['bbox'][0]
  368. width2 = char2['bbox'][2] - char2['bbox'][0]
  369. if width1 < width2:
  370. # Remove the narrower character
  371. span['chars'].pop(i)
  372. else:
  373. span['chars'].pop(i + 1)
  374. else:
  375. i += 1
  376. # Don't increment i since we need to check the new pair
  377. else:
  378. i += 1
  379. else:
  380. i += 1
  381. return span
  382. def calculate_contrast(img, img_mode) -> float:
  383. """
  384. 计算给定图像的对比度。
  385. :param img: 图像,类型为numpy.ndarray
  386. :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
  387. :return: 图像的对比度值
  388. """
  389. if img_mode == 'rgb':
  390. # 将RGB图像转换为灰度图
  391. gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  392. elif img_mode == 'bgr':
  393. # 将BGR图像转换为灰度图
  394. gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  395. else:
  396. raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
  397. # 计算均值和标准差
  398. mean_value = np.mean(gray_img)
  399. std_dev = np.std(gray_img)
  400. # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
  401. contrast = std_dev / (mean_value + 1e-6)
  402. # logger.debug(f"contrast: {contrast}")
  403. return round(contrast, 2)