pdf_classify.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import re
  3. from io import BytesIO
  4. import numpy as np
  5. import pypdfium2 as pdfium
  6. from loguru import logger
  7. from pdfminer.high_level import extract_text
  8. from pdfminer.layout import LAParams
  9. def classify(pdf_bytes):
  10. """
  11. 判断PDF文件是可以直接提取文本还是需要OCR
  12. Args:
  13. pdf_bytes: PDF文件的字节数据
  14. Returns:
  15. str: 'txt' 表示可以直接提取文本,'ocr' 表示需要OCR
  16. """
  17. try:
  18. # 从字节数据加载PDF
  19. sample_pdf_bytes = extract_pages(pdf_bytes)
  20. pdf = pdfium.PdfDocument(sample_pdf_bytes)
  21. # 获取PDF页数
  22. page_count = len(pdf)
  23. # 如果PDF页数为0,直接返回OCR
  24. if page_count == 0:
  25. return 'ocr'
  26. # 总字符数
  27. total_chars = 0
  28. # 清理后的总字符数
  29. cleaned_total_chars = 0
  30. # 检查的页面数(最多检查10页)
  31. pages_to_check = min(page_count, 10)
  32. # 检查前几页的文本
  33. for i in range(pages_to_check):
  34. page = pdf[i]
  35. text_page = page.get_textpage()
  36. text = text_page.get_text_bounded()
  37. total_chars += len(text)
  38. # 清理提取的文本,移除空白字符
  39. cleaned_text = re.sub(r'\s+', '', text)
  40. cleaned_total_chars += len(cleaned_text)
  41. # 计算平均每页字符数
  42. # avg_chars_per_page = total_chars / pages_to_check
  43. avg_cleaned_chars_per_page = cleaned_total_chars / pages_to_check
  44. # 设置阈值:如果每页平均少于50个有效字符,认为需要OCR
  45. chars_threshold = 50
  46. # logger.debug(f"PDF分析: 平均每页{avg_chars_per_page:.1f}字符, 清理后{avg_cleaned_chars_per_page:.1f}字符")
  47. if (avg_cleaned_chars_per_page < chars_threshold) or detect_invalid_chars(sample_pdf_bytes):
  48. return 'ocr'
  49. else:
  50. return 'txt'
  51. except Exception as e:
  52. logger.error(f"判断PDF类型时出错: {e}")
  53. # 出错时默认使用OCR
  54. return 'ocr'
  55. def extract_pages(src_pdf_bytes: bytes) -> bytes:
  56. """
  57. 从PDF字节数据中随机提取最多10页,返回新的PDF字节数据
  58. Args:
  59. src_pdf_bytes: PDF文件的字节数据
  60. Returns:
  61. bytes: 提取页面后的PDF字节数据
  62. """
  63. # 从字节数据加载PDF
  64. pdf = pdfium.PdfDocument(src_pdf_bytes)
  65. # 获取PDF页数
  66. total_page = len(pdf)
  67. if total_page == 0:
  68. # 如果PDF没有页面,直接返回空文档
  69. logger.warning("PDF is empty, return empty document")
  70. return b''
  71. # 选择最多10页
  72. select_page_cnt = min(10, total_page)
  73. # 从总页数中随机选择页面
  74. page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
  75. # 创建一个新的PDF文档
  76. sample_docs = pdfium.PdfDocument.new()
  77. try:
  78. # 将选择的页面导入新文档
  79. sample_docs.import_pages(pdf, page_indices)
  80. # 将新PDF保存到内存缓冲区
  81. output_buffer = BytesIO()
  82. sample_docs.save(output_buffer)
  83. # 获取字节数据
  84. return output_buffer.getvalue()
  85. except Exception as e:
  86. logger.exception(e)
  87. return b'' # 出错时返回空字节
  88. def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
  89. """"
  90. 检测PDF中是否包含非法字符
  91. """
  92. '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
  93. # sample_pdf_bytes = extract_pages(src_pdf_bytes)
  94. sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
  95. laparams = LAParams(
  96. line_overlap=0.5,
  97. char_margin=2.0,
  98. line_margin=0.5,
  99. word_margin=0.1,
  100. boxes_flow=None,
  101. detect_vertical=False,
  102. all_texts=False,
  103. )
  104. text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
  105. text = text.replace("\n", "")
  106. # logger.info(text)
  107. '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
  108. cid_pattern = re.compile(r'\(cid:\d+\)')
  109. matches = cid_pattern.findall(text)
  110. cid_count = len(matches)
  111. cid_len = sum(len(match) for match in matches)
  112. text_len = len(text)
  113. if text_len == 0:
  114. cid_chars_radio = 0
  115. else:
  116. cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
  117. # logger.debug(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
  118. '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
  119. if cid_chars_radio > 0.05:
  120. return True # 乱码文档
  121. else:
  122. return False # 正常文档
  123. if __name__ == '__main__':
  124. with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
  125. p_bytes = f.read()
  126. logger.info(f"PDF分类结果: {classify(p_bytes)}")