Przeglądaj źródła

feat: 新增智能表格匹配功能,优化表格相似度计算和表头检测

zhch158_admin 4 tygodni temu
rodzic
commit
e02c0f2d22
1 zmienionych plików z 358 dodań i 50 usunięć
  1. 358 50
      comparator/table_comparator.py

+ 358 - 50
comparator/table_comparator.py

@@ -1,6 +1,6 @@
 import re
-from typing import Dict, List
-# ✅ 兼容相对导入和绝对导入
+from typing import Dict, List, Tuple, Optional
+
 try:
     from .data_type_detector import DataTypeDetector
     from .similarity_calculator import SimilarityCalculator
@@ -10,6 +10,7 @@ except ImportError:
     from similarity_calculator import SimilarityCalculator
     from text_processor import TextProcessor
 
+
 class TableComparator:
     """表格数据比较"""
     
@@ -21,13 +22,368 @@ class TableComparator:
         self.content_similarity_threshold = 95
         self.max_paragraph_window = 6
     
+    def find_matching_tables(self, tables1: List[List[List[str]]], 
+                            tables2: List[List[List[str]]]) -> List[Tuple[int, int, float]]:
+        """
+        智能匹配两个文件中的表格
+        
+        Returns:
+            List[Tuple[int, int, float]]: (table1_index, table2_index, similarity_score)
+        """
+        matches = []
+        
+        for i, table1 in enumerate(tables1):
+            if not table1:
+                continue
+            
+            best_match = None
+            best_score = 0
+            
+            for j, table2 in enumerate(tables2):
+                if not table2:
+                    continue
+                
+                # 计算表格相似度
+                score = self._calculate_table_similarity(table1, table2)
+                
+                if score > best_score:
+                    best_score = score
+                    best_match = j
+            
+            if best_match is not None and best_score > 50:  # 至少50%相似度
+                matches.append((i, best_match, best_score))
+                print(f"   📊 表格匹配: 文件1表格{i+1} ↔ 文件2表格{best_match+1} (相似度: {best_score:.1f}%)")
+        
+        return matches
+    
+    def _get_max_columns(self, table: List[List[str]]) -> int:
+        """获取表格的最大列数"""
+        if not table:
+            return 0
+        return max(len(row) for row in table)
+    
+    def _calculate_table_similarity(self, table1: List[List[str]], 
+                                   table2: List[List[str]]) -> float:
+        """计算两个表格的相似度"""
+        if not table1 or not table2:
+            return 0.0
+        
+        # 1. 行数相似度 (权重: 15%)
+        row_count1 = len(table1)
+        row_count2 = len(table2)
+        row_similarity = 100 * (1 - abs(row_count1 - row_count2) / max(row_count1, row_count2))
+        
+        # 2. 列数相似度 (权重: 15%) - ✅ 使用最大列数
+        col_count1 = self._get_max_columns(table1)
+        col_count2 = self._get_max_columns(table2)
+        
+        max_cols = max(col_count1, col_count2)
+        min_cols = min(col_count1, col_count2)
+        
+        if max_cols == 0:
+            col_similarity = 0
+        else:
+            # 如果列数差异在合理范围内(比如差1-2列),给予较高分数
+            col_diff = abs(col_count1 - col_count2)
+            if col_diff == 0:
+                col_similarity = 100
+            elif col_diff <= 2:
+                # 差1-2列,给予80-95分
+                col_similarity = 100 - (col_diff * 10)
+            else:
+                # 差异较大时,使用比例计算
+                col_similarity = 100 * (min_cols / max_cols)
+        
+        print(f"      行数对比: {row_count1} vs {row_count2}, 相似度: {row_similarity:.1f}%")
+        print(f"      列数对比: {col_count1} vs {col_count2}, 相似度: {col_similarity:.1f}%")
+        
+        # 3. 表头相似度 (权重: 50%) - ✅ 先检测表头位置
+        header_row_idx1 = self.detect_table_header_row(table1)
+        header_row_idx2 = self.detect_table_header_row(table2)
+        
+        print(f"      表头位置: 文件1第{header_row_idx1+1}行, 文件2第{header_row_idx2+1}行")
+        
+        header_similarity = 0
+        if header_row_idx1 < len(table1) and header_row_idx2 < len(table2):
+            header1 = table1[header_row_idx1]
+            header2 = table2[header_row_idx2]
+            
+            if header1 and header2:
+                # ✅ 智能表头匹配
+                header_similarity = self._calculate_header_similarity_smart(header1, header2)
+        
+        print(f"      表头相似度: {header_similarity:.1f}%")
+        
+        # 4. 内容特征相似度 (权重: 20%)
+        content_similarity = self._calculate_content_features_similarity(table1, table2)
+        
+        print(f"      内容特征相似度: {content_similarity:.1f}%")
+        
+        # ✅ 调整权重分配
+        total_similarity = (
+            row_similarity * 0.15 +      # 行数 15%
+            col_similarity * 0.15 +      # 列数 15%  
+            header_similarity * 0.50 +   # 表头 50% (最重要)
+            content_similarity * 0.20    # 内容 20%
+        )
+        
+        return total_similarity
+    
+    def _calculate_header_similarity_smart(self, header1: List[str], 
+                                          header2: List[str]) -> float:
+        """
+        智能计算表头相似度
+        
+        处理以下情况:
+        1. 列数不同但表头内容相似
+        2. PaddleOCR可能将多行表头合并
+        3. 表头顺序可能不同
+        """
+        if not header1 or not header2:
+            return 0.0
+        
+        # 标准化表头
+        norm_headers1 = [self.normalize_header_text(h) for h in header1]
+        norm_headers2 = [self.normalize_header_text(h) for h in header2]
+        
+        # 方法1: 精确匹配 (最高优先级)
+        common_headers = set(norm_headers1) & set(norm_headers2)
+        max_len = max(len(norm_headers1), len(norm_headers2))
+        
+        if max_len == 0:
+            return 0.0
+        
+        exact_match_ratio = len(common_headers) / max_len
+        
+        # 方法2: 模糊匹配 (针对列数不同的情况)
+        fuzzy_matches = 0
+        
+        # 使用较短的表头作为基准
+        if len(norm_headers1) <= len(norm_headers2):
+            base_headers = norm_headers1
+            compare_headers = norm_headers2
+        else:
+            base_headers = norm_headers2
+            compare_headers = norm_headers1
+        
+        for base_h in base_headers:
+            best_similarity = 0
+            for comp_h in compare_headers:
+                similarity = self.calculator.calculate_text_similarity(base_h, comp_h)
+                if similarity > best_similarity:
+                    best_similarity = similarity
+                    if best_similarity == 100:
+                        break
+            
+            # 如果相似度超过70%,认为是匹配的
+            if best_similarity > 70:
+                fuzzy_matches += 1
+        
+        fuzzy_match_ratio = fuzzy_matches / max_len if max_len > 0 else 0
+        
+        # 方法3: 关键字匹配 (识别常见表头)
+        key_headers = {
+            'date': ['日期', 'date', '时间', 'time'],
+            'type': ['类型', 'type', '业务', 'business'],
+            'number': ['号', 'no', '编号', 'id', '票据', 'bill'],
+            'description': ['摘要', 'description', '说明', 'remark'],
+            'amount': ['金额', 'amount', '借方', 'debit', '贷方', 'credit'],
+            'balance': ['余额', 'balance'],
+            'counterparty': ['对手', 'counterparty', '账户', 'account', '户名', 'name']
+        }
+        
+        def categorize_header(h: str) -> set:
+            categories = set()
+            h_lower = h.lower()
+            for category, keywords in key_headers.items():
+                for keyword in keywords:
+                    if keyword in h_lower:
+                        categories.add(category)
+            return categories
+        
+        categories1 = set()
+        for h in norm_headers1:
+            categories1.update(categorize_header(h))
+        
+        categories2 = set()
+        for h in norm_headers2:
+            categories2.update(categorize_header(h))
+        
+        common_categories = categories1 & categories2
+        all_categories = categories1 | categories2
+        
+        category_match_ratio = len(common_categories) / len(all_categories) if all_categories else 0
+        
+        # ✅ 综合三种方法,加权计算
+        final_similarity = (
+            exact_match_ratio * 0.4 +      # 精确匹配 40%
+            fuzzy_match_ratio * 0.4 +      # 模糊匹配 40%
+            category_match_ratio * 0.2     # 语义匹配 20%
+        ) * 100
+        
+        print(f"        精确匹配: {exact_match_ratio:.1%}, 模糊匹配: {fuzzy_match_ratio:.1%}, 语义匹配: {category_match_ratio:.1%}")
+        
+        return final_similarity
+    
+    def _calculate_content_features_similarity(self, table1: List[List[str]], 
+                                              table2: List[List[str]]) -> float:
+        """计算表格内容特征相似度"""
+        # 统计数字、日期等特征
+        features1 = self._extract_table_features(table1)
+        features2 = self._extract_table_features(table2)
+        
+        # 比较特征
+        similarity_scores = []
+        
+        for key in ['numeric_ratio', 'date_ratio', 'empty_ratio']:
+            if key in features1 and key in features2:
+                diff = abs(features1[key] - features2[key])
+                similarity_scores.append(100 * (1 - diff))
+        
+        return sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0
+    
+    def _extract_table_features(self, table: List[List[str]]) -> Dict:
+        """提取表格特征"""
+        total_cells = 0
+        numeric_cells = 0
+        date_cells = 0
+        empty_cells = 0
+        
+        for row in table:
+            for cell in row:
+                total_cells += 1
+                
+                if not cell or cell.strip() == '':
+                    empty_cells += 1
+                    continue
+                
+                if self.detector.is_numeric(cell):
+                    numeric_cells += 1
+                
+                if self.detector.extract_datetime(cell):
+                    date_cells += 1
+        
+        return {
+            'numeric_ratio': numeric_cells / total_cells if total_cells > 0 else 0,
+            'date_ratio': date_cells / total_cells if total_cells > 0 else 0,
+            'empty_ratio': empty_cells / total_cells if total_cells > 0 else 0,
+            'total_cells': total_cells
+        }
+    
     def normalize_header_text(self, text: str) -> str:
         """标准化表头文本"""
+        # 移除括号内容
         text = re.sub(r'[((].*?[))]', '', text)
+        # 移除空格
         text = re.sub(r'\s+', '', text)
+        # 只保留字母、数字和中文
         text = re.sub(r'[^\w\u4e00-\u9fff]', '', text)
         return text.lower().strip()
     
+    def detect_table_header_row(self, table: List[List[str]]) -> int:
+        """
+        智能检测表格的表头行索引
+        
+        检测策略:
+        1. 查找包含表头关键字最多的行
+        2. 确认下一行是数据行
+        3. 避免将合并单元格的元数据行误判为表头
+        """
+        if not table:
+            return 0
+        
+        header_keywords = [
+            '日期', 'date', '时间', 'time',
+            '类型', 'type', '业务', 'business',
+            '号', 'no', '编号', 'id', '票据', 'bill',
+            '摘要', 'description', '说明', 'remark',
+            '金额', 'amount', '借方', 'debit', '贷方', 'credit',
+            '余额', 'balance',
+            '对手', 'counterparty', '账户', 'account', '户名', 'name'
+        ]
+        
+        best_header_row = 0
+        best_score = 0
+        
+        for row_idx, row in enumerate(table[:5]):  # 只检查前5行
+            if not row:
+                continue
+            
+            # 计算关键字匹配分数
+            keyword_count = 0
+            non_empty_cells = 0
+            
+            for cell in row:
+                cell_text = str(cell).strip()
+                if cell_text:
+                    non_empty_cells += 1
+                    cell_lower = cell_text.lower()
+                    
+                    for keyword in header_keywords:
+                        if keyword in cell_lower:
+                            keyword_count += 1
+                            break
+            
+            # 避免空行或几乎空的行
+            if non_empty_cells < 3:
+                continue
+            
+            # 计算得分:关键字比例 + 列数奖励
+            keyword_ratio = keyword_count / non_empty_cells if non_empty_cells > 0 else 0
+            column_bonus = min(non_empty_cells / 5, 1.0)  # 列数越多,奖励越高
+            score = keyword_ratio * 0.7 + column_bonus * 0.3
+            
+            # 如果下一行是数据行,加分
+            if row_idx + 1 < len(table):
+                next_row = table[row_idx + 1]
+                if self._is_data_row(next_row):
+                    score += 0.2
+            
+            if score > best_score:
+                best_score = score
+                best_header_row = row_idx
+        
+        # 如果最佳得分太低,返回0(第一行)
+        if best_score < 0.3:
+            print(f"   ⚠️  未检测到明确表头,默认使用第1行 (得分: {best_score:.2f})")
+            return 0
+        
+        print(f"   📍 检测到表头在第 {best_header_row + 1} 行 (得分: {best_score:.2f})")
+        return best_header_row
+    
+    def _is_data_row(self, row: List[str]) -> bool:
+        """判断是否为数据行"""
+        if not row:
+            return False
+        
+        data_pattern_count = 0
+        non_empty_count = 0
+        
+        for cell in row:
+            cell_text = str(cell).strip()
+            if not cell_text:
+                continue
+            
+            non_empty_count += 1
+            
+            # 包含数字
+            if re.search(r'\d', cell_text):
+                data_pattern_count += 1
+            
+            # 包含日期格式
+            if re.search(r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}', cell_text):
+                data_pattern_count += 1
+            
+            # 包含金额格式
+            if re.search(r'-?\d+[,,]?\d*\.?\d+', cell_text):
+                data_pattern_count += 1
+        
+        if non_empty_count == 0:
+            return False
+        
+        # 至少30%的单元格包含数据特征
+        return data_pattern_count / non_empty_count >= 0.3
+    
     def compare_table_headers(self, headers1: List[str], headers2: List[str]) -> Dict:
         """比较表格表头"""
         result = {
@@ -74,54 +430,6 @@ class TableComparator:
         
         return result
     
-    def detect_table_header_row(self, table: List[List[str]]) -> int:
-        """智能检测表格的表头行索引"""
-        header_keywords = [
-            '序号', '编号', '时间', '日期', '名称', '类型', '金额', '数量', '单价',
-            '备注', '说明', '状态', '类别', '方式', '账号', '单号', '订单',
-            '交易单号', '交易时间', '交易类型', '收/支', '支出', '收入', 
-            '交易方式', '交易对方', '商户单号', '付款方式', '收款方',
-            'no', 'id', 'time', 'date', 'name', 'type', 'amount', 'status'
-        ]
-        
-        for row_idx, row in enumerate(table):
-            if not row:
-                continue
-            
-            keyword_count = 0
-            for cell in row:
-                cell_lower = cell.lower().strip()
-                for keyword in header_keywords:
-                    if keyword in cell_lower:
-                        keyword_count += 1
-                        break
-            
-            if keyword_count >= len(row) * 0.4 and keyword_count >= 2:
-                if row_idx + 1 < len(table):
-                    next_row = table[row_idx + 1]
-                    if self._is_data_row(next_row):
-                        print(f"   📍 检测到表头在第 {row_idx + 1} 行")
-                        return row_idx
-        
-        print(f"   ⚠️  未检测到明确表头,默认使用第1行")
-        return 0
-    
-    def _is_data_row(self, row: List[str]) -> bool:
-        """判断是否为数据行"""
-        data_pattern_count = 0
-        
-        for cell in row:
-            if not cell:
-                continue
-            
-            if re.search(r'\d', cell):
-                data_pattern_count += 1
-            
-            if re.search(r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}', cell):
-                data_pattern_count += 1
-        
-        return data_pattern_count >= len(row) * 0.5
-    
     def compare_cell_value(self, value1: str, value2: str, column_type: str, 
                           column_name: str = '') -> Dict:
         """比较单元格值"""