Explorar el Código

refactor: add LLM-aided title optimization and improve config handling

myhloli hace 5 meses
padre
commit
7d4ce0c380

+ 11 - 1
mineru/backend/pipeline/config_reader.py

@@ -124,4 +124,14 @@ def get_latex_delimiter_config():
         logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
         return None
     else:
-        return latex_delimiter_config
+        return latex_delimiter_config
+
+
+def get_llm_aided_config():
+    config = read_config()
+    llm_aided_config = config.get('llm-aided-config')
+    if llm_aided_config is None:
+        logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
+        return None
+    else:
+        return llm_aided_config

+ 18 - 1
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -1,10 +1,15 @@
 # Copyright (c) Opendatalab. All rights reserved.
-from mineru.backend.pipeline.config_reader import get_device
+import time
+
+from loguru import logger
+
+from mineru.backend.pipeline.config_reader import get_device, get_llm_aided_config
 from mineru.backend.pipeline.model_init import AtomModelSingleton
 from mineru.backend.pipeline.para_split import para_split
 from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
 from mineru.utils.block_sort import sort_blocks_by_bbox
 from mineru.utils.cut_image import cut_image_and_table
+from mineru.utils.llm_aided import llm_aided_title
 from mineru.utils.model_utils import clean_memory
 from mineru.utils.pipeline_magic_model import MagicModel
 from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
@@ -169,6 +174,18 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N
     """分段"""
     para_split(middle_json["pdf_info"])
 
+    """llm优化"""
+    llm_aided_config = get_llm_aided_config()
+
+    if llm_aided_config is not None:
+        """标题优化"""
+        title_aided_config = llm_aided_config.get('title_aided', None)
+        if title_aided_config is not None:
+            if title_aided_config.get('enable', False):
+                llm_aided_title_start_time = time.time()
+                llm_aided_title(middle_json["pdf_info"], title_aided_config)
+                logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
+
     clean_memory(get_device())
 
     return middle_json

+ 1 - 2
mineru/cli/common.py

@@ -215,8 +215,7 @@ def do_parse(
 
 
 if __name__ == "__main__":
-    # pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf"
-    pdf_path = "../../demo/pdfs/demo1.pdf"
+    pdf_path = "../../demo/pdfs/demo2.pdf"
     with open(pdf_path, "rb") as f:
         try:
            do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,)

+ 101 - 0
mineru/utils/llm_aided.py

@@ -0,0 +1,101 @@
+# Copyright (c) Opendatalab. All rights reserved.
+from loguru import logger
+from openai import OpenAI
+import ast
+
+from mineru.api.pipeline_middle_json_mkcontent import merge_para_with_text
+
+
+def llm_aided_title(page_info_list, title_aided_config):
+    client = OpenAI(
+        api_key=title_aided_config["api_key"],
+        base_url=title_aided_config["base_url"],
+    )
+    title_dict = {}
+    origin_title_list = []
+    i = 0
+    for page_info in page_info_list:
+        blocks = page_info["para_blocks"]
+        for block in blocks:
+            if block["type"] == "title":
+                origin_title_list.append(block)
+                title_text = merge_para_with_text(block)
+                page_line_height_list = []
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    page_line_height_list.append(int(bbox[3] - bbox[1]))
+                if len(page_line_height_list) > 0:
+                    line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
+                else:
+                    line_avg_height = int(block['bbox'][3] - block['bbox'][1])
+                title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
+                i += 1
+    # logger.info(f"Title list: {title_dict}")
+
+    title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
+
+1. 字典中每个value均为一个list,包含以下元素:
+    - 标题文本
+    - 文本行高是标题所在块的平均行高
+    - 标题所在的页码
+
+2. 保留原始内容:
+    - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
+    - 请务必保证输出的字典中元素的数量和输入的数量一致
+
+3. 保持字典内key-value的对应关系不变
+
+4. 优化层次结构:
+    - 为每个标题元素添加适当的层次结构
+    - 行高较大的标题一般是更高级别的标题
+    - 标题从前至后的层级必须是连续的,不能跳过层级
+    - 标题层级最多为4级,不要添加过多的层级
+    - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
+
+5. 合理性检查与微调:
+    - 在完成初步分级后,仔细检查分级结果的合理性
+    - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
+    - 确保最终的分级结果符合文档的实际结构和逻辑
+    - 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
+
+IMPORTANT: 
+请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
+{{0:1,1:2,2:2,3:3}}
+不需要对字典格式化,不需要返回任何其他信息。
+
+Input title list:
+{title_dict}
+
+Corrected title list:
+"""
+
+    retry_count = 0
+    max_retries = 3
+    dict_completion = None
+
+    while retry_count < max_retries:
+        try:
+            completion = client.chat.completions.create(
+                model=title_aided_config["model"],
+                messages=[
+                    {'role': 'user', 'content': title_optimize_prompt}],
+                temperature=0.7,
+            )
+            # logger.info(f"Title completion: {completion.choices[0].message.content}")
+            dict_completion = ast.literal_eval(completion.choices[0].message.content)
+            # logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
+
+            if len(dict_completion) == len(title_dict):
+                for i, origin_title_block in enumerate(origin_title_list):
+                    origin_title_block["level"] = int(dict_completion[i])
+                break
+            else:
+                logger.warning(
+                    "The number of titles in the optimized result is not equal to the number of titles in the input.")
+                retry_count += 1
+        except Exception as e:
+            logger.exception(e)
+            retry_count += 1
+
+    if dict_completion is None:
+        logger.error("Failed to decode dict after maximum retries.")