Ver código fonte

feat: 添加保存API输出图像和Markdown图像的功能,优化JSON和Markdown文件的保存逻辑

zhch158_admin 1 mês atrás
pai
commit
d6084f033b
1 arquivos alterados com 79 adições e 7 exclusões
  1. 79 7
      zhch/ppstructurev3_single_client.py

+ 79 - 7
zhch/ppstructurev3_single_client.py

@@ -253,20 +253,65 @@ def convert_api_result_to_json(api_result: Dict[str, Any],
         }
     
     # 保存JSON文件
-    output_path = Path(output_dir).resolve() / f"{filename}.json"
-    output_path.parent.mkdir(parents=True, exist_ok=True)
+    output_path = Path(output_dir).resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
     
-    with open(output_path, 'w', encoding='utf-8') as f:
+    json_file_path = output_path / f"{filename}.json"
+    with open(json_file_path, 'w', encoding='utf-8') as f:
         json.dump(converted_json, f, ensure_ascii=False, indent=2)
     
     # 如果启用了标准化且有变化,保存原始版本用于对比
     if normalize_numbers and changes_count > 0:
-        original_output_path = output_path.parent / f"{output_path.stem}_original.json"
+        original_output_path = output_path / f"{filename}_original.json"
         with open(original_output_path, 'w', encoding='utf-8') as f:
             json.dump(original_json, f, ensure_ascii=False, indent=2)
     
     return str(output_path), converted_json
 
+def save_output_images(api_result: Dict[str, Any], output_dir: str, output_filename: str) -> Dict[str, str]:
+    """
+    保存API返回的输出图像
+    
+    Args:
+        api_result: API返回的结果
+        output_dir: 输出目录
+        
+    Returns:
+        保存的图像文件路径字典
+    """
+    layout_parsing_results = api_result.get('layoutParsingResults', [])
+    if not layout_parsing_results:
+        return {}
+    
+    main_result = layout_parsing_results[0]
+    output_images = main_result.get('outputImages', {})
+    
+    output_path = Path(output_dir).resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    saved_images = {}
+    
+    for img_name, img_base64 in output_images.items():
+        try:
+            # 解码base64图像
+            img_data = base64.b64decode(img_base64)
+            
+            # 生成文件名
+            img_filename = f"{output_filename}_{img_name}.jpg"
+            img_path = output_path / img_filename
+            
+            # 保存图像
+            with open(img_path, 'wb') as f:
+                f.write(img_data)
+            
+            saved_images[img_name] = str(img_path)
+            # print(f"📷 Saved image: {img_path}")
+            
+        except Exception as e:
+            print(f"❌ Error saving image {img_name}: {e}")
+    
+    return saved_images
+
 def save_markdown_content(api_result: Dict[str, Any], output_dir: str, 
                          filename: str, normalize_numbers: bool = True) -> str:
     """
@@ -286,6 +331,7 @@ def save_markdown_content(api_result: Dict[str, Any], output_dir: str,
     markdown_text = markdown_data.get('text', '')
     
     # 数字标准化处理
+    changes_count = 0
     if normalize_numbers and markdown_text:
         original_markdown_text = markdown_text
         markdown_text = normalize_markdown_table(markdown_text)
@@ -300,10 +346,26 @@ def save_markdown_content(api_result: Dict[str, Any], output_dir: str,
     
     # 如果启用了标准化且有变化,保存原始版本用于对比
     if normalize_numbers and changes_count > 0:
-        original_output_path = output_path.parent / f"{output_path.stem}_original.json"
+        original_output_path = output_path / f"{filename}_original.md"
         with open(original_output_path, 'w', encoding='utf-8') as f:
             f.write(original_markdown_text)
 
+    # 保存Markdown中的图像
+    markdown_images = markdown_data.get('images', {})
+    for img_path, img_base64 in markdown_images.items():
+        try:
+            img_data = base64.b64decode(img_base64)
+            full_img_path = output_path / img_path
+            full_img_path.parent.mkdir(parents=True, exist_ok=True)
+            
+            with open(full_img_path, 'wb') as f:
+                f.write(img_data)
+            
+            # print(f"🖼️ Saved Markdown image: {full_img_path}")
+            
+        except Exception as e:
+            print(f"❌ Error saving Markdown image {img_path}: {e}")
+
     return str(md_file_path)
 
 def call_api_for_image(image_path: str, api_url: str, timeout: int = 300) -> Dict[str, Any]:
@@ -327,6 +389,14 @@ def call_api_for_image(image_path: str, api_url: str, timeout: int = 300) -> Dic
         payload = {
             "file": image_data,
             "fileType": 1,
+            # 添加管道参数设置
+            "useDocOrientationClassify": True,
+            "useDocUnwarping": False,
+            "useSealRecognition": True,
+            "useTableRecognition": True,
+            "useFormulaRecognition": False,  # 避免公式识别的索引错误
+            "useChartRecognition": True,
+            "useRegionDetection": False,
         }
 
         # 调用API
@@ -398,7 +468,10 @@ def process_images_via_api(image_paths: List[str],
                     output_filename,
                     normalize_numbers=normalize_numbers
                 )
-                
+
+                # 保存输出图像
+                saved_images = save_output_images(api_result, str(output_dir), output_filename) 
+
                 # 保存Markdown内容
                 md_output_path = save_markdown_content(
                     api_result, 
@@ -576,7 +649,6 @@ def main():
 
 if __name__ == "__main__":
     print(f"🚀 启动PP-StructureV3 API客户端...")
-    print(f"🔧 环境变量检查: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
     
     if len(sys.argv) == 1:
         # 如果没有命令行参数,使用默认配置运行