llm_aided.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import json
  3. from loguru import logger
  4. from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
  5. from openai import OpenAI
  6. #@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
  7. formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
  8. 1. 修正渲染或编译错误:
  9. - Some syntax errors such as mismatched/missing/extra tokens. Your task is to fix these syntax errors and make sure corrected results conform to latex math syntax principles.
  10. - 包含KaTeX不支持的关键词等原因导致的无法编译或渲染的错误
  11. 2. 保留原始信息:
  12. - 保留原始公式中的所有重要信息
  13. - 不要添加任何原始公式中没有的新信息
  14. IMPORTANT:请仅返回修正后的公式,不要包含任何介绍、解释或元数据。
  15. LaTeX recognition result:
  16. $FORMULA
  17. Your corrected result:
  18. """
  19. text_optimize_prompt = f"""请根据以下指南修正OCR引起的错误,确保文本连贯并符合原始内容:
  20. 1. 修正OCR引起的拼写错误和错误:
  21. - 修正常见的OCR错误(例如,'rn' 被误读为 'm')
  22. - 使用上下文和常识进行修正
  23. - 只修正明显的错误,不要不必要的修改内容
  24. - 不要添加额外的句号或其他不必要的标点符号
  25. 2. 保持原始结构:
  26. - 保留所有标题和子标题
  27. 3. 保留原始内容:
  28. - 保留原始文本中的所有重要信息
  29. - 不要添加任何原始文本中没有的新信息
  30. - 保留段落之间的换行符
  31. 4. 保持连贯性:
  32. - 确保内容与前文顺畅连接
  33. - 适当处理在句子中间开始或结束的文本
  34. 5. 修正行内公式:
  35. - 去除行内公式前后多余的空格
  36. - 修正公式中的OCR错误
  37. - 确保公式能够通过KaTeX渲染
  38. 6. 修正全角字符
  39. - 修正全角标点符号为半角标点符号
  40. - 修正全角字母为半角字母
  41. - 修正全角数字为半角数字
  42. IMPORTANT:请仅返回修正后的文本,保留所有原始格式,包括换行符。不要包含任何介绍、解释或元数据。
  43. Previous context:
  44. Current chunk to process:
  45. Corrected text:
  46. """
  47. def llm_aided_formula(pdf_info_dict, formula_aided_config):
  48. pass
  49. def llm_aided_text(pdf_info_dict, text_aided_config):
  50. pass
  51. def llm_aided_title(pdf_info_dict, title_aided_config):
  52. client = OpenAI(
  53. api_key=title_aided_config["api_key"],
  54. base_url=title_aided_config["base_url"],
  55. )
  56. title_dict = {}
  57. origin_title_list = []
  58. i = 0
  59. for page_num, page in pdf_info_dict.items():
  60. blocks = page["para_blocks"]
  61. for block in blocks:
  62. if block["type"] == "title":
  63. origin_title_list.append(block)
  64. title_text = merge_para_with_text(block)
  65. page_line_height_list = []
  66. for line in block['lines']:
  67. bbox = line['bbox']
  68. page_line_height_list.append(int(bbox[3] - bbox[1]))
  69. if len(page_line_height_list) > 0:
  70. line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
  71. else:
  72. line_avg_height = int(block['bbox'][3] - block['bbox'][1])
  73. title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1]
  74. i += 1
  75. # logger.info(f"Title list: {title_dict}")
  76. title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
  77. 1. 字典中每个value均为一个list,包含以下元素:
  78. - 标题文本
  79. - 文本行高是标题所在块的平均行高
  80. - 标题所在的页码
  81. 2. 保留原始内容:
  82. - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
  83. - 请务必保证输出的字典中元素的数量和输入的数量一致
  84. 3. 保持字典内key-value的对应关系不变
  85. 4. 优化层次结构:
  86. - 为每个标题元素添加适当的层次结构
  87. - 行高较大的标题一般是更高级别的标题
  88. - 标题从前至后的层级必须是连续的,不能跳过层级
  89. - 标题层级最多为4级,不要添加过多的层级
  90. - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
  91. IMPORTANT:
  92. 请直接返回优化过的由标题层级组成的json,格式如下:
  93. {{"0":1,"1":2,"2":2,"3":3}}
  94. 返回的json不需要格式化。
  95. Input title list:
  96. {title_dict}
  97. Corrected title list:
  98. """
  99. retry_count = 0
  100. max_retries = 3
  101. json_completion = None
  102. while retry_count < max_retries:
  103. try:
  104. completion = client.chat.completions.create(
  105. model=title_aided_config["model"],
  106. messages=[
  107. {'role': 'user', 'content': title_optimize_prompt}],
  108. temperature=0.7,
  109. )
  110. json_completion = json.loads(completion.choices[0].message.content)
  111. # logger.info(f"Title completion: {json_completion}")
  112. # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
  113. if len(json_completion) == len(title_dict):
  114. for i, origin_title_block in enumerate(origin_title_list):
  115. origin_title_block["level"] = int(json_completion[str(i)])
  116. break
  117. else:
  118. logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
  119. retry_count += 1
  120. except Exception as e:
  121. if isinstance(e, json.decoder.JSONDecodeError):
  122. logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
  123. else:
  124. logger.exception(e)
  125. retry_count += 1
  126. if json_completion is None:
  127. logger.error("Failed to decode JSON after maximum retries.")