pdf_classify.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import re
  3. from io import BytesIO
  4. import numpy as np
  5. import pypdfium2 as pdfium
  6. from loguru import logger
  7. from pdfminer.high_level import extract_text
  8. from pdfminer.pdfparser import PDFParser
  9. from pdfminer.pdfdocument import PDFDocument
  10. from pdfminer.pdfpage import PDFPage
  11. from pdfminer.pdfinterp import PDFResourceManager
  12. from pdfminer.pdfinterp import PDFPageInterpreter
  13. from pdfminer.layout import LAParams, LTImage, LTFigure
  14. from pdfminer.converter import PDFPageAggregator
  15. def classify(pdf_bytes):
  16. """
  17. 判断PDF文件是可以直接提取文本还是需要OCR
  18. Args:
  19. pdf_bytes: PDF文件的字节数据
  20. Returns:
  21. str: 'txt' 表示可以直接提取文本,'ocr' 表示需要OCR
  22. """
  23. # 从字节数据加载PDF
  24. sample_pdf_bytes = extract_pages(pdf_bytes)
  25. pdf = pdfium.PdfDocument(sample_pdf_bytes)
  26. try:
  27. # 获取PDF页数
  28. page_count = len(pdf)
  29. # 如果PDF页数为0,直接返回OCR
  30. if page_count == 0:
  31. return 'ocr'
  32. # 检查的页面数(最多检查10页)
  33. pages_to_check = min(page_count, 10)
  34. # 设置阈值:如果每页平均少于50个有效字符,认为需要OCR
  35. chars_threshold = 50
  36. # 检查平均字符数和无效字符
  37. if (get_avg_cleaned_chars_per_page(pdf, pages_to_check) < chars_threshold) or detect_invalid_chars(sample_pdf_bytes):
  38. return 'ocr'
  39. # 检查图像覆盖率
  40. if get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check) >= 0.8:
  41. return 'ocr'
  42. return 'txt'
  43. except Exception as e:
  44. logger.error(f"判断PDF类型时出错: {e}")
  45. # 出错时默认使用OCR
  46. return 'ocr'
  47. finally:
  48. # 无论执行哪个路径,都确保PDF被关闭
  49. pdf.close()
  50. def get_avg_cleaned_chars_per_page(pdf_doc, pages_to_check):
  51. # 总字符数
  52. total_chars = 0
  53. # 清理后的总字符数
  54. cleaned_total_chars = 0
  55. # 检查前几页的文本
  56. for i in range(pages_to_check):
  57. page = pdf_doc[i]
  58. text_page = page.get_textpage()
  59. text = text_page.get_text_bounded()
  60. total_chars += len(text)
  61. # 清理提取的文本,移除空白字符
  62. cleaned_text = re.sub(r'\s+', '', text)
  63. cleaned_total_chars += len(cleaned_text)
  64. # 计算平均每页字符数
  65. avg_cleaned_chars_per_page = cleaned_total_chars / pages_to_check
  66. # logger.debug(f"PDF分析: 平均每页清理后{avg_cleaned_chars_per_page:.1f}字符")
  67. return avg_cleaned_chars_per_page
  68. def get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check):
  69. # 创建内存文件对象
  70. pdf_stream = BytesIO(sample_pdf_bytes)
  71. # 创建PDF解析器
  72. parser = PDFParser(pdf_stream)
  73. # 创建PDF文档对象
  74. document = PDFDocument(parser)
  75. # 检查文档是否允许文本提取
  76. if not document.is_extractable:
  77. # logger.warning("PDF不允许内容提取")
  78. return 1.0 # 默认为高覆盖率,因为无法提取内容
  79. # 创建资源管理器和参数对象
  80. rsrcmgr = PDFResourceManager()
  81. laparams = LAParams(
  82. line_overlap=0.5,
  83. char_margin=2.0,
  84. line_margin=0.5,
  85. word_margin=0.1,
  86. boxes_flow=None,
  87. detect_vertical=False,
  88. all_texts=False,
  89. )
  90. # 创建聚合器
  91. device = PDFPageAggregator(rsrcmgr, laparams=laparams)
  92. # 创建解释器
  93. interpreter = PDFPageInterpreter(rsrcmgr, device)
  94. # 记录高图像覆盖率的页面数量
  95. high_image_coverage_pages = 0
  96. page_count = 0
  97. # 遍历页面
  98. for page in PDFPage.create_pages(document):
  99. # 控制检查的页数
  100. if page_count >= pages_to_check:
  101. break
  102. # 处理页面
  103. interpreter.process_page(page)
  104. layout = device.get_result()
  105. # 页面尺寸
  106. page_width = layout.width
  107. page_height = layout.height
  108. page_area = page_width * page_height
  109. # 计算图像覆盖的总面积
  110. image_area = 0
  111. # 遍历页面元素
  112. for element in layout:
  113. # 检查是否为图像或图形元素
  114. if isinstance(element, (LTImage, LTFigure)):
  115. # 计算图像边界框面积
  116. img_width = element.width
  117. img_height = element.height
  118. img_area = img_width * img_height
  119. image_area += img_area
  120. # 计算覆盖率
  121. coverage_ratio = min(image_area / page_area, 1.0) if page_area > 0 else 0
  122. # logger.debug(f"PDF分析: 页面 {page_count + 1} 图像覆盖率: {coverage_ratio:.2f}")
  123. # 判断是否为高覆盖率
  124. if coverage_ratio >= 0.8: # 使用80%作为高覆盖率的阈值
  125. high_image_coverage_pages += 1
  126. page_count += 1
  127. # 关闭资源
  128. pdf_stream.close()
  129. # 如果没有处理任何页面,返回0
  130. if page_count == 0:
  131. return 0.0
  132. # 计算高图像覆盖率的页面比例
  133. high_coverage_ratio = high_image_coverage_pages / page_count
  134. # logger.debug(f"PDF分析: 高图像覆盖页面比例: {high_coverage_ratio:.2f}")
  135. return high_coverage_ratio
  136. def extract_pages(src_pdf_bytes: bytes) -> bytes:
  137. """
  138. 从PDF字节数据中随机提取最多10页,返回新的PDF字节数据
  139. Args:
  140. src_pdf_bytes: PDF文件的字节数据
  141. Returns:
  142. bytes: 提取页面后的PDF字节数据
  143. """
  144. # 从字节数据加载PDF
  145. pdf = pdfium.PdfDocument(src_pdf_bytes)
  146. # 获取PDF页数
  147. total_page = len(pdf)
  148. if total_page == 0:
  149. # 如果PDF没有页面,直接返回空文档
  150. logger.warning("PDF is empty, return empty document")
  151. return b''
  152. # 选择最多10页
  153. select_page_cnt = min(10, total_page)
  154. # 从总页数中随机选择页面
  155. page_indices = np.random.choice(total_page, select_page_cnt, replace=False).tolist()
  156. # 创建一个新的PDF文档
  157. sample_docs = pdfium.PdfDocument.new()
  158. try:
  159. # 将选择的页面导入新文档
  160. sample_docs.import_pages(pdf, page_indices)
  161. pdf.close()
  162. # 将新PDF保存到内存缓冲区
  163. output_buffer = BytesIO()
  164. sample_docs.save(output_buffer)
  165. # 获取字节数据
  166. return output_buffer.getvalue()
  167. except Exception as e:
  168. pdf.close()
  169. logger.exception(e)
  170. return b'' # 出错时返回空字节
  171. def detect_invalid_chars(sample_pdf_bytes: bytes) -> bool:
  172. """"
  173. 检测PDF中是否包含非法字符
  174. """
  175. '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
  176. # sample_pdf_bytes = extract_pages(src_pdf_bytes)
  177. sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
  178. laparams = LAParams(
  179. line_overlap=0.5,
  180. char_margin=2.0,
  181. line_margin=0.5,
  182. word_margin=0.1,
  183. boxes_flow=None,
  184. detect_vertical=False,
  185. all_texts=False,
  186. )
  187. text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
  188. text = text.replace("\n", "")
  189. # logger.info(text)
  190. '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
  191. cid_pattern = re.compile(r'\(cid:\d+\)')
  192. matches = cid_pattern.findall(text)
  193. cid_count = len(matches)
  194. cid_len = sum(len(match) for match in matches)
  195. text_len = len(text)
  196. if text_len == 0:
  197. cid_chars_radio = 0
  198. else:
  199. cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
  200. # logger.debug(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
  201. '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
  202. if cid_chars_radio > 0.05:
  203. return True # 乱码文档
  204. else:
  205. return False # 正常文档
  206. if __name__ == '__main__':
  207. with open('/Users/myhloli/pdf/luanma2x10.pdf', 'rb') as f:
  208. p_bytes = f.read()
  209. logger.info(f"PDF分类结果: {classify(p_bytes)}")