浏览代码

feat: integrate LLM optimization for title enhancement in PDF processing

myhloli 4 月之前
父节点
当前提交
35cb414f1c
共有 2 个文件被更改,包括 18 次插入3 次删除
  1. 17 1
      mineru/backend/vlm/token_to_middle_json.py
  2. 1 2
      mineru/utils/llm_aided.py

+ 17 - 1
mineru/backend/vlm/token_to_middle_json.py

@@ -1,9 +1,12 @@
-import re
+import time
+from loguru import logger
 
+from mineru.utils.config_reader import get_llm_aided_config
 from mineru.utils.cut_image import cut_image_and_table
 from mineru.utils.enum_class import BlockType, ContentType
 from mineru.utils.hash_utils import str_md5
 from mineru.backend.vlm.vlm_magic_model import MagicModel
+from mineru.utils.llm_aided import llm_aided_title
 from mineru.version import __version__
 
 
@@ -48,6 +51,19 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
         image_dict = images_list[index]
         page_info = token_to_page_info(token, image_dict, page, image_writer, index)
         middle_json["pdf_info"].append(page_info)
+
+    """llm优化"""
+    llm_aided_config = get_llm_aided_config()
+
+    if llm_aided_config is not None:
+        """标题优化"""
+        title_aided_config = llm_aided_config.get('title_aided', None)
+        if title_aided_config is not None:
+            if title_aided_config.get('enable', False):
+                llm_aided_title_start_time = time.time()
+                llm_aided_title(middle_json["pdf_info"], title_aided_config)
+                logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
+
     # 关闭pdf文档
     pdf_doc.close()
     return middle_json

+ 1 - 2
mineru/utils/llm_aided.py

@@ -1,7 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 from loguru import logger
 from openai import OpenAI
-import ast
+import json_repair
 
 from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text
 
@@ -91,7 +91,6 @@ Corrected title list:
             if "</think>" in content:
                 idx = content.index("</think>") + len("</think>")
                 content = content[idx:].strip()
-            import json_repair
             dict_completion = json_repair.loads(content)
             dict_completion = {int(k): int(v) for k, v in dict_completion.items()}