Browse Source

refactor(magic_pdf): improve title optimization process

- Update instructions for AI-generated titles optimization
- Use ast.literal_eval() instead of json.loads() for parsing completion content
- Refactor variable names and logging for better code readability- Add error handling for JSON decoding issues
myhloli 9 months ago
parent
commit
54940c611d
1 changed files with 14 additions and 16 deletions
  1. 14 16
      magic_pdf/post_proc/llm_aided.py

+ 14 - 16
magic_pdf/post_proc/llm_aided.py

@@ -3,6 +3,7 @@ import json
 from loguru import logger
 from loguru import logger
 from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from openai import OpenAI
 from openai import OpenAI
+import ast
 
 
 
 
 #@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
 #@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
@@ -119,11 +120,12 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
     - 在完成初步分级后,仔细检查分级结果的合理性
     - 在完成初步分级后,仔细检查分级结果的合理性
     - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
     - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
     - 确保最终的分级结果符合文档的实际结构和逻辑
     - 确保最终的分级结果符合文档的实际结构和逻辑
+    - 字典中包含OCR错误识别的标题,你可以通过将其层级标记为 0 来排除它们
     
     
 IMPORTANT: 
 IMPORTANT: 
-请直接返回优化过的由标题层级组成的json,格式如下:
-{{"0":1,"1":2,"2":2,"3":3}}
-返回的json不需要格式化。
+请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
+{{0:1,1:2,2:2,3:3}}
+不需要对字典格式化,不需要返回任何其他信息
 
 
 Input title list:
 Input title list:
 {title_dict}
 {title_dict}
@@ -133,7 +135,7 @@ Corrected title list:
 
 
     retry_count = 0
     retry_count = 0
     max_retries = 3
     max_retries = 3
-    json_completion = None
+    dict_completion = None
 
 
     while retry_count < max_retries:
     while retry_count < max_retries:
         try:
         try:
@@ -143,24 +145,20 @@ Corrected title list:
                     {'role': 'user', 'content': title_optimize_prompt}],
                     {'role': 'user', 'content': title_optimize_prompt}],
                 temperature=0.7,
                 temperature=0.7,
             )
             )
-            json_completion = json.loads(completion.choices[0].message.content)
+            # logger.info(f"Title completion: {completion.choices[0].message.content}")
+            dict_completion = ast.literal_eval(completion.choices[0].message.content)
+            # logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
 
 
-            # logger.info(f"Title completion: {json_completion}")
-            # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
-
-            if len(json_completion) == len(title_dict):
+            if len(dict_completion) == len(title_dict):
                 for i, origin_title_block in enumerate(origin_title_list):
                 for i, origin_title_block in enumerate(origin_title_list):
-                    origin_title_block["level"] = int(json_completion[str(i)])
+                    origin_title_block["level"] = int(dict_completion[i])
                 break
                 break
             else:
             else:
                 logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
                 logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
                 retry_count += 1
                 retry_count += 1
         except Exception as e:
         except Exception as e:
-            if isinstance(e, json.decoder.JSONDecodeError):
-                logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
-            else:
-                logger.exception(e)
+            logger.exception(e)
             retry_count += 1
             retry_count += 1
 
 
-    if json_completion is None:
-        logger.error("Failed to decode JSON after maximum retries.")
+    if dict_completion is None:
+        logger.error("Failed to decode dict after maximum retries.")