Explorar o código

feat: 添加PP-StructureV3 API客户端,支持PDF和图像文件的统一处理及数字标准化功能

zhch158_admin hai 1 mes
pai
achega
3d8461fffc
Modificáronse 1 ficheiros con 604 adicións e 0 borrados
  1. 604 0
      zhch/ppstructurev3_single_client.py

+ 604 - 0
zhch/ppstructurev3_single_client.py

@@ -0,0 +1,604 @@
+"""PDF转图像后通过API统一处理"""
+import json
+import time
+import os
+import traceback
+import argparse
+import sys
+import warnings
+import base64
+from pathlib import Path
+from typing import List, Dict, Any, Union
+import requests
+from tqdm import tqdm
+
+from dotenv import load_dotenv
+load_dotenv(override=True)
+
+from utils import (
+    get_image_files_from_dir,
+    get_image_files_from_list,
+    get_image_files_from_csv,
+    collect_pid_files,
+    load_images_from_pdf,
+    normalize_financial_numbers,
+    normalize_markdown_table
+)
+
+def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
+    """
+    将PDF转换为图像文件
+    
+    Args:
+        pdf_file: PDF文件路径
+        output_dir: 输出目录
+        dpi: 图像分辨率
+        
+    Returns:
+        生成的图像文件路径列表
+    """
+    pdf_path = Path(pdf_file)
+    if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
+        print(f"❌ Invalid PDF file: {pdf_path}")
+        return []
+
+    # 如果没有指定输出目录,使用PDF同名目录
+    if output_dir is None:
+        output_path = pdf_path.parent / f"{pdf_path.stem}"
+    else:
+        output_path = Path(output_dir) / f"{pdf_path.stem}"
+    output_path = output_path.resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    try:
+        # 使用doc_utils中的函数加载PDF图像
+        images = load_images_from_pdf(str(pdf_path), dpi=dpi)
+        
+        image_paths = []
+        for i, image in enumerate(images):
+            # 生成图像文件名
+            image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
+            image_path = output_path / image_filename
+
+            # 保存图像
+            image.save(str(image_path))
+            image_paths.append(str(image_path))
+            
+        print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
+        return image_paths
+        
+    except Exception as e:
+        print(f"❌ Error converting PDF {pdf_path}: {e}")
+        traceback.print_exc()
+        return []
+
+def get_input_files(args) -> List[str]:
+    """
+    获取输入文件列表,统一处理PDF和图像文件
+    
+    Args:
+        args: 命令行参数
+        
+    Returns:
+        处理后的图像文件路径列表
+    """
+    input_files = []
+    
+    # 获取原始输入文件
+    if args.input_csv:
+        raw_files = get_image_files_from_csv(args.input_csv, "fail")
+    elif args.input_file_list:
+        raw_files = get_image_files_from_list(args.input_file_list)
+    elif args.input_file:
+        raw_files = [Path(args.input_file).resolve()]
+    else:
+        input_dir = Path(args.input_dir).resolve()
+        if not input_dir.exists():
+            print(f"❌ Input directory does not exist: {input_dir}")
+            return []
+        
+        # 获取所有支持的文件(图像和PDF)
+        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
+        pdf_extensions = ['.pdf']
+        
+        raw_files = []
+        for ext in image_extensions + pdf_extensions:
+            raw_files.extend(list(input_dir.glob(f"*{ext}")))
+            raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+        
+        raw_files = [str(f) for f in raw_files]
+    
+    # 分别处理PDF和图像文件
+    pdf_count = 0
+    image_count = 0
+    
+    for file_path in raw_files:
+        file_path = Path(file_path)
+        
+        if file_path.suffix.lower() == '.pdf':
+            # 转换PDF为图像
+            print(f"📄 Processing PDF: {file_path.name}")
+            pdf_images = convert_pdf_to_images(
+                str(file_path), 
+                args.output_dir,
+                dpi=args.pdf_dpi
+            )
+            input_files.extend(pdf_images)
+            pdf_count += 1
+        else:
+            # 直接添加图像文件
+            if file_path.exists():
+                input_files.append(str(file_path))
+                image_count += 1
+    
+    print(f"📊 Input summary:")
+    print(f"  PDF files processed: {pdf_count}")
+    print(f"  Image files found: {image_count}")
+    print(f"  Total image files to process: {len(input_files)}")
+    
+    return input_files
+
+def convert_api_result_to_json(api_result: Dict[str, Any], 
+                              input_image_path: str, 
+                              output_dir: str, 
+                              filename: str,
+                              normalize_numbers: bool = True) -> tuple[str, Dict[str, Any]]:
+    """
+    将API返回结果转换为标准JSON格式,并支持数字标准化
+    """
+    # 获取主要数据
+    layout_parsing_results = api_result.get('layoutParsingResults', [])
+    
+    if not layout_parsing_results:
+        print("⚠️ Warning: No layoutParsingResults found in API response")
+        return {}
+    
+    # 取第一个结果(通常只有一个)
+    main_result = layout_parsing_results[0]
+    pruned_result = main_result.get('prunedResult', {})
+    
+    # 构造标准格式的JSON
+    converted_json = {
+        "input_path": input_image_path,
+        "page_index": None,
+        "model_settings": pruned_result.get('model_settings', {}),
+        "parsing_res_list": pruned_result.get('parsing_res_list', []),
+        "doc_preprocessor_res": {
+            "input_path": None,
+            "page_index": None,
+            "model_settings": pruned_result.get('doc_preprocessor_res', {}).get('model_settings', {}),
+            "angle": pruned_result.get('doc_preprocessor_res', {}).get('angle', 0)
+        },
+        "layout_det_res": {
+            "input_path": None,
+            "page_index": None,
+            "boxes": pruned_result.get('layout_det_res', {}).get('boxes', [])
+        },
+        "overall_ocr_res": {
+            "input_path": None,
+            "page_index": None,
+            "model_settings": pruned_result.get('overall_ocr_res', {}).get('model_settings', {}),
+            "dt_polys": pruned_result.get('overall_ocr_res', {}).get('dt_polys', []),
+            "text_det_params": pruned_result.get('overall_ocr_res', {}).get('text_det_params', {}),
+            "text_type": pruned_result.get('overall_ocr_res', {}).get('text_type', 'general'),
+            "textline_orientation_angles": pruned_result.get('overall_ocr_res', {}).get('textline_orientation_angles', []),
+            "text_rec_score_thresh": pruned_result.get('overall_ocr_res', {}).get('text_rec_score_thresh', 0.0),
+            "return_word_box": pruned_result.get('overall_ocr_res', {}).get('return_word_box', False),
+            "rec_texts": pruned_result.get('overall_ocr_res', {}).get('rec_texts', []),
+            "rec_scores": pruned_result.get('overall_ocr_res', {}).get('rec_scores', []),
+            "rec_polys": pruned_result.get('overall_ocr_res', {}).get('rec_polys', []),
+            "rec_boxes": pruned_result.get('overall_ocr_res', {}).get('rec_boxes', [])
+        },
+        "table_res_list": pruned_result.get('table_res_list', [])
+    }
+    
+    # 数字标准化处理
+    original_json = converted_json.copy()
+    changes_count = 0
+    
+    if normalize_numbers:
+        # 1. 标准化 parsing_res_list 中的文本内容
+        for item in converted_json.get('parsing_res_list', []):
+            if 'block_content' in item:
+                original_content = item['block_content']
+                
+                # 根据block_label类型选择标准化方法
+                if item.get('block_label') == 'table':
+                    normalized_content = normalize_markdown_table(original_content)
+                # else:
+                #     normalized_content = normalize_financial_numbers(original_content)
+                
+                if original_content != normalized_content:
+                    item['block_content'] = normalized_content
+                    changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
+        
+        # 2. 标准化 table_res_list 中的HTML表格
+        for table_item in converted_json.get('table_res_list', []):
+            if 'pred_html' in table_item:
+                original_html = table_item['pred_html']
+                normalized_html = normalize_markdown_table(original_html)
+                
+                if original_html != normalized_html:
+                    table_item['pred_html'] = normalized_html
+                    changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
+        
+        # 3. 标准化 overall_ocr_res 中的识别文本
+        # ocr_res = converted_json.get('overall_ocr_res', {})
+        # if 'rec_texts' in ocr_res:
+        #     original_texts = ocr_res['rec_texts'][:]
+        #     normalized_texts = []
+            
+        #     for text in original_texts:
+        #         normalized_text = normalize_financial_numbers(text)
+        #         normalized_texts.append(normalized_text)
+        #         if text != normalized_text:
+        #             changes_count += len([1 for o, n in zip(text, normalized_text) if o != n])
+            
+        #     ocr_res['rec_texts'] = normalized_texts
+        
+        # 添加标准化处理信息
+        converted_json['processing_info'] = {
+            "normalize_numbers": normalize_numbers,
+            "changes_applied": changes_count > 0,
+            "character_changes_count": changes_count
+        }
+        
+        if changes_count > 0:
+            print(f"🔧 已标准化 {changes_count} 个字符(全角→半角)")
+    else:
+        converted_json['processing_info'] = {
+            "normalize_numbers": False,
+            "changes_applied": False,
+            "character_changes_count": 0
+        }
+    
+    # 保存JSON文件
+    output_path = Path(output_dir).resolve() / f"{filename}.json"
+    output_path.parent.mkdir(parents=True, exist_ok=True)
+    
+    with open(output_path, 'w', encoding='utf-8') as f:
+        json.dump(converted_json, f, ensure_ascii=False, indent=2)
+    
+    # 如果启用了标准化且有变化,保存原始版本用于对比
+    if normalize_numbers and changes_count > 0:
+        original_output_path = output_path.parent / f"{output_path.stem}_original.json"
+        with open(original_output_path, 'w', encoding='utf-8') as f:
+            json.dump(original_json, f, ensure_ascii=False, indent=2)
+    
+    return str(output_path), converted_json
+
+def save_markdown_content(api_result: Dict[str, Any], output_dir: str, 
+                         filename: str, normalize_numbers: bool = True) -> str:
+    """
+    保存Markdown内容,支持数字标准化
+    """
+    layout_parsing_results = api_result.get('layoutParsingResults', [])
+    if not layout_parsing_results:
+        return ""
+    
+    main_result = layout_parsing_results[0]
+    markdown_data = main_result.get('markdown', {})
+    
+    output_path = Path(output_dir).resolve()
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    # 保存Markdown文本
+    markdown_text = markdown_data.get('text', '')
+    
+    # 数字标准化处理
+    if normalize_numbers and markdown_text:
+        original_markdown_text = markdown_text
+        markdown_text = normalize_markdown_table(markdown_text)
+        
+        changes_count = len([1 for o, n in zip(original_markdown_text, markdown_text) if o != n])
+        if changes_count > 0:
+            print(f"🔧 Markdown中已标准化 {changes_count} 个字符(全角→半角)")
+    
+    md_file_path = output_path / f"{filename}.md"
+    with open(md_file_path, 'w', encoding='utf-8') as f:
+        f.write(markdown_text)
+    
+    # 如果启用了标准化且有变化,保存原始版本用于对比
+    if normalize_numbers and changes_count > 0:
+        original_output_path = output_path.parent / f"{output_path.stem}_original.json"
+        with open(original_output_path, 'w', encoding='utf-8') as f:
+            f.write(original_markdown_text)
+
+    return str(md_file_path)
+
+def call_api_for_image(image_path: str, api_url: str, timeout: int = 300) -> Dict[str, Any]:
+    """
+    为单个图像调用API
+    
+    Args:
+        image_path: 图像文件路径
+        api_url: API URL
+        timeout: 超时时间(秒)
+        
+    Returns:
+        API返回结果
+    """
+    try:
+        # 对本地图像进行Base64编码
+        with open(image_path, "rb") as file:
+            image_bytes = file.read()
+            image_data = base64.b64encode(image_bytes).decode("ascii")
+
+        payload = {
+            "file": image_data,
+            "fileType": 1,
+        }
+
+        # 调用API
+        response = requests.post(api_url, json=payload, timeout=timeout)
+        response.raise_for_status()
+        
+        return response.json()["result"]
+        
+    except requests.exceptions.Timeout:
+        raise Exception(f"API调用超时 ({timeout}秒)")
+    except requests.exceptions.RequestException as e:
+        raise Exception(f"API调用失败: {e}")
+    except KeyError:
+        raise Exception("API返回格式错误,缺少'result'字段")
+    except Exception as e:
+        raise Exception(f"处理图像时发生错误: {e}")
+
+def process_images_via_api(image_paths: List[str],
+                          api_url: str,
+                          output_dir: str = "./output",
+                          normalize_numbers: bool = True,
+                          timeout: int = 300) -> List[Dict[str, Any]]:
+    """
+    通过API统一处理图像文件
+    
+    Args:
+        image_paths: 图像路径列表
+        api_url: API URL
+        output_dir: 输出目录
+        normalize_numbers: 是否标准化数字格式
+        timeout: API调用超时时间
+        
+    Returns:
+        处理结果列表
+    """
+    # 创建输出目录
+    output_path = Path(output_dir)
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    print(f"🚀 Using API: {api_url}")
+    print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
+    
+    all_results = []
+    total_images = len(image_paths)
+    
+    print(f"Processing {total_images} images via API")
+    
+    # 使用tqdm显示进度
+    with tqdm(total=total_images, desc="Processing images", unit="img", 
+              bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
+        
+        # 逐个处理图像
+        for img_path in image_paths:
+            start_time = time.time()
+            
+            try:
+                # 调用API处理图像
+                api_result = call_api_for_image(img_path, api_url, timeout)
+                processing_time = time.time() - start_time
+                
+                # 处理API返回结果
+                input_path = Path(img_path)
+                
+                # 生成输出文件名
+                output_filename = input_path.stem
+                
+                # 转换并保存标准JSON格式
+                json_output_path, converted_json = convert_api_result_to_json(
+                    api_result, 
+                    str(input_path), 
+                    output_dir,
+                    output_filename,
+                    normalize_numbers=normalize_numbers
+                )
+                
+                # 保存Markdown内容
+                md_output_path = save_markdown_content(
+                    api_result, 
+                    output_dir, 
+                    output_filename,
+                    normalize_numbers=normalize_numbers
+                )
+                
+                # 记录处理结果
+                all_results.append({
+                    "image_path": str(input_path),
+                    "processing_time": processing_time,
+                    "success": True,
+                    "api_url": api_url,
+                    "output_json": json_output_path,
+                    "output_md": md_output_path,
+                    "is_pdf_page": "_page_" in input_path.name,  # 标记是否为PDF页面
+                    "processing_info": converted_json.get('processing_info', {})
+                })
+                
+                # 更新进度条
+                success_count = sum(1 for r in all_results if r.get('success', False))
+                
+                pbar.update(1)
+                pbar.set_postfix({
+                    'time': f"{processing_time:.2f}s",
+                    'success': f"{success_count}/{len(all_results)}",
+                    'rate': f"{success_count/len(all_results)*100:.1f}%"
+                })
+                
+            except Exception as e:
+                print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
+                
+                # 添加错误结果
+                all_results.append({
+                    "image_path": str(img_path),
+                    "processing_time": 0,
+                    "success": False,
+                    "api_url": api_url,
+                    "error": str(e),
+                    "is_pdf_page": "_page_" in Path(img_path).name
+                })
+                pbar.update(1)
+    
+    return all_results
+
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 API Client - Unified PDF/Image Processor")
+    
+    # 参数定义
+    input_group = parser.add_mutually_exclusive_group(required=True)
+    input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
+    input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
+    input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
+    input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
+
+    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
+    parser.add_argument("--api_url", type=str, default="http://localhost:8080/layout-parsing", help="API URL")
+    parser.add_argument("--pdf_dpi", type=int, default=200, help="DPI for PDF to image conversion")
+    parser.add_argument("--timeout", type=int, default=300, help="API timeout in seconds")
+    parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
+    parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
+    parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
+
+    args = parser.parse_args()
+    
+    normalize_numbers = not args.no_normalize
+    
+    try:
+        # 获取并预处理输入文件
+        print("🔄 Preprocessing input files...")
+        input_files = get_input_files(args)
+        
+        if not input_files:
+            print("❌ No input files found or processed")
+            return 1
+        
+        if args.test_mode:
+            input_files = input_files[:20]
+            print(f"Test mode: processing only {len(input_files)} images")
+        
+        print(f"🌐 Using API: {args.api_url}")
+        print(f"⏱️ Timeout: {args.timeout} seconds")
+        
+        # 开始处理
+        start_time = time.time()
+        results = process_images_via_api(
+            input_files,
+            args.api_url,
+            args.output_dir,
+            normalize_numbers=normalize_numbers,
+            timeout=args.timeout
+        )
+        total_time = time.time() - start_time
+        
+        # 统计结果
+        success_count = sum(1 for r in results if r.get('success', False))
+        error_count = len(results) - success_count
+        pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
+        total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
+        
+        print(f"\n" + "="*60)
+        print(f"✅ API Processing completed!")
+        print(f"📊 Statistics:")
+        print(f"  Total files processed: {len(input_files)}")
+        print(f"  PDF pages processed: {pdf_page_count}")
+        print(f"  Regular images processed: {len(input_files) - pdf_page_count}")
+        print(f"  Successful: {success_count}")
+        print(f"  Failed: {error_count}")
+        if len(input_files) > 0:
+            print(f"  Success rate: {success_count / len(input_files) * 100:.2f}%")
+        if normalize_numbers:
+            print(f"  总标准化字符数: {total_changes}")
+        print(f"⏱️ Performance:")
+        print(f"  Total time: {total_time:.2f} seconds")
+        if total_time > 0:
+            print(f"  Throughput: {len(input_files) / total_time:.2f} files/second")
+            print(f"  Avg time per file: {total_time / len(input_files):.2f} seconds")
+        
+        # 保存结果统计
+        stats = {
+            "total_files": len(input_files),
+            "pdf_pages": pdf_page_count,
+            "regular_images": len(input_files) - pdf_page_count,
+            "success_count": success_count,
+            "error_count": error_count,
+            "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0,
+            "total_time": total_time,
+            "throughput": len(input_files) / total_time if total_time > 0 else 0,
+            "avg_time_per_file": total_time / len(input_files) if len(input_files) > 0 else 0,
+            "api_url": args.api_url,
+            "pdf_dpi": args.pdf_dpi,
+            "normalize_numbers": normalize_numbers,
+            "total_character_changes": total_changes,
+            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
+        }
+        
+        # 保存最终结果
+        output_file_name = Path(args.output_dir).name
+        output_file = os.path.join(args.output_dir, f"{output_file_name}_api_results.json")
+        final_results = {
+            "stats": stats,
+            "results": results
+        }
+        
+        with open(output_file, 'w', encoding='utf-8') as f:
+            json.dump(final_results, f, ensure_ascii=False, indent=2)
+        
+        print(f"💾 Results saved to: {output_file}")
+
+        # 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
+        if not args.collect_results:
+            output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
+        else:
+            output_file_processed = Path(args.collect_results).resolve()
+        
+        processed_files = collect_pid_files(output_file)
+        with open(output_file_processed, 'w', encoding='utf-8') as f:
+            f.write("image_path,status\n")
+            for file_path, status in processed_files:
+                f.write(f"{file_path},{status}\n")
+        print(f"💾 Processed files saved to: {output_file_processed}")
+
+        return 0
+        
+    except Exception as e:
+        print(f"❌ Processing failed: {e}", file=sys.stderr)
+        traceback.print_exc()
+        return 1
+
+
+if __name__ == "__main__":
+    print(f"🚀 启动PP-StructureV3 API客户端...")
+    print(f"🔧 环境变量检查: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
+    
+    if len(sys.argv) == 1:
+        # 如果没有命令行参数,使用默认配置运行
+        print("ℹ️  No command line arguments provided. Running with default configuration...")
+        
+        # 默认配置
+        default_config = {
+            "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
+            "output_dir": "./OmniDocBench_API_Results",
+            "api_url": "http://10.192.72.11:8111/layout-parsing",
+            "timeout": "300",
+            "collect_results": f"./OmniDocBench_API_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
+        }
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            sys.argv.extend([f"--{key}", str(value)])
+        
+        # sys.argv.append("--no-normalize")
+
+        # 测试模式
+        # sys.argv.append("--test_mode")
+    
+    sys.exit(main())