pdf_classify.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import re
  3. from io import BytesIO
  4. import numpy as np
  5. import pypdfium2 as pdfium
  6. from loguru import logger
  7. from pdfminer.high_level import extract_text
  8. from pdfminer.layout import LAParams
  9. from pypdf import PdfReader
  10. def classify(pdf_bytes):
  11. """
  12. 判断PDF文件是可以直接提取文本还是需要OCR
  13. Args:
  14. pdf_bytes: PDF文件的字节数据
  15. Returns:
  16. str: 'txt' 表示可以直接提取文本,'ocr' 表示需要OCR
  17. """
  18. try:
  19. # 从字节数据加载PDF
  20. sample_pdf_bytes = extract_pages(pdf_bytes)
  21. pdf = pdfium.PdfDocument(sample_pdf_bytes)
  22. # 获取PDF页数
  23. page_count = len(pdf)
  24. # 如果PDF页数为0,直接返回OCR
  25. if page_count == 0:
  26. return 'ocr'
  27. # 检查的页面数(最多检查10页)
  28. pages_to_check = min(page_count, 10)
  29. # 设置阈值:如果每页平均少于50个有效字符,认为需要OCR
  30. chars_threshold = 50
  31. if (get_avg_cleaned_chars_per_page(pdf, pages_to_check) < chars_threshold) or detect_invalid_chars(sample_pdf_bytes):
  32. return 'ocr'
  33. else:
  34. if get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check) >= 0.9:
  35. return 'ocr'
  36. return 'txt'
  37. except Exception as e:
  38. logger.error(f"判断PDF类型时出错: {e}")
  39. # 出错时默认使用OCR
  40. return 'ocr'
  41. def get_avg_cleaned_chars_per_page(pdf_doc, pages_to_check):
  42. # 总字符数
  43. total_chars = 0
  44. # 清理后的总字符数
  45. cleaned_total_chars = 0
  46. # 检查前几页的文本
  47. for i in range(pages_to_check):
  48. page = pdf_doc[i]
  49. text_page = page.get_textpage()
  50. text = text_page.get_text_bounded()
  51. total_chars += len(text)
  52. # 清理提取的文本,移除空白字符
  53. cleaned_text = re.sub(r'\s+', '', text)
  54. cleaned_total_chars += len(cleaned_text)
  55. # 计算平均每页字符数
  56. avg_cleaned_chars_per_page = cleaned_total_chars / pages_to_check
  57. # logger.debug(f"PDF分析: 平均每页清理后{avg_cleaned_chars_per_page:.1f}字符")
  58. pdf_doc.close() # 关闭PDF文档
  59. return avg_cleaned_chars_per_page
  60. def get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check):
  61. pdf_stream = BytesIO(sample_pdf_bytes)
  62. pdf_reader = PdfReader(pdf_stream)
  63. # 记录高图像覆盖率的页面数量
  64. high_image_coverage_pages = 0
  65. # 检查前几页的图像
  66. for i in range(pages_to_check):
  67. page = pdf_reader.pages[i]
  68. # 获取页面尺寸
  69. page_width = float(page.mediabox.width)
  70. page_height = float(page.mediabox.height)
  71. page_area = page_width * page_height
  72. # 估算图像覆盖率
  73. image_area = 0
  74. if '/Resources' in page:
  75. resources = page['/Resources']
  76. if '/XObject' in resources:
  77. x_objects = resources['/XObject']
  78. # 计算所有图像对象占据的面积
  79. for obj_name in x_objects:
  80. try:
  81. obj = x_objects[obj_name]
  82. if obj['/Subtype'] == '/Image':
  83. # 获取图像宽高
  84. width = obj.get('/Width', 0)
  85. height = obj.get('/Height', 0)
  86. # 计算图像在页面上的估计面积
  87. # 注意:这是估计值,因为没有考虑图像变换矩阵
  88. scale_factor = 1.0 # 估计缩放因子
  89. img_area = width * height * scale_factor
  90. image_area += img_area
  91. except Exception as e:
  92. # logger.debug(f"处理图像对象时出错: {e}")
  93. continue
  94. # 估算图像覆盖率
  95. estimated_coverage = min(image_area / page_area, 1.0) if page_area > 0 else 0
  96. # logger.debug(f"PDF分析: 页面 {i + 1} 图像覆盖率: {estimated_coverage:.2f}")
  97. # 基于估计的图像覆盖率
  98. if estimated_coverage >= 1:
  99. # 如果图像覆盖率超过80%,认为是高图像覆盖率页面
  100. high_image_coverage_pages += 1
  101. # 计算高图像覆盖页面比例
  102. high_image_coverage_ratio = high_image_coverage_pages / pages_to_check
  103. # logger.debug(f"PDF分析: 高图像覆盖页面比例: {high_image_coverage_ratio:.2f}")
  104. pdf_stream.close() # 关闭字节流
  105. pdf_reader.close()
  106. return high_image_coverage_ratio
  107. def extract_pages(src_pdf_bytes: bytes) -> bytes:
  108. """
  109. 从PDF字节数据中随机提取最多10页,返回新的PDF字节数据
  110. Args:
  111. src_pdf_bytes: PDF文件的字节数据
  112. Returns:
  113. bytes: 提取页面后的PDF字节数据
  114. """
  115. # 从字节数据加载PDF
  116. pdf = pdfium.PdfDocument(src_pdf_bytes)
  117. # 获取PDF页数
  118. total_page = len(pdf)
  119. if total_page == 0:
  120. # 如果PDF没有页面,直接返回空文档
  121. logger.warning("PDF is empty, return empty document")
  122. return b''
  123. # 选择最多10页
  124. select_page_cnt = min(10, total_page)
  125. # 从总页数中随机选择页面
  126. page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
  127. # 创建一个新的PDF文档
  128. sample_docs = pdfium.PdfDocument.new()
  129. try:
  130. # 将选择的页面导入新文档
  131. sample_docs.import_pages(pdf, page_indices)
  132. # 将新PDF保存到内存缓冲区
  133. output_buffer = BytesIO()
  134. sample_docs.save(output_buffer)
  135. # 获取字节数据
  136. return output_buffer.getvalue()
  137. except Exception as e:
  138. logger.exception(e)
  139. return b'' # 出错时返回空字节
  140. def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
  141. """"
  142. 检测PDF中是否包含非法字符
  143. """
  144. '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
  145. # sample_pdf_bytes = extract_pages(src_pdf_bytes)
  146. sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
  147. laparams = LAParams(
  148. line_overlap=0.5,
  149. char_margin=2.0,
  150. line_margin=0.5,
  151. word_margin=0.1,
  152. boxes_flow=None,
  153. detect_vertical=False,
  154. all_texts=False,
  155. )
  156. text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
  157. text = text.replace("\n", "")
  158. # logger.info(text)
  159. '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
  160. cid_pattern = re.compile(r'\(cid:\d+\)')
  161. matches = cid_pattern.findall(text)
  162. cid_count = len(matches)
  163. cid_len = sum(len(match) for match in matches)
  164. text_len = len(text)
  165. if text_len == 0:
  166. cid_chars_radio = 0
  167. else:
  168. cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
  169. # logger.debug(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
  170. '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
  171. if cid_chars_radio > 0.05:
  172. return True # 乱码文档
  173. else:
  174. return False # 正常文档
  175. if __name__ == '__main__':
  176. with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
  177. p_bytes = f.read()
  178. logger.info(f"PDF分类结果: {classify(p_bytes)}")