main_enhanced.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #!/usr/bin/env python3
  2. """
  3. 金融文档处理统一入口
  4. 支持银行流水和财务报表两种场景
  5. """
  6. import argparse
  7. import json
  8. import sys
  9. from pathlib import Path
  10. from loguru import logger
  11. # 添加项目根目录到 Python 路径
  12. project_root = Path(__file__).parents[1]
  13. if str(project_root) not in sys.path:
  14. sys.path.insert(0, str(project_root))
  15. from core.pipeline_manager import FinancialDocPipeline
  16. from universal_doc_parser.utils import OutputFormatter
  17. from dotenv import load_dotenv
  18. load_dotenv(override=True) # 加载环境变量
  19. def setup_logging(log_level: str = "INFO"):
  20. """设置日志"""
  21. logger.remove() # 移除默认处理器
  22. logger.add(sys.stdout, level=log_level, format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
  23. def main():
  24. parser = argparse.ArgumentParser(description="金融文档处理工具")
  25. parser.add_argument("--input", "-i", required=True, help="输入文档路径")
  26. parser.add_argument("--config", "-c", required=True, help="配置文件路径")
  27. parser.add_argument("--output_dir", "-o", default="./output", help="输出目录")
  28. parser.add_argument("--scene", "-s", choices=["bank_statement", "financial_report"],
  29. help="场景类型(会覆盖配置文件中的设置)")
  30. parser.add_argument("--log_level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"],
  31. help="日志级别")
  32. parser.add_argument("--dry_run", action="store_true", help="仅验证配置,不执行处理")
  33. args = parser.parse_args()
  34. # 设置日志
  35. setup_logging(args.log_level)
  36. # 验证输入文件
  37. input_path = Path(args.input)
  38. if not input_path.exists():
  39. logger.error(f"❌ Input file not found: {input_path}")
  40. return 1
  41. # 验证配置文件
  42. config_path = Path(args.config)
  43. if not config_path.exists():
  44. logger.error(f"❌ Config file not found: {config_path}")
  45. return 1
  46. try:
  47. # 初始化处理流水线
  48. with FinancialDocPipeline(str(config_path)) as pipeline:
  49. # 如果命令行指定了场景,覆盖配置文件
  50. if args.scene:
  51. pipeline.scene_name = args.scene
  52. logger.info(f"🔄 Scene overridden to: {args.scene}")
  53. logger.info(f"🚀 开始处理文档: {input_path}")
  54. logger.info(f"📋 使用场景配置: {pipeline.scene_name}")
  55. logger.info(f"📁 输出目录: {args.output_dir}")
  56. # 仅验证模式
  57. if args.dry_run:
  58. logger.info("✅ Dry run completed - configuration is valid")
  59. return 0
  60. # 处理文档
  61. results = pipeline.process_document(str(input_path))
  62. # 格式化输出
  63. logger.info("💾 Saving results...")
  64. formatter = OutputFormatter(args.output_dir)
  65. output_paths = formatter.save_results(results, pipeline.config['output'])
  66. logger.info(f"✅ 处理完成,结果保存至: {output_paths}")
  67. # 打印关键统计信息
  68. _print_summary(results)
  69. return 0
  70. except KeyboardInterrupt:
  71. logger.warning("⚠️ Process interrupted by user")
  72. return 1
  73. except Exception as e:
  74. logger.error(f"❌ Processing failed: {e}")
  75. if args.log_level == "DEBUG":
  76. logger.exception("Full traceback:")
  77. return 1
  78. def _print_summary(results: dict):
  79. """打印处理结果摘要"""
  80. total_pages = len(results['pages'])
  81. total_tables = sum(
  82. len([e for e in page['elements'] if e.get('type') == 'table'])
  83. for page in results['pages']
  84. )
  85. total_text_blocks = sum(
  86. len([e for e in page['elements'] if e.get('type') in ['text', 'title', 'ocr_text']])
  87. for page in results['pages']
  88. )
  89. total_formulas = sum(
  90. len([e for e in page['elements'] if e.get('type') == 'formula'])
  91. for page in results['pages']
  92. )
  93. print(f"\n📊 处理摘要:")
  94. print(f" 📄 文档: {results['document_path']}")
  95. print(f" 🎯 场景类型: {results['scene']}")
  96. print(f" 📖 页面数量: {total_pages}")
  97. print(f" 📋 表格数量: {total_tables}")
  98. print(f" 📝 文本块数量: {total_text_blocks}")
  99. print(f" 🧮 公式数量: {total_formulas}")
  100. if __name__ == "__main__":
  101. if len(sys.argv) == 1:
  102. # 如果没有命令行参数,使用默认配置运行
  103. print("ℹ️ No command line arguments provided. Running with default configuration...")
  104. # 默认配置
  105. default_config = {
  106. # "input": "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_022.png",
  107. "input": "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
  108. "output_dir": "./output/bank_statement_mineru_vl",
  109. "config": "./config/bank_statement_mineru_vl.yaml",
  110. # "output_dir": "./output/bank_statement_paddle_vl",
  111. # "config": "./config/bank_statement_paddle_vl.yaml",
  112. "scene": "bank_statement",
  113. "log_level": "DEBUG",
  114. }
  115. # 构造参数
  116. sys.argv = [sys.argv[0]]
  117. for key, value in default_config.items():
  118. sys.argv.extend([f"--{key}", str(value)])
  119. sys.exit(main())