llm_aided.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. from loguru import logger
  3. from openai import OpenAI
  4. import json_repair
  5. from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text
  6. def llm_aided_title(page_info_list, title_aided_config):
  7. client = OpenAI(
  8. api_key=title_aided_config["api_key"],
  9. base_url=title_aided_config["base_url"],
  10. )
  11. title_dict = {}
  12. origin_title_list = []
  13. i = 0
  14. for page_info in page_info_list:
  15. blocks = page_info["para_blocks"]
  16. for block in blocks:
  17. if block["type"] == "title":
  18. origin_title_list.append(block)
  19. title_text = merge_para_with_text(block)
  20. if 'line_avg_height' in block:
  21. line_avg_height = block['line_avg_height']
  22. else:
  23. title_block_line_height_list = []
  24. for line in block['lines']:
  25. bbox = line['bbox']
  26. title_block_line_height_list.append(int(bbox[3] - bbox[1]))
  27. if len(title_block_line_height_list) > 0:
  28. line_avg_height = sum(title_block_line_height_list) / len(title_block_line_height_list)
  29. else:
  30. line_avg_height = int(block['bbox'][3] - block['bbox'][1])
  31. title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
  32. i += 1
  33. # logger.info(f"Title list: {title_dict}")
  34. title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
  35. 1. 字典中每个value均为一个list,包含以下元素:
  36. - 标题文本
  37. - 文本行高是标题所在块的平均行高
  38. - 标题所在的页码
  39. 2. 保留原始内容:
  40. - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
  41. - 请务必保证输出的字典中元素的数量和输入的数量一致
  42. 3. 保持字典内key-value的对应关系不变
  43. 4. 优化层次结构:
  44. - 根据标题内容的语义为每个标题元素添加适当的层次结构
  45. - 行高较大的标题一般是更高级别的标题
  46. - 标题从前至后的层级必须是连续的,不能跳过层级
  47. - 标题层级最多为4级,不要添加过多的层级
  48. - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
  49. 5. 合理性检查与微调:
  50. - 在完成初步分级后,仔细检查分级结果的合理性
  51. - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
  52. - 确保最终的分级结果符合文档的实际结构和逻辑
  53. IMPORTANT:
  54. 请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
  55. {{
  56. 0:1,
  57. 1:2,
  58. 2:2,
  59. 3:3
  60. }}
  61. 不需要对字典格式化,不需要返回任何其他信息。
  62. Input title list:
  63. {title_dict}
  64. Corrected title list:
  65. """
  66. #5.
  67. #- 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
  68. retry_count = 0
  69. max_retries = 3
  70. dict_completion = None
  71. # Build API call parameters
  72. api_params = {
  73. "model": title_aided_config["model"],
  74. "messages": [{'role': 'user', 'content': title_optimize_prompt}],
  75. "temperature": 0.7,
  76. "stream": True,
  77. }
  78. # Only add extra_body when explicitly specified in config
  79. if "enable_thinking" in title_aided_config:
  80. api_params["extra_body"] = {"enable_thinking": title_aided_config["enable_thinking"]}
  81. while retry_count < max_retries:
  82. try:
  83. completion = client.chat.completions.create(**api_params)
  84. content_pieces = []
  85. for chunk in completion:
  86. if chunk.choices and chunk.choices[0].delta.content is not None:
  87. content_pieces.append(chunk.choices[0].delta.content)
  88. content = "".join(content_pieces).strip()
  89. # logger.info(f"Title completion: {content}")
  90. if "</think>" in content:
  91. idx = content.index("</think>") + len("</think>")
  92. content = content[idx:].strip()
  93. dict_completion = json_repair.loads(content)
  94. dict_completion = {int(k): int(v) for k, v in dict_completion.items()}
  95. # logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
  96. if len(dict_completion) == len(title_dict):
  97. for i, origin_title_block in enumerate(origin_title_list):
  98. origin_title_block["level"] = int(dict_completion[i])
  99. break
  100. else:
  101. logger.warning(
  102. "The number of titles in the optimized result is not equal to the number of titles in the input.")
  103. retry_count += 1
  104. except Exception as e:
  105. logger.exception(e)
  106. retry_count += 1
  107. if dict_completion is None:
  108. logger.error("Failed to decode dict after maximum retries.")