Pārlūkot izejas kodu

refactor(magic_pdf): improve title optimization process

- Update instructions for AI-generated titles optimization
- Use ast.literal_eval() instead of json.loads() for parsing completion content
- Refactor variable names and logging for better code readability- Add error handling for JSON decoding issues
myhloli 9 mēneši atpakaļ
vecāks
revīzija
54940c611d
1 mainītis faili ar 14 papildinājumiem un 16 dzēšanām
  1. 14 16
      magic_pdf/post_proc/llm_aided.py

+ 14 - 16
magic_pdf/post_proc/llm_aided.py

@@ -3,6 +3,7 @@ import json
 from loguru import logger
 from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from openai import OpenAI
+import ast
 
 
 #@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
@@ -119,11 +120,12 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
     - 在完成初步分级后,仔细检查分级结果的合理性
     - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
     - 确保最终的分级结果符合文档的实际结构和逻辑
+    - 字典中包含OCR错误识别的标题,你可以通过将其层级标记为 0 来排除它们
     
 IMPORTANT: 
-请直接返回优化过的由标题层级组成的json,格式如下:
-{{"0":1,"1":2,"2":2,"3":3}}
-返回的json不需要格式化。
+请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
+{{0:1,1:2,2:2,3:3}}
+不需要对字典格式化,不需要返回任何其他信息
 
 Input title list:
 {title_dict}
@@ -133,7 +135,7 @@ Corrected title list:
 
     retry_count = 0
     max_retries = 3
-    json_completion = None
+    dict_completion = None
 
     while retry_count < max_retries:
         try:
@@ -143,24 +145,20 @@ Corrected title list:
                     {'role': 'user', 'content': title_optimize_prompt}],
                 temperature=0.7,
             )
-            json_completion = json.loads(completion.choices[0].message.content)
+            # logger.info(f"Title completion: {completion.choices[0].message.content}")
+            dict_completion = ast.literal_eval(completion.choices[0].message.content)
+            # logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
 
-            # logger.info(f"Title completion: {json_completion}")
-            # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
-
-            if len(json_completion) == len(title_dict):
+            if len(dict_completion) == len(title_dict):
                 for i, origin_title_block in enumerate(origin_title_list):
-                    origin_title_block["level"] = int(json_completion[str(i)])
+                    origin_title_block["level"] = int(dict_completion[i])
                 break
             else:
                 logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
                 retry_count += 1
         except Exception as e:
-            if isinstance(e, json.decoder.JSONDecodeError):
-                logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
-            else:
-                logger.exception(e)
+            logger.exception(e)
             retry_count += 1
 
-    if json_completion is None:
-        logger.error("Failed to decode JSON after maximum retries.")
+    if dict_completion is None:
+        logger.error("Failed to decode dict after maximum retries.")