| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554 |
- """PDF转图像后统一处理"""
- import json
- import time
- import os
- import traceback
- import argparse
- import sys
- import warnings
- from pathlib import Path
- from typing import List, Dict, Any, Union
- import cv2
- import numpy as np
- # 抑制特定警告
- warnings.filterwarnings("ignore", message="To copy construct from a tensor")
- warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
- warnings.filterwarnings("ignore", category=UserWarning, module="paddlex")
- from paddlex import create_pipeline
- from paddlex.utils.device import constr_device, parse_device
- from tqdm import tqdm
- from dotenv import load_dotenv
- load_dotenv(override=True)
- from utils import (
- get_image_files_from_dir,
- get_image_files_from_list,
- get_image_files_from_csv,
- collect_pid_files,
- load_images_from_pdf,
- normalize_financial_numbers,
- normalize_markdown_table
- )
- def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
- """
- 将PDF转换为图像文件
-
- Args:
- pdf_file: PDF文件路径
- output_dir: 输出目录
- dpi: 图像分辨率
-
- Returns:
- 生成的图像文件路径列表
- """
- pdf_path = Path(pdf_file)
- if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
- print(f"❌ Invalid PDF file: {pdf_path}")
- return []
- # 如果没有指定输出目录,使用PDF同名目录
- if output_dir is None:
- output_path = pdf_path.parent / f"{pdf_path.stem}"
- else:
- output_path = Path(output_dir) / f"{pdf_path.stem}"
- output_path = output_path.resolve()
- output_path.mkdir(parents=True, exist_ok=True)
- try:
- # 使用doc_utils中的函数加载PDF图像
- images = load_images_from_pdf(str(pdf_path), dpi=dpi)
-
- image_paths = []
- for i, image in enumerate(images):
- # 生成图像文件名
- image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
- image_path = output_path / image_filename
- # 保存图像
- image.save(str(image_path))
- image_paths.append(str(image_path))
-
- print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
- return image_paths
-
- except Exception as e:
- print(f"❌ Error converting PDF {pdf_path}: {e}")
- traceback.print_exc()
- return []
- def get_input_files(args) -> List[str]:
- """
- 获取输入文件列表,统一处理PDF和图像文件
-
- Args:
- args: 命令行参数
-
- Returns:
- 处理后的图像文件路径列表
- """
- input_files = []
-
- # 获取原始输入文件
- if args.input_csv:
- raw_files = get_image_files_from_csv(args.input_csv, "fail")
- elif args.input_file_list:
- raw_files = get_image_files_from_list(args.input_file_list)
- elif args.input_file:
- raw_files = [Path(args.input_file).resolve()]
- else:
- input_dir = Path(args.input_dir).resolve()
- if not input_dir.exists():
- print(f"❌ Input directory does not exist: {input_dir}")
- return []
-
- # 获取所有支持的文件(图像和PDF)
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
- pdf_extensions = ['.pdf']
-
- raw_files = []
- for ext in image_extensions + pdf_extensions:
- raw_files.extend(list(input_dir.glob(f"*{ext}")))
- raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
-
- raw_files = [str(f) for f in raw_files]
-
- # 分别处理PDF和图像文件
- pdf_count = 0
- image_count = 0
-
- for file_path in raw_files:
- file_path = Path(file_path)
-
- if file_path.suffix.lower() == '.pdf':
- # 转换PDF为图像
- print(f"📄 Processing PDF: {file_path.name}")
- pdf_images = convert_pdf_to_images(
- str(file_path),
- args.output_dir,
- dpi=args.pdf_dpi
- )
- input_files.extend(pdf_images)
- pdf_count += 1
- else:
- # 直接添加图像文件
- if file_path.exists():
- input_files.append(str(file_path))
- image_count += 1
-
- print(f"📊 Input summary:")
- print(f" PDF files processed: {pdf_count}")
- print(f" Image files found: {image_count}")
- print(f" Total image files to process: {len(input_files)}")
-
- return input_files
- def normalize_pipeline_result(result: Dict[str, Any], normalize_numbers: bool = True) -> Dict[str, Any]:
- """
- 对pipeline结果进行数字标准化处理
-
- Args:
- result: pipeline返回的结果对象
- normalize_numbers: 是否启用数字标准化
-
- Returns:
- 包含标准化信息的字典
- """
- if not normalize_numbers:
- return {
- "normalize_numbers": False,
- "changes_applied": False,
- "character_changes_count": 0,
- "parsing_res_tables_count": 0,
- "table_res_list_count": 0,
- "table_consistency_fixed": False
- }
-
- changes_count = 0
- original_data = {}
-
- # 获取原始数据进行备份
- if 'parsing_res_list' in result:
- original_data['parsing_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['parsing_res_list']]
- if 'table_res_list' in result:
- original_data['table_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['table_res_list']]
- try:
- # 1. 标准化 parsing_res_list 中的文本内容
- if 'parsing_res_list' in result:
- for item in result['parsing_res_list']:
- if 'block_content' in item and item['block_content']:
- original_content = str(item['block_content'])
- normalized_content = original_content
-
- # 根据block_label类型选择标准化方法
- if 'block_label' in item and item['block_label'] == 'table':
- normalized_content = normalize_markdown_table(original_content)
-
- if original_content != normalized_content:
- item['block_content'] = normalized_content
- changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
-
- # 2. 标准化 table_res_list 中的HTML表格
- if 'table_res_list' in result:
- for table_item in result['table_res_list']:
- if 'pred_html' in table_item and table_item['pred_html']:
- original_html = str(table_item['pred_html'])
- normalized_html = normalize_markdown_table(original_html)
-
- if original_html != normalized_html:
- table_item['pred_html'] = normalized_html
- changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
-
- # 统计表格数量
- parsing_res_tables_count = 0
- table_res_list_count = 0
- if 'parsing_res_list' in result:
- parsing_res_tables_count = len([item for item in result['parsing_res_list']
- if 'block_label' in item and item['block_label'] == 'table'])
- if 'table_res_list' in result:
- table_res_list_count = len(result['table_res_list'])
-
- # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
- table_consistency_fixed = False
- if parsing_res_tables_count != table_res_list_count:
- warnings.warn(f"⚠️ Warning: Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
- f"but table_res_list has {table_res_list_count} tables.")
- table_consistency_fixed = True
- # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
- # 但由于缺乏具体规则,暂时只做统计和警告
- return {
- "normalize_numbers": normalize_numbers,
- "changes_applied": changes_count > 0,
- "character_changes_count": changes_count,
- "parsing_res_tables_count": parsing_res_tables_count,
- "table_res_list_count": table_res_list_count,
- "table_consistency_fixed": table_consistency_fixed
- }
-
- except Exception as e:
- print(f"⚠️ Warning: Error during normalization: {e}")
- return {
- "normalize_numbers": normalize_numbers,
- "changes_applied": False,
- "character_changes_count": 0,
- "normalization_error": str(e)
- }
- def save_normalized_files(result, output_dir: str, filename: str,
- processing_info: Dict[str, Any], normalize_numbers: bool = True):
- """
- 保存标准化处理后的文件,包括原始版本
- """
- output_path = Path(output_dir)
-
- # 保存标准化后的版本
- json_output_path = str(output_path / f"{filename}.json")
- md_output_path = str(output_path / f"{filename}.md")
-
- result.save_to_json(json_output_path)
- result.save_to_markdown(md_output_path)
-
- # 如果有标准化变化,在JSON中添加处理信息
- if normalize_numbers and processing_info.get('changes_applied', False):
- try:
- # 读取生成的JSON文件,添加处理信息
- with open(json_output_path, 'r', encoding='utf-8') as f:
- json_data = json.load(f)
-
- json_data['processing_info'] = processing_info
-
- # 重新保存包含处理信息的JSON
- with open(json_output_path, 'w', encoding='utf-8') as f:
- json.dump(json_data, f, ensure_ascii=False, indent=2)
-
- except Exception as e:
- print(f"⚠️ Warning: Could not add processing info to JSON: {e}")
-
- return json_output_path, md_output_path
- def process_images_unified(image_paths: List[str],
- pipeline_name: str = "PP-StructureV3",
- device: str = "gpu:0",
- output_dir: str = "./output",
- normalize_numbers: bool = True) -> List[Dict[str, Any]]:
- """
- 统一的图像处理函数,支持数字标准化
- """
- # 创建输出目录
- output_path = Path(output_dir)
- output_path.mkdir(parents=True, exist_ok=True)
-
- print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
-
- try:
- # 设置环境变量以减少警告
- os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
-
- # 初始化pipeline
- pipeline = create_pipeline(pipeline_name, device=device)
- print(f"Pipeline initialized successfully on {device}")
-
- except Exception as e:
- print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
- traceback.print_exc()
- return []
-
- all_results = []
- total_images = len(image_paths)
-
- print(f"Processing {total_images} images one by one")
- print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
-
- # 使用tqdm显示进度
- with tqdm(total=total_images, desc="Processing images", unit="img",
- bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
-
- # 逐个处理图像
- for img_path in image_paths:
- start_time = time.time()
-
- try:
- # 使用pipeline预测单个图像
- results = pipeline.predict(
- img_path,
- use_doc_orientation_classify=True,
- use_doc_unwarping=False,
- use_seal_recognition=True,
- use_table_recognition=True,
- use_formula_recognition=False,
- use_chart_recognition=True,
- )
-
- processing_time = time.time() - start_time
-
- # 处理结果
- for result in results:
- try:
- input_path = Path(result["input_path"])
-
- # 生成输出文件名
- if result.get("page_index") is not None:
- output_filename = f"{input_path.stem}_{result['page_index']}"
- else:
- output_filename = f"{input_path.stem}"
-
- # 应用数字标准化
- processing_info = normalize_pipeline_result(result, normalize_numbers)
-
- # 保存JSON和Markdown文件(包含标准化处理)
- json_output_path, md_output_path = save_normalized_files(
- result, output_dir, output_filename, processing_info, normalize_numbers
- )
-
- # 如果有表格一致性修复,输出提示
- if processing_info.get('table_consistency_fixed', False):
- print(f"🔧 修复了表格一致性问题:{input_path.name}")
-
- # 记录处理结果
- all_results.append({
- "image_path": str(input_path),
- "processing_time": processing_time,
- "success": True,
- "device": device,
- "output_json": json_output_path,
- "output_md": md_output_path,
- "is_pdf_page": "_page_" in input_path.name, # 标记是否为PDF页面
- "processing_info": processing_info
- })
-
- except Exception as e:
- print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
- traceback.print_exc()
- all_results.append({
- "image_path": str(img_path),
- "processing_time": 0,
- "success": False,
- "device": device,
- "error": str(e)
- })
-
- # 更新进度条
- success_count = sum(1 for r in all_results if r.get('success', False))
-
- pbar.update(1)
- pbar.set_postfix({
- 'time': f"{processing_time:.2f}s",
- 'success': f"{success_count}/{len(all_results)}",
- 'rate': f"{success_count/len(all_results)*100:.1f}%"
- })
-
- except Exception as e:
- print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
- traceback.print_exc()
-
- # 添加错误结果
- all_results.append({
- "image_path": str(img_path),
- "processing_time": 0,
- "success": False,
- "device": device,
- "error": str(e)
- })
- pbar.update(1)
-
- return all_results
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(description="PaddleX PP-StructureV3 Unified PDF/Image Processor")
-
- # 参数定义
- input_group = parser.add_mutually_exclusive_group(required=True)
- input_group.add_argument("--input_file", type=str, help="Input file (supports both PDF and image file)")
- input_group.add_argument("--input_dir", type=str, help="Input directory (supports both PDF and image files)")
- input_group.add_argument("--input_file_list", type=str, help="Input file list (one file per line)")
- input_group.add_argument("--input_csv", type=str, help="Input CSV file with image_path and status columns")
- parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
- parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
- parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
- parser.add_argument("--pdf_dpi", type=int, default=200, help="DPI for PDF to image conversion")
- parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
- parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
- parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
- args = parser.parse_args()
-
- normalize_numbers = not args.no_normalize
-
- try:
- # 获取并预处理输入文件
- print("🔄 Preprocessing input files...")
- input_files = get_input_files(args)
-
- if not input_files:
- print("❌ No input files found or processed")
- return 1
-
- if args.test_mode:
- input_files = input_files[:20]
- print(f"Test mode: processing only {len(input_files)} images")
-
- print(f"Using device: {args.device}")
-
- # 开始处理
- start_time = time.time()
- results = process_images_unified(
- input_files,
- args.pipeline,
- args.device,
- args.output_dir,
- normalize_numbers=normalize_numbers
- )
- total_time = time.time() - start_time
-
- # 统计结果
- success_count = sum(1 for r in results if r.get('success', False))
- error_count = len(results) - success_count
- pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
- total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
-
- print(f"\n" + "="*60)
- print(f"✅ Processing completed!")
- print(f"📊 Statistics:")
- print(f" Total files processed: {len(input_files)}")
- print(f" PDF pages processed: {pdf_page_count}")
- print(f" Regular images processed: {len(input_files) - pdf_page_count}")
- print(f" Successful: {success_count}")
- print(f" Failed: {error_count}")
- if len(input_files) > 0:
- print(f" Success rate: {success_count / len(input_files) * 100:.2f}%")
- if normalize_numbers:
- print(f" 总标准化字符数: {total_changes}")
- print(f"⏱️ Performance:")
- print(f" Total time: {total_time:.2f} seconds")
- if total_time > 0:
- print(f" Throughput: {len(input_files) / total_time:.2f} files/second")
- print(f" Avg time per file: {total_time / len(input_files):.2f} seconds")
-
- # 保存结果统计
- stats = {
- "total_files": len(input_files),
- "pdf_pages": pdf_page_count,
- "regular_images": len(input_files) - pdf_page_count,
- "success_count": success_count,
- "error_count": error_count,
- "success_rate": success_count / len(input_files) if len(input_files) > 0 else 0,
- "total_time": total_time,
- "throughput": len(input_files) / total_time if total_time > 0 else 0,
- "avg_time_per_file": total_time / len(input_files) if len(input_files) > 0 else 0,
- "device": args.device,
- "pipeline": args.pipeline,
- "pdf_dpi": args.pdf_dpi,
- "normalize_numbers": normalize_numbers,
- "total_character_changes": total_changes,
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
- }
-
- # 保存最终结果
- output_file_name = Path(args.output_dir).name
- output_file = os.path.join(args.output_dir, f"{output_file_name}_unified.json")
- final_results = {
- "stats": stats,
- "results": results
- }
-
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(final_results, f, ensure_ascii=False, indent=2)
-
- print(f"💾 Results saved to: {output_file}")
- # 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
- if not args.collect_results:
- output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
- else:
- output_file_processed = Path(args.collect_results).resolve()
- processed_files = collect_pid_files(output_file)
- with open(output_file_processed, 'w', encoding='utf-8') as f:
- f.write("image_path,status\n")
- for file_path, status in processed_files:
- f.write(f"{file_path},{status}\n")
- print(f"💾 Processed files saved to: {output_file_processed}")
- return 0
-
- except Exception as e:
- print(f"❌ Processing failed: {e}", file=sys.stderr)
- traceback.print_exc()
- return 1
- if __name__ == "__main__":
- print(f"🚀 启动统一PDF/图像处理程序...")
- print(f"🔧 CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
-
- if len(sys.argv) == 1:
- # 如果没有命令行参数,使用默认配置运行
- print("ℹ️ No command line arguments provided. Running with default configuration...")
-
- # 默认配置
- default_config = {
- "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
- "output_dir": "./OmniDocBench_PPStructureV3_Results",
- "pipeline": "./my_config/PP-StructureV3.yaml",
- "device": "gpu:0",
- "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
- }
-
- # 构造参数
- sys.argv = [sys.argv[0]]
- for key, value in default_config.items():
- sys.argv.extend([f"--{key}", str(value)])
-
- # 可以添加禁用标准化选项
- # sys.argv.append("--no-normalize")
-
- # 测试模式
- # sys.argv.append("--test_mode")
-
- sys.exit(main())
|