ocr_comparator.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import os
  2. from typing import Dict
  3. from datetime import datetime
  4. # ✅ 兼容相对导入和绝对导入
  5. try:
  6. from .content_extractor import ContentExtractor
  7. from .table_comparator import TableComparator
  8. from .paragraph_comparator import ParagraphComparator
  9. except ImportError:
  10. from content_extractor import ContentExtractor
  11. from table_comparator import TableComparator
  12. from paragraph_comparator import ParagraphComparator
  13. class OCRResultComparator:
  14. """OCR结果比较器主类"""
  15. def __init__(self):
  16. self.content_extractor = ContentExtractor()
  17. self.table_comparator = TableComparator()
  18. self.paragraph_comparator = ParagraphComparator()
  19. self.differences = []
  20. self.paragraph_match_threshold = 80
  21. self.content_similarity_threshold = 95
  22. self.max_paragraph_window = 6
  23. self.table_comparison_mode = 'standard'
  24. self.header_similarity_threshold = 90
  25. def compare_files(self, file1_path: str, file2_path: str) -> Dict:
  26. """比较两个OCR结果文件"""
  27. print(f"\n📖 读取文件...")
  28. # 读取文件内容
  29. with open(file1_path, 'r', encoding='utf-8') as f:
  30. content1 = f.read()
  31. with open(file2_path, 'r', encoding='utf-8') as f:
  32. content2 = f.read()
  33. print(f"✅ 文件读取完成")
  34. print(f" 文件1大小: {len(content1)} 字符")
  35. print(f" 文件2大小: {len(content2)} 字符")
  36. # 提取表格
  37. print(f"\n📊 提取表格...")
  38. tables1 = self.content_extractor.extract_table_data(content1)
  39. tables2 = self.content_extractor.extract_table_data(content2)
  40. print(f" 文件1表格数: {len(tables1)}")
  41. print(f" 文件2表格数: {len(tables2)}")
  42. # 提取段落
  43. print(f"\n📝 提取段落...")
  44. paragraphs1 = self.content_extractor.extract_paragraphs(content1)
  45. paragraphs2 = self.content_extractor.extract_paragraphs(content2)
  46. print(f" 文件1段落数: {len(paragraphs1)}")
  47. print(f" 文件2段落数: {len(paragraphs2)}")
  48. # 比较段落
  49. print(f"\n🔍 开始段落对比...")
  50. paragraph_differences = self.paragraph_comparator.compare_paragraphs(
  51. paragraphs1, paragraphs2
  52. )
  53. print(f"✅ 段落对比完成,发现 {len(paragraph_differences)} 个差异")
  54. # ✅ 初始化所有差异列表 - 用于兼容原版本返回结构
  55. all_differences = []
  56. all_differences.extend(paragraph_differences)
  57. # 比较表格
  58. print(f"\n🔍 开始表格对比...")
  59. # ✅ 处理表格比较 - 支持多表格
  60. if tables1 and tables2:
  61. # 根据模式选择比较方法
  62. if self.table_comparison_mode == 'flow_list':
  63. table_diffs = self.table_comparator.compare_table_flow_list(tables1[0], tables2[0])
  64. else:
  65. table_diffs = self.table_comparator.compare_tables(tables1[0], tables2[0])
  66. all_differences.extend(table_diffs)
  67. print(f"✅ 表格对比完成,发现 {len(table_diffs)} 个差异")
  68. elif tables1 and not tables2:
  69. all_differences.append({
  70. 'type': 'table_structure',
  71. 'position': '表格结构',
  72. 'file1_value': f'包含{len(tables1)}个表格',
  73. 'file2_value': '无表格',
  74. 'description': '文件1包含表格但文件2无表格',
  75. 'severity': 'high'
  76. })
  77. elif not tables1 and tables2:
  78. all_differences.append({
  79. 'type': 'table_structure',
  80. 'position': '表格结构',
  81. 'file1_value': '无表格',
  82. 'file2_value': f'包含{len(tables2)}个表格',
  83. 'description': '文件2包含表格但文件1无表格',
  84. 'severity': 'high'
  85. })
  86. print(f"\n✅ 对比完成")
  87. # ✅ 统计差异 - 细化分类(与原版本保持一致)
  88. stats = {
  89. 'total_differences': len(all_differences),
  90. 'table_differences': len([d for d in all_differences if d['type'].startswith('table')]),
  91. 'paragraph_differences': len([d for d in all_differences if d['type'] == 'paragraph']),
  92. 'amount_differences': len([d for d in all_differences if d['type'] == 'table_amount']),
  93. 'datetime_differences': len([d for d in all_differences if d['type'] == 'table_datetime']),
  94. 'text_differences': len([d for d in all_differences if d['type'] == 'table_text']),
  95. 'table_pre_header': len([d for d in all_differences if d['type'] == 'table_pre_header']),
  96. 'table_header_mismatch': len([d for d in all_differences if d['type'] == 'table_header_mismatch']),
  97. 'table_header_critical': len([d for d in all_differences if d['type'] == 'table_header_critical']),
  98. 'table_header_position': len([d for d in all_differences if d['type'] == 'table_header_position']),
  99. 'table_row_missing': len([d for d in all_differences if d['type'] == 'table_row_missing']),
  100. 'high_severity': len([d for d in all_differences if d.get('severity') in ['critical', 'high']]),
  101. 'medium_severity': len([d for d in all_differences if d.get('severity') == 'medium']),
  102. 'low_severity': len([d for d in all_differences if d.get('severity') == 'low'])
  103. }
  104. # ✅ 构建返回结果 - 与原版本结构保持完全一致
  105. result = {
  106. 'differences': all_differences, # ✅ 原版本使用 differences 而非 paragraph_differences
  107. 'statistics': stats,
  108. 'file1_tables': len(tables1),
  109. 'file2_tables': len(tables2),
  110. 'file1_paragraphs': len(paragraphs1),
  111. 'file2_paragraphs': len(paragraphs2),
  112. 'file1_path': file1_path,
  113. 'file2_path': file2_path,
  114. 'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S') # ✅ 添加时间戳
  115. }
  116. print(f"\n" + "="*60)
  117. print(f"📊 对比结果汇总")
  118. print(f"="*60)
  119. print(f"总差异数: {result['statistics']['total_differences']}")
  120. print(f" - 段落差异: {result['statistics']['paragraph_differences']}")
  121. print(f" - 表格差异: {result['statistics']['table_differences']}")
  122. print(f" - 金额: {result['statistics']['amount_differences']}")
  123. print(f" - 日期: {result['statistics']['datetime_differences']}")
  124. print(f" - 文本: {result['statistics']['text_differences']}")
  125. print(f"\n严重级别分布:")
  126. print(f" 🔴 高: {result['statistics']['high_severity']}")
  127. print(f" 🟡 中: {result['statistics']['medium_severity']}")
  128. print(f" 🟢 低: {result['statistics']['low_severity']}")
  129. print(f"="*60)
  130. return result