Procházet zdrojové kódy

feat: 添加PP-StructureV3 API客户端,支持标准化JSON格式和批量处理功能

zhch158_admin před 1 měsícem
rodič
revize
f592aedbfb
1 změnil soubory, kde provedl 468 přidání a 0 odebrání
  1. 468 0
      zhch/test_ppstructure_v3_client.py

+ 468 - 0
zhch/test_ppstructure_v3_client.py

@@ -0,0 +1,468 @@
+import json
+import base64
+from pathlib import Path
+from typing import Dict, Any, List
+import requests
+from utils.normalize_financial_numbers import normalize_financial_numbers, normalize_markdown_table, normalize_json_table
+
+def convert_api_result_to_json(api_result: Dict[str, Any], 
+                              input_image_path: str, 
+                              output_json_path: str,
+                              normalize_numbers: bool = True) -> Dict[str, Any]:
+    """
+    将API返回结果转换为标准JSON格式,并支持数字标准化
+    
+    Args:
+        api_result: API返回的结果
+        input_image_path: 输入图像路径
+        output_json_path: 输出JSON文件路径
+        normalize_numbers: 是否标准化数字格式,默认True
+        
+    Returns:
+        转换后的JSON数据
+    """
+    
+    # 获取主要数据
+    layout_parsing_results = api_result.get('layoutParsingResults', [])
+    data_info = api_result.get('dataInfo', {})
+    
+    if not layout_parsing_results:
+        print("⚠️ Warning: No layoutParsingResults found in API response")
+        return {}
+    
+    # 取第一个结果(通常只有一个)
+    main_result = layout_parsing_results[0]
+    pruned_result = main_result.get('prunedResult', {})
+    
+    # 构造标准格式的JSON
+    converted_json = {
+        "input_path": input_image_path,
+        "page_index": None,
+        "model_settings": pruned_result.get('model_settings', {}),
+        "parsing_res_list": pruned_result.get('parsing_res_list', []),
+        "doc_preprocessor_res": {
+            "input_path": None,
+            "page_index": None,
+            "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
+            "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
+        },
+        "layout_det_res": {
+            "input_path": None,
+            "page_index": None,
+            "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
+        },
+        "overall_ocr_res": {
+            "input_path": None,
+            "page_index": None,
+            "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
+            "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
+            "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
+            "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
+            "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
+            "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
+            "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
+            "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
+            "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
+            "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
+            "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
+        },
+        "table_res_list": pruned_result.get('table_res_list', [])
+    }
+    
+    # 数字标准化处理
+    original_json = converted_json.copy()
+    changes_count = 0
+    
+    if normalize_numbers:
+        print("🔧 正在标准化数字格式...")
+        
+        # 1. 标准化 parsing_res_list 中的文本内容
+        for item in converted_json.get('parsing_res_list', []):
+            if 'block_content' in item:
+                original_content = item['block_content']
+                
+                # 根据block_label类型选择标准化方法
+                if item.get('block_label') == 'table':
+                    # 表格内容使用表格专用标准化
+                    normalized_content = normalize_markdown_table(original_content)
+                else:
+                    # 普通文本使用通用标准化
+                    normalized_content = normalize_financial_numbers(original_content)
+                
+                if original_content != normalized_content:
+                    item['block_content'] = normalized_content
+                    changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
+        
+        # 2. 标准化 table_res_list 中的HTML表格
+        for table_item in converted_json.get('table_res_list', []):
+            if 'pred_html' in table_item:
+                original_html = table_item['pred_html']
+                normalized_html = normalize_markdown_table(original_html)
+                
+                if original_html != normalized_html:
+                    table_item['pred_html'] = normalized_html
+                    changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
+        
+        # 3. 标准化 overall_ocr_res 中的识别文本
+        ocr_res = converted_json.get('overall_ocr_res', {})
+        if 'rec_texts' in ocr_res:
+            original_texts = ocr_res['rec_texts'][:]
+            normalized_texts = []
+            
+            for text in original_texts:
+                normalized_text = normalize_financial_numbers(text)
+                normalized_texts.append(normalized_text)
+                if text != normalized_text:
+                    changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
+            
+            ocr_res['rec_texts'] = normalized_texts
+        
+        # 添加标准化处理信息
+        converted_json['processing_info'] = {
+            "normalize_numbers": normalize_numbers,
+            "changes_applied": changes_count > 0,
+            "character_changes_count": changes_count
+        }
+        
+        if changes_count > 0:
+            print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
+        else:
+            print("ℹ️ 无需标准化(已是标准格式)")
+    else:
+        converted_json['processing_info'] = {
+            "normalize_numbers": False,
+            "changes_applied": False,
+            "character_changes_count": 0
+        }
+    
+    # 保存JSON文件
+    output_path = Path(output_json_path)
+    output_path.parent.mkdir(parents=True, exist_ok=True)
+    
+    with open(output_path, 'w', encoding='utf-8') as f:
+        json.dump(converted_json, f, ensure_ascii=False, indent=4)
+    
+    # 如果启用了标准化且有变化,保存原始版本用于对比
+    if normalize_numbers and changes_count > 0:
+        original_output_path = output_path.parent / f"{output_path.stem}_original.json"
+        with open(original_output_path, 'w', encoding='utf-8') as f:
+            json.dump(original_json, f, ensure_ascii=False, indent=4)
+        print(f"📄 原始JSON已保存到: {original_output_path}")
+    
+    print(f"✅ Converted JSON saved to: {output_path}")
+    return converted_json
+
+def save_output_images(api_result: Dict[str, Any], output_dir: str) -> Dict[str, str]:
+    """
+    保存API返回的输出图像
+    
+    Args:
+        api_result: API返回的结果
+        output_dir: 输出目录
+        
+    Returns:
+        保存的图像文件路径字典
+    """
+    layout_parsing_results = api_result.get('layoutParsingResults', [])
+    if not layout_parsing_results:
+        return {}
+    
+    main_result = layout_parsing_results[0]
+    output_images = main_result.get('outputImages', {})
+    
+    output_dir = Path(output_dir)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    
+    saved_images = {}
+    
+    for img_name, img_base64 in output_images.items():
+        try:
+            # 解码base64图像
+            img_data = base64.b64decode(img_base64)
+            
+            # 生成文件名
+            img_filename = f"{img_name}.jpg"
+            img_path = output_dir / img_filename
+            
+            # 保存图像
+            with open(img_path, 'wb') as f:
+                f.write(img_data)
+            
+            saved_images[img_name] = str(img_path)
+            print(f"📷 Saved image: {img_path}")
+            
+        except Exception as e:
+            print(f"❌ Error saving image {img_name}: {e}")
+    
+    return saved_images
+
+def save_markdown_content(api_result: Dict[str, Any], output_dir: str, normalize_numbers: bool = True) -> str:
+    """
+    保存Markdown内容和相关图像,支持数字标准化
+    
+    Args:
+        api_result: API返回的结果
+        output_dir: 输出目录
+        normalize_numbers: 是否标准化数字格式
+        
+    Returns:
+        Markdown文件路径
+    """
+    layout_parsing_results = api_result.get('layoutParsingResults', [])
+    if not layout_parsing_results:
+        return ""
+    
+    main_result = layout_parsing_results[0]
+    markdown_data = main_result.get('markdown', {})
+    
+    output_dir = Path(output_dir)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    
+    # 保存Markdown文本
+    markdown_text = markdown_data.get('text', '')
+    original_markdown_text = markdown_text
+    
+    # 数字标准化处理
+    if normalize_numbers:
+        print("🔧 正在标准化Markdown中的数字格式...")
+        markdown_text = normalize_markdown_table(markdown_text)
+        
+        changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
+        if changes_count > 0:
+            print(f"✅ Markdown中已标准化 {changes_count} 个字符(全角→半角)")
+        else:
+            print("ℹ️ Markdown无需标准化(已是标准格式)")
+    
+    md_file_path = output_dir / 'document.md'
+    
+    with open(md_file_path, 'w', encoding='utf-8') as f:
+        f.write(markdown_text)
+    
+    # 如果启用了标准化且有变化,保存原始版本
+    if normalize_numbers and markdown_text != original_markdown_text:
+        original_md_path = output_dir / 'document_original.md'
+        with open(original_md_path, 'w', encoding='utf-8') as f:
+            f.write(original_markdown_text)
+        print(f"📄 原始Markdown已保存到: {original_md_path}")
+    
+    print(f"📝 Saved Markdown: {md_file_path}")
+    
+    # 保存Markdown中的图像
+    markdown_images = markdown_data.get('images', {})
+    for img_path, img_base64 in markdown_images.items():
+        try:
+            img_data = base64.b64decode(img_base64)
+            full_img_path = output_dir / img_path
+            full_img_path.parent.mkdir(parents=True, exist_ok=True)
+            
+            with open(full_img_path, 'wb') as f:
+                f.write(img_data)
+            
+            print(f"🖼️ Saved Markdown image: {full_img_path}")
+            
+        except Exception as e:
+            print(f"❌ Error saving Markdown image {img_path}: {e}")
+    
+    return str(md_file_path)
+
+def test_ppstructurev3_api_enhanced(image_path: str, 
+                                  API_URL: str, 
+                                  output_dir: str = "./api_output",
+                                  normalize_numbers: bool = True) -> Dict[str, Any]:
+    """
+    增强版的PP-StructureV3 API测试,保存为标准格式并支持数字标准化
+    
+    Args:
+        image_path: 输入图像路径
+        API_URL: API URL
+        output_dir: 输出目录
+        normalize_numbers: 是否标准化数字格式,默认True
+    """
+    # 对本地图像进行Base64编码
+    with open(image_path, "rb") as file:
+        image_bytes = file.read()
+        image_data = base64.b64encode(image_bytes).decode("ascii")
+
+    payload = {
+        "file": image_data,
+        "fileType": 1,
+    }
+
+    # 调用API
+    print(f"🚀 Calling API: {API_URL}")
+    print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
+    
+    response = requests.post(API_URL, json=payload)
+
+    # 处理接口返回数据
+    assert response.status_code == 200
+    api_result = response.json()["result"]
+    
+    # 创建输出目录
+    output_path = Path(output_dir)
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    # 获取输入文件的基本名称
+    input_name = Path(image_path).stem
+    
+    # 1. 转换并保存标准JSON格式(包含数字标准化)
+    json_output_path = output_path / f"{input_name}.json"
+    converted_json = convert_api_result_to_json(
+        api_result, 
+        image_path, 
+        str(json_output_path),
+        normalize_numbers=normalize_numbers
+    )
+    
+    # 2. 保存输出图像
+    images_dir = output_path / f"{input_name}_images"
+    saved_images = save_output_images(api_result, str(images_dir))
+    
+    # 3. 保存Markdown内容(包含数字标准化)
+    markdown_dir = output_path / f"{input_name}_markdown"
+    markdown_file = save_markdown_content(api_result, str(markdown_dir), normalize_numbers=normalize_numbers)
+    
+    # 4. 保存完整的API响应(用于调试)
+    full_response_path = output_path / f"{input_name}_full_response.json"
+    with open(full_response_path, 'w', encoding='utf-8') as f:
+        json.dump(api_result, f, ensure_ascii=False, indent=2)
+    
+    print(f"📊 Processing completed!")
+    print(f"  Standard JSON: {json_output_path}")
+    print(f"  Images directory: {images_dir}")
+    print(f"  Markdown file: {markdown_file}")
+    print(f"  Full response: {full_response_path}")
+    
+    # 打印详细统计(仿照ocr_by_vlm.py)
+    processing_info = converted_json.get('processing_info', {})
+    print("\n📊 API处理统计")
+    print(f"   原始图片: {Path(image_path).resolve().as_posix()}")
+    print(f"   输出路径: {json_output_path.resolve().as_posix()}")
+    print(f"   API地址: {API_URL}")
+    print(f"   数字标准化: {processing_info.get('normalize_numbers', False)}")
+    if normalize_numbers:
+        print(f"   字符变化数: {processing_info.get('character_changes_count', 0)}")
+        print(f"   应用了标准化: {processing_info.get('changes_applied', False)}")
+    
+    return {
+        "standard_json": str(json_output_path),
+        "images_dir": str(images_dir),
+        "markdown_file": markdown_file,
+        "full_response": str(full_response_path),
+        "converted_data": converted_json,
+        "processing_info": processing_info
+    }
+
+def batch_process_api_results(image_list: List[str], 
+                            API_URL: str, 
+                            output_base_dir: str,
+                            normalize_numbers: bool = True) -> List[Dict[str, Any]]:
+    """
+    批量处理多个图像文件
+    
+    Args:
+        image_list: 图像文件路径列表
+        API_URL: API URL
+        output_base_dir: 输出基础目录
+        normalize_numbers: 是否标准化数字格式
+    """
+    results = []
+    
+    print(f"🚀 开始批量处理 {len(image_list)} 个图像文件...")
+    print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
+    
+    for i, image_path in enumerate(image_list):
+        try:
+            print(f"\n🔄 Processing {i+1}/{len(image_list)}: {Path(image_path).name}")
+            
+            # 为每个文件创建单独的输出目录
+            output_dir = Path(output_base_dir) / f"result_{i+1:03d}_{Path(image_path).stem}"
+            
+            result = test_ppstructurev3_api_enhanced(
+                image_path, 
+                API_URL, 
+                str(output_dir),
+                normalize_numbers=normalize_numbers
+            )
+            results.append(result)
+            
+        except Exception as e:
+            print(f"❌ Error processing {image_path}: {e}")
+            results.append({"error": str(e), "image_path": image_path})
+    
+    # 生成批量处理统计
+    success_count = sum(1 for r in results if 'error' not in r)
+    total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
+    
+    print(f"\n📊 批量处理完成统计")
+    print(f"   总文件数: {len(image_list)}")
+    print(f"   成功处理: {success_count}")
+    print(f"   失败数量: {len(image_list) - success_count}")
+    if normalize_numbers:
+        print(f"   总标准化字符数: {total_changes}")
+    
+    return results
+
+if __name__ == "__main__":
+    import argparse
+    
+    parser = argparse.ArgumentParser(description='PP-StructureV3 API客户端工具')
+    parser.add_argument('image_path', nargs='?', help='图片文件路径')
+    parser.add_argument('-u', '--url', default='http://localhost:8080/layout-parsing', help='API URL')
+    parser.add_argument('-o', '--output', default='./api_conversion_output', help='输出目录')
+    parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
+    parser.add_argument('--batch', help='批量处理,指定包含图像路径的文本文件')
+    
+    args = parser.parse_args()
+    
+    normalize_numbers = not args.no_normalize
+    
+    try:
+        if args.batch:
+            # 批量处理模式
+            with open(args.batch, 'r', encoding='utf-8') as f:
+                image_list = [line.strip() for line in f if line.strip()]
+            
+            results = batch_process_api_results(
+                image_list, 
+                args.url, 
+                args.output,
+                normalize_numbers=normalize_numbers
+            )
+            print("\n🎉 批量API处理完成!")
+            
+        elif args.image_path:
+            # 单文件处理模式
+            result = test_ppstructurev3_api_enhanced(
+                args.image_path, 
+                args.url, 
+                args.output,
+                normalize_numbers=normalize_numbers
+            )
+            
+            # 验证转换结果
+            with open(result["standard_json"], 'r', encoding='utf-8') as f:
+                converted_data = json.load(f)
+            
+            print(f"\n📋 转换后的数据包含:")
+            print(f"  - 解析结果块数: {len(converted_data.get('parsing_res_list', []))}")
+            print(f"  - OCR文本数: {len(converted_data.get('overall_ocr_res', {}).get('rec_texts', []))}")
+            print(f"  - 表格数: {len(converted_data.get('table_res_list', []))}")
+            
+            print("\n🎉 API测试和转换完成!")
+            
+        else:
+            # 默认示例
+            image_path = "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/data_DotsOCR_Results/2023年度报告母公司/2023年度报告母公司_page_004.png"
+            result = test_ppstructurev3_api_enhanced(
+                image_path, 
+                args.url, 
+                args.output,
+                normalize_numbers=normalize_numbers
+            )
+            print("\n🎉 API测试和转换完成!")
+            
+    except Exception as e:
+        print(f"❌ 处理失败: {e}")
+        import traceback
+        traceback.print_exc()