Bladeren bron

feat(llm_aided): add title optimization feature

- Implement llm_aided_title function to optimize document titles using LLM
- Update pdf_parse_union_core_v2.py to include title optimization
- Modify ocr_mkcontent.py to use optimized title levels- Add openai SDK dependency in setup.py
myhloli 10 maanden geleden
bovenliggende
commit
0a468eca6e
5 gewijzigde bestanden met toevoegingen van 97 en 7 verwijderingen
  1. 7 1
      magic-pdf.template.json
  2. 13 2
      magic_pdf/dict2md/ocr_mkcontent.py
  3. 5 1
      magic_pdf/pdf_parse_union_core_v2.py
  4. 71 3
      magic_pdf/post_proc/llm_aided.py
  5. 1 0
      setup.py

+ 7 - 1
magic-pdf.template.json

@@ -23,7 +23,7 @@
         "formula_aided": {
             "api_key": "your_api_key",
             "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
-            "model": "qwen2.5-72b-instruct",
+            "model": "qwen2.5-7b-instruct",
             "enable": false
         },
         "text_aided": {
@@ -31,6 +31,12 @@
             "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
             "model": "qwen2.5-7b-instruct",
             "enable": false
+        },
+        "title_aided": {
+            "api_key": "your_api_key",
+            "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+            "model": "qwen2.5-32b-instruct",
+            "enable": false
         }
     },
     "config_version": "1.1.0"

+ 13 - 2
magic_pdf/dict2md/ocr_mkcontent.py

@@ -61,7 +61,8 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
         if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
             para_text = merge_para_with_text(para_block)
         elif para_type == BlockType.Title:
-            para_text = f'# {merge_para_with_text(para_block)}'
+            title_level = get_title_level(para_block)
+            para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
         elif para_type == BlockType.InterlineEquation:
             para_text = merge_para_with_text(para_block)
         elif para_type == BlockType.Image:
@@ -186,10 +187,11 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
             'text': merge_para_with_text(para_block),
         }
     elif para_type == BlockType.Title:
+        title_level = get_title_level(para_block)
         para_content = {
             'type': 'text',
             'text': merge_para_with_text(para_block),
-            'text_level': 1,
+            'text_level': title_level,
         }
     elif para_type == BlockType.InterlineEquation:
         para_content = {
@@ -289,3 +291,12 @@ def union_make(pdf_info_dict: list,
         return '\n\n'.join(output_content)
     elif make_mode == MakeMode.STANDARD_FORMAT:
         return output_content
+
+
+def get_title_level(block):
+    title_level = block.get('level', 1)
+    if title_level > 4:
+        title_level = 4
+    elif title_level < 1:
+        title_level = 1
+    return title_level

+ 5 - 1
magic_pdf/pdf_parse_union_core_v2.py

@@ -19,7 +19,7 @@ from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
 from magic_pdf.model.magic_model import MagicModel
-from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text
+from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
 
 try:
     import torchtext
@@ -846,6 +846,10 @@ def pdf_parse_union(
         text_aided_config = llm_aided_config.get('text_aided', None)
         if text_aided_config is not None:
             llm_aided_text(pdf_info_dict, text_aided_config)
+        """标题优化"""
+        title_aided_config = llm_aided_config.get('title_aided', None)
+        if title_aided_config is not None:
+            llm_aided_title(pdf_info_dict, title_aided_config)
 
     """dict转list"""
     pdf_info_list = dict_to_list(pdf_info_dict)

+ 71 - 3
magic_pdf/post_proc/llm_aided.py

@@ -1,6 +1,11 @@
 # Copyright (c) Opendatalab. All rights reserved.
+import json
+from loguru import logger
+from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
+from openai import OpenAI
 
-formula_correction_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
+
+formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
 
 1. 修正渲染或编译错误:
     - Some syntax errors such as mismatched/missing/extra tokens. Your task is to fix these syntax errors and make sure corrected results conform to latex math syntax principles.
@@ -18,7 +23,7 @@ $FORMULA
 Your corrected result:
 """
 
-text_correction_prompt = f"""请根据以下指南修正OCR引起的错误,确保文本连贯并符合原始内容:
+text_optimize_prompt = f"""请根据以下指南修正OCR引起的错误,确保文本连贯并符合原始内容:
 
 1. 修正OCR引起的拼写错误和错误:
    - 修正常见的OCR错误(例如,'rn' 被误读为 'm')
@@ -61,4 +66,67 @@ def llm_aided_formula(pdf_info_dict, formula_aided_config):
     pass
 
 def llm_aided_text(pdf_info_dict, text_aided_config):
-    pass
+    pass
+
+def llm_aided_title(pdf_info_dict, title_aided_config):
+    client = OpenAI(
+        api_key=title_aided_config["api_key"],
+        base_url=title_aided_config["base_url"],
+    )
+    title_dict = {}
+    origin_title_list = []
+    i = 0
+    for page_num, page in pdf_info_dict.items():
+        blocks = page["para_blocks"]
+        for block in blocks:
+            if block["type"] == "title":
+                origin_title_list.append(block)
+                title_text = merge_para_with_text(block)
+                title_dict[f"{i}"] = title_text
+                i += 1
+    logger.info(f"Title list: {title_dict}")
+
+    title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
+
+1. 保留原始内容:
+    - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
+    - 请务必保证输出的字典中元素的数量和输入的数量一致
+
+2. 保持字典内key-value的对应关系不变
+
+3. 优化层次结构:
+    - 为每个标题元素添加适当的层次结构
+    - 标题层级应具有连续性,不能跳过某一层级
+    - 标题层级最多为4级,不要添加过多的层级
+    - 优化后的标题为一个整数,代表该标题的层级
+
+IMPORTANT: 
+请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。
+
+Input title list:
+{title_dict}
+
+Corrected title list:
+"""
+
+    completion = client.chat.completions.create(
+        model=title_aided_config["model"],
+        messages=[
+            {'role': 'user', 'content': title_optimize_prompt}],
+        temperature=0.7,
+    )
+
+    json_completion = json.loads(completion.choices[0].message.content)
+
+    logger.info(f"Title completion: {json_completion}")
+
+    logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
+    if len(json_completion) == len(title_dict):
+        try:
+            for i, origin_title_block in enumerate(origin_title_list):
+               origin_title_block["level"] = int(json_completion[str(i)])
+        except Exception as e:
+            logger.exception(e)
+    else:
+        logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.")
+

+ 1 - 0
setup.py

@@ -52,6 +52,7 @@ if __name__ == '__main__':
                      "rapidocr-paddle",  # rapidocr-paddle
                      "rapid_table",  # rapid_table
                      "PyYAML",  # yaml
+                     "openai",  # openai SDK
                      "detectron2"
                      ],
             "old_linux":[