瀏覽代碼

feat(llm): add LLM-aided formula and text correction

- Add LLM-aided formula and text correction functionality
- Update config reader to include LLM-aided settings
- Create new LLM-aided processing module
- Update main processing script to incorporate LLM-aided corrections
- Modify download scripts to check for new config version
myhloli 10 月之前
父節點
當前提交
c660fdc8f0

+ 15 - 1
magic-pdf.template.json

@@ -19,5 +19,19 @@
         "enable": false,
         "max_time": 400
     },
-    "config_version": "1.0.0"
+    "llm-aided-config": {
+        "formula_aided": {
+            "api_key": "your_api_key",
+            "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+            "model": "qwen2.5-72b-instruct",
+            "enable": false
+        },
+        "text_aided": {
+            "api_key": "your_api_key",
+            "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+            "model": "qwen2.5-7b-instruct",
+            "enable": false
+        }
+    },
+    "config_version": "1.1.0"
 }

+ 1 - 1
magic_pdf/dict2md/ocr_mkcontent.py

@@ -7,7 +7,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.libs.commons import join_path
 from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
-from magic_pdf.para.para_split_v3 import ListLineTag
+from magic_pdf.post_proc.para_split_v3 import ListLineTag
 
 
 def __is_hyphen_at_line_end(line):

+ 9 - 0
magic_pdf/libs/config_reader.py

@@ -116,6 +116,15 @@ def get_formula_config():
     else:
         return formula_config
 
+def get_llm_aided_config():
+    config = read_config()
+    llm_aided_config = config.get('llm-aided-config')
+    if llm_aided_config is None:
+        logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
+        return None
+    else:
+        return llm_aided_config
+
 
 if __name__ == '__main__':
     ak, sk, endpoint = get_s3_config('llm-raw')

+ 0 - 0
magic_pdf/para/__init__.py


+ 15 - 2
magic_pdf/pdf_parse_union_core_v2.py

@@ -14,11 +14,12 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.data.dataset import Dataset, PageableData
 from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
 from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
+from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
 from magic_pdf.model.magic_model import MagicModel
+from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text
 
 try:
     import torchtext
@@ -29,7 +30,7 @@ except ImportError:
     pass
 
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
-from magic_pdf.para.para_split_v3 import para_split
+from magic_pdf.post_proc.para_split_v3 import para_split
 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
 from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
 from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
@@ -828,6 +829,18 @@ def pdf_parse_union(
     """分段"""
     para_split(pdf_info_dict)
 
+    """llm优化"""
+    llm_aided_config = get_llm_aided_config()
+    if llm_aided_config is not None:
+        """公式优化"""
+        formula_aided_config = llm_aided_config.get('formula_aided', None)
+        if formula_aided_config is not None:
+            llm_aided_formula(pdf_info_dict, formula_aided_config)
+        """文本优化"""
+        text_aided_config = llm_aided_config.get('text_aided', None)
+        if text_aided_config is not None:
+            llm_aided_text(pdf_info_dict, text_aided_config)
+
     """dict转list"""
     pdf_info_list = dict_to_list(pdf_info_dict)
     new_pdf_info_dict = {

+ 64 - 0
magic_pdf/post_proc/llm_aided.py

@@ -0,0 +1,64 @@
+# Copyright (c) Opendatalab. All rights reserved.
+
+formula_correction_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
+
+1. 修正渲染或编译错误:
+    - Some syntax errors such as mismatched/missing/extra tokens. Your task is to fix these syntax errors and make sure corrected results conform to latex math syntax principles.
+    - 包含KaTeX不支持的关键词等原因导致的无法编译或渲染的错误
+
+2. 保留原始信息:
+   - 保留原始公式中的所有重要信息
+   - 不要添加任何原始公式中没有的新信息
+
+IMPORTANT:请仅返回修正后的公式,不要包含任何介绍、解释或元数据。
+
+LaTeX recognition result:
+$FORMULA
+
+Your corrected result:
+"""
+
+text_correction_prompt = f"""请根据以下指南修正OCR引起的错误,确保文本连贯并符合原始内容:
+
+1. 修正OCR引起的拼写错误和错误:
+   - 修正常见的OCR错误(例如,'rn' 被误读为 'm')
+   - 使用上下文和常识进行修正
+   - 只修正明显的错误,不要不必要的修改内容
+   - 不要添加额外的句号或其他不必要的标点符号
+
+2. 保持原始结构:
+   - 保留所有标题和子标题
+
+3. 保留原始内容:
+   - 保留原始文本中的所有重要信息
+   - 不要添加任何原始文本中没有的新信息
+   - 保留段落之间的换行符
+
+4. 保持连贯性:
+   - 确保内容与前文顺畅连接
+   - 适当处理在句子中间开始或结束的文本
+   
+5. 修正行内公式:
+   - 去除行内公式前后多余的空格
+   - 修正公式中的OCR错误
+   - 确保公式能够通过KaTeX渲染
+   
+6. 修正全角字符
+    - 修正全角标点符号为半角标点符号
+    - 修正全角字母为半角字母
+    - 修正全角数字为半角数字
+
+IMPORTANT:请仅返回修正后的文本,保留所有原始格式,包括换行符。不要包含任何介绍、解释或元数据。
+
+Previous context:
+
+Current chunk to process:
+
+Corrected text:
+"""
+
+def llm_aided_formula(pdf_info_dict, formula_aided_config):
+    pass
+
+def llm_aided_text(pdf_info_dict, text_aided_config):
+    pass

+ 0 - 0
magic_pdf/para/para_split_v3.py → magic_pdf/post_proc/para_split_v3.py


+ 1 - 1
scripts/download_models.py

@@ -16,7 +16,7 @@ def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
         config_version = data.get('config_version', '0.0.0')
-        if config_version < '1.0.0':
+        if config_version < '1.1.0':
             data = download_json(url)
     else:
         data = download_json(url)

+ 1 - 1
scripts/download_models_hf.py

@@ -16,7 +16,7 @@ def download_and_modify_json(url, local_filename, modifications):
     if os.path.exists(local_filename):
         data = json.load(open(local_filename))
         config_version = data.get('config_version', '0.0.0')
-        if config_version < '1.0.0':
+        if config_version < '1.1.0':
             data = download_json(url)
     else:
         data = download_json(url)