瀏覽代碼

feat: 优化表格单元格匹配逻辑,增强动态规划支持跳过HTML行,改进相似度计算

zhch158_admin 2 天之前
父節點
當前提交
631d6cdd7d
共有 1 個文件被更改,包括 222 次插入200 次删除
  1. 222 200
      merger/table_cell_matcher.py

+ 222 - 200
merger/table_cell_matcher.py

@@ -7,6 +7,11 @@ from bs4 import BeautifulSoup
 import numpy as np
 
 try:
+    from rapidfuzz import fuzz
+except ImportError:
+    from fuzzywuzzy import fuzz
+
+try:
     from .text_matcher import TextMatcher
     from .bbox_extractor import BBoxExtractor
 except ImportError:
@@ -441,7 +446,6 @@ class TableCellMatcher:
             # 🔑 降级:如果精确匹配失败,使用模糊匹配
             print("   ℹ️ 精确匹配失败,尝试模糊匹配...")
             
-            from fuzzywuzzy import fuzz
             for box in paddle_boxes:
                 normalized_text = self.text_matcher.normalize_text(box['text'])
                 
@@ -574,7 +578,7 @@ class TableCellMatcher:
     def _match_html_rows_to_paddle_groups(self, html_rows: List, 
                                         grouped_boxes: List[Dict]) -> Dict[int, List[int]]:
         """
-        智能匹配 HTML 行与 paddle 分组(优化版:支持跳过无关组 + 防贪婪
+        智能匹配 HTML 行与 paddle 分组(增强版 DP:支持跳过 HTML 行,防止链条断裂
         """
         if not html_rows or not grouped_boxes:
             return {}
@@ -587,222 +591,247 @@ class TableCellMatcher:
                 mapping[i] = [i]
             return mapping
         
-        # 🎯 策略 2: 基于内容匹配(带跳过机制的单调匹配)
-        from fuzzywuzzy import fuzz
-        used_groups = set()
-        next_group_to_check = 0
-        
-        for row_idx, row in enumerate(html_rows):
-            row_cells = row.find_all(['td', 'th'])
-            row_texts = [cell.get_text(strip=True) for cell in row_cells]
-            row_texts = [t for t in row_texts if t]
+        # --- 准备数据 ---
+        # 提取 HTML 文本
+        html_row_texts = []
+        for row in html_rows:
+            cells = row.find_all(['td', 'th'])
+            texts = [self.text_matcher.normalize_text(c.get_text(strip=True)) for c in cells]
+            html_row_texts.append("".join(texts))
+
+        # 预计算所有组的文本
+        group_texts = []
+        for group in grouped_boxes:
+            boxes = group['boxes']
+            texts = [self.text_matcher.normalize_text(b['text']) for b in boxes]
+            group_texts.append("".join(texts))
+
+        n_html = len(html_row_texts)
+        n_paddle = len(grouped_boxes)
+
+        # ⚡️ 优化 3: 预计算合并文本
+        MAX_MERGE = 4
+        merged_cache = {}
+        for j in range(n_paddle):
+            current_t = ""
+            for k in range(MAX_MERGE):
+                if j + k < n_paddle:
+                    current_t += group_texts[j + k]
+                    merged_cache[(j, k + 1)] = current_t
+                else:
+                    break
+
+        # --- 动态规划 (DP) ---
+        # dp[i][j] 表示:HTML 前 i 行 (0..i) 匹配到了 Paddle 的前 j 组 (0..j) 的最大得分
+        # 初始化为负无穷
+        dp = np.full((n_html, n_paddle), -np.inf)
+        # 记录路径:path[i][j] = (prev_j, start_j) 
+        # prev_j: 上一行结束的 paddle index
+        # start_j: 当前行开始的 paddle index (因为一行可能对应多个组)
+        path = {} 
+
+        # 参数配置
+        SEARCH_WINDOW = 15  # 向前搜索窗口
+        SKIP_PADDLE_PENALTY = 0.1  # 跳过 Paddle 组的惩罚
+        SKIP_HTML_PENALTY = 0.3    # 关键:跳过 HTML 行的惩罚        
+        # --- 1. 初始化第一行 ---
+        # 选项 A: 匹配 Paddle 组
+        for end_j in range(min(n_paddle, SEARCH_WINDOW + MAX_MERGE)):
+            for count in range(1, MAX_MERGE + 1):
+                start_j = end_j - count + 1
+                if start_j < 0: continue
+                
+                current_text = merged_cache.get((start_j, count), "")
+                similarity = self._calculate_similarity(html_row_texts[0], current_text)
+                
+                penalty = start_j * SKIP_PADDLE_PENALTY
+                score = similarity - penalty
+                
+                # 只有得分尚可才作为有效状态
+                if score > 0.1:
+                    if score > dp[0][end_j]:
+                        dp[0][end_j] = score
+                        path[(0, end_j)] = (-1, start_j)
+        
+        # 选项 B: 第一行就跳过 (虽然少见,但为了完整性)
+        # 如果第一行跳过,相当于没有消耗任何 paddle 组,状态难以用 dp[0][j] 表达
+        # 这里简化处理,假设第一行必须匹配点什么,或者由后续行修正
+
+        # --- 2. 状态转移 ---
+        for i in range(1, n_html):
+            html_text = html_row_texts[i]
             
-            # 提取行首文本(通常是项目名称),用于加权匹配
-            row_header = row_texts[0] if row_texts else ""
+            # 获取上一行所有有效位置
+            valid_prev_indices = [j for j in range(n_paddle) if dp[i-1][j] > -np.inf]
             
-            if not row_texts:
-                mapping[row_idx] = []
+            # 剪枝
+            if len(valid_prev_indices) > 30:
+                valid_prev_indices.sort(key=lambda j: dp[i-1][j], reverse=True)
+                valid_prev_indices = valid_prev_indices[:30]
+
+            # 🛡️ 关键修复:允许跳过当前 HTML 行 (继承上一行的状态)
+            # 如果跳过当前行,Paddle 指针 j 不变
+            for prev_j in valid_prev_indices:
+                score_skip = dp[i-1][prev_j] - SKIP_HTML_PENALTY
+                if score_skip > dp[i][prev_j]:
+                    dp[i][prev_j] = score_skip
+                    # 记录路径:start_j = prev_j + 1 表示没有消耗新组 (空范围)
+                    path[(i, prev_j)] = (prev_j, prev_j + 1)
+
+            # 如果是空行,直接跳过计算,仅保留继承的状态
+            if not html_text:
                 continue
-            
-            row_text_normalized = [self.text_matcher.normalize_text(t) for t in row_texts]
-            row_combined_text = ''.join(row_text_normalized)
-            
-            best_groups = []
-            best_score = 0
-            
-            # 🆕 动态跳过窗口:首行允许跳过较多(处理文档标题),后续行跳过较少(处理噪声)
-            max_skip = 15 if row_idx == 0 else 5
-            
-            # 遍历可能的跳过数量
-            for skip in range(max_skip + 1):
-                start_group = next_group_to_check + skip
-                
-                if start_group >= len(grouped_boxes):
-                    break
+
+            # 正常匹配逻辑
+            for prev_j in valid_prev_indices:
+                prev_score = dp[i-1][prev_j]
                 
-                # 尝试合并不同数量的组 (1-5)
-                max_merge_window = 5
+                max_gap = min(SEARCH_WINDOW, n_paddle - prev_j - 1)
                 
-                for group_count in range(1, max_merge_window + 1):
-                    end_group = start_group + group_count
-                    if end_group > len(grouped_boxes):
-                        break
-
-                    combined_group_indices = list(range(start_group, end_group))
-                    
-                    # 收集组内所有文本
-                    combined_texts = []
+                for gap in range(max_gap):
+                    start_j = prev_j + 1 + gap
                     
-                    for g_idx in combined_group_indices:
-                        group_boxes = grouped_boxes[g_idx].get('boxes', [])
-                        for box in group_boxes:
-                            if box.get('used'):
-                                continue
-                            normalized_text = self.text_matcher.normalize_text(box.get('text', ''))
-                            if normalized_text:
-                                combined_texts.append(normalized_text)
-
-                    if not combined_texts:
-                        continue
-                    
-                    paddle_combined_text = ''.join(combined_texts)
-                    
-                    # --- 评分逻辑 ---
-                    match_count = 0
-                    
-                    # 1. 单元格覆盖率
-                    for rt in row_text_normalized:
-                        if len(rt) < 2: 
+                    for count in range(1, MAX_MERGE + 1):
+                        end_j = start_j + count - 1
+                        if end_j >= n_paddle: break
+                        
+                        current_text = merged_cache.get((start_j, count), "")
+                        
+                        # 长度预筛选
+                        h_len = len(html_text)
+                        p_len = len(current_text)
+                        if h_len > 10 and p_len < h_len * 0.2:
                             continue
-                        if rt in paddle_combined_text:
-                            match_count += 1
-                            continue
-                        for ct in combined_texts:
-                            if fuzz.partial_ratio(rt, ct) >= 80:
-                                match_count += 1
-                                break
-                    
-                    coverage = match_count / len(row_texts) if row_texts else 0
-                    
-                    # 2. 整行相似度
-                    row_similarity = fuzz.partial_ratio(row_combined_text, paddle_combined_text) / 100.0
+
+                        similarity = self._calculate_similarity(html_text, current_text)
+                        
+                        # 计算惩罚
+                        # 1. 跳过惩罚 (gap)
+                        # 2. 长度惩罚 (防止过度合并)
+                        len_penalty = 0.0
+                        if h_len > 0:
+                            ratio = p_len / h_len
+                            if ratio > 2.0: len_penalty = (ratio - 2.0) * 0.2
+
+                        current_score = similarity - (gap * SKIP_PADDLE_PENALTY) - len_penalty
+                        
+                        # 只有正收益才转移
+                        if current_score > 0.1:
+                            total_score = prev_score + current_score
+                            
+                            if total_score > dp[i][end_j]:
+                                dp[i][end_j] = total_score
+                                path[(i, end_j)] = (prev_j, start_j)
+
+        # --- 3. 回溯找最优路径 ---
+        # 找到最后一行得分最高的结束位置
+        best_end_j = -1
+        max_score = -np.inf
+        
+        # 优先找最后一行,如果最后一行没匹配上,往前找
+        found_end = False
+        for i in range(n_html - 1, -1, -1):
+            for j in range(n_paddle):
+                if dp[i][j] > max_score:
+                    max_score = dp[i][j]
+                    best_end_j = j
+                    best_last_row = i
+            if max_score > -np.inf:
+                found_end = True
+                break
+        
+        mapping = {}
+        used_groups = set()
+        
+        if found_end:
+            curr_i = best_last_row
+            curr_j = best_end_j
+            
+            while curr_i >= 0:
+                if (curr_i, curr_j) in path:
+                    prev_j, start_j = path[(curr_i, curr_j)]
                     
-                    # 3. 表头关键匹配(加权)
-                    header_score = 0
-                    if len(row_header) > 1:
-                        if row_header in paddle_combined_text:
-                            header_score = 1.0
-                        else:
-                            header_sim = fuzz.partial_ratio(row_header, paddle_combined_text)
-                            if header_sim > 80:
-                                header_score = 0.8
+                    # 如果 start_j <= curr_j,说明消耗了 Paddle 组
+                    # 如果 start_j > curr_j,说明是跳过 HTML 行 (空范围)
+                    if start_j <= curr_j:
+                        indices = list(range(start_j, curr_j + 1))
+                        mapping[curr_i] = indices
+                        used_groups.update(indices)
                     else:
-                        header_score = 0.5
-                    
-                    final_score = (coverage * 0.3) + (row_similarity * 0.3) + (header_score * 0.4)
+                        mapping[curr_i] = []
                     
-                    # 🔑 惩罚项:合并惩罚 + 跳过惩罚
-                    # 优先选择:不跳过 > 少合并
-                    merge_penalty = (group_count - 1) * 0.05
-                    skip_penalty = skip * 0.02
-                    
-                    adjusted_score = final_score - merge_penalty - skip_penalty
-                    
-                    if adjusted_score > best_score:
-                        best_score = adjusted_score
-                        best_groups = combined_group_indices
-                    
-                    # 早停:如果单组匹配极好,不尝试合并更多
-                    if group_count == 1 and final_score > 0.85:
-                        break
-                
-                # 优化:如果当前 skip 找到了非常好的匹配,就不再尝试更大的 skip
-                # 避免跳过正确的组去匹配后面相似的组
-                if best_score > 0.85:
+                    curr_j = prev_j
+                    curr_i -= 1
+                else:
                     break
-            
-            # 判定匹配
-            if best_groups and best_score >= 0.4:
-                mapping[row_idx] = best_groups
-                used_groups.update(best_groups)
-                next_group_to_check = max(best_groups) + 1
-                print(f"   ✓ 行 {row_idx} ('{row_header[:10]}...'): 匹配组 {best_groups} (得分: {best_score:.2f})")
-            else:
-                mapping[row_idx] = []
-                # 如果没匹配上,next_group_to_check 不变,给下一行机会
-                print(f"   ✗ 行 {row_idx} ('{row_header[:10]}...'): 无匹配 (最佳得分: {best_score:.2f})")
+        
+        # 填补未匹配的行
+        for i in range(n_html):
+            if i not in mapping:
+                mapping[i] = []
 
-        # 🎯 策略 3: 第二遍 - 处理未使用的组(关键!)
+        # --- 4. 后处理:未匹配组的归属 (Orphans) ---
         unused_groups = [i for i in range(len(grouped_boxes)) if i not in used_groups]
         
         if unused_groups:
             print(f"   ℹ️ 发现 {len(unused_groups)} 个未匹配的 paddle 组: {unused_groups}")
-            
-            # 🔑 将未使用的组合并到相邻的已匹配行
             for unused_idx in unused_groups:
-                # 🎯 关键改进:计算与相邻行的边界距离
                 unused_group = grouped_boxes[unused_idx]
                 unused_y_min = min(b['bbox'][1] for b in unused_group['boxes'])
                 unused_y_max = max(b['bbox'][3] for b in unused_group['boxes'])
                 
-                # 🔑 查找上方和下方最近的已使用组
                 above_idx = None
                 below_idx = None
                 above_distance = float('inf')
                 below_distance = float('inf')
                 
-                # 向上查找
                 for i in range(unused_idx - 1, -1, -1):
                     if i in used_groups:
                         above_idx = i
-                        # 🎯 边界距离:unused 的最小 y - above 的最大 y
                         above_group = grouped_boxes[i]
-                        max_y_box = max(
-                            above_group['boxes'],
-                            key=lambda b: b['bbox'][3]
-                        )
+                        max_y_box = max(above_group['boxes'], key=lambda b: b['bbox'][3])
                         above_y_center = (max_y_box['bbox'][1] + max_y_box['bbox'][3]) / 2
                         above_distance = abs(unused_y_min - above_y_center)
-                        print(f"      • 组 {unused_idx} 与上方组 {i} 距离: {above_distance:.1f}px")
                         break
                 
-                # 向下查找
                 for i in range(unused_idx + 1, len(grouped_boxes)):
                     if i in used_groups:
                         below_idx = i
-                        # 🎯 边界距离:below 的最小 y - unused 的最大 y
                         below_group = grouped_boxes[i]
-                        min_y_box = min(
-                            below_group['boxes'],
-                            key=lambda b: b['bbox'][1]
-                        )
+                        min_y_box = min(below_group['boxes'], key=lambda b: b['bbox'][1])
                         below_y_center = (min_y_box['bbox'][1] + min_y_box['bbox'][3]) / 2
                         below_distance = abs(below_y_center - unused_y_max)
-                        print(f"      • 组 {unused_idx} 与下方组 {i} 距离: {below_distance:.1f}px")
                         break
                 
-                # 🎯 选择距离更近的一侧
+                closest_used_idx = None
+                merge_direction = ""
+                
                 if above_idx is not None and below_idx is not None:
-                    # 都存在,选择距离更近的
                     if above_distance < below_distance:
                         closest_used_idx = above_idx
                         merge_direction = "上方"
                     else:
                         closest_used_idx = below_idx
                         merge_direction = "下方"
-                    print(f"      ✓ 组 {unused_idx} 选择合并到{merge_direction}组 {closest_used_idx}")
                 elif above_idx is not None:
                     closest_used_idx = above_idx
                     merge_direction = "上方"
                 elif below_idx is not None:
                     closest_used_idx = below_idx
                     merge_direction = "下方"
-                else:
-                    print(f"      ⚠️ 组 {unused_idx} 无相邻已使用组,跳过")
-                    continue
-                
-                # 🔑 找到该组对应的 HTML 行
-                target_html_row = None
-                for html_row_idx, group_indices in mapping.items():
-                    if closest_used_idx in group_indices:
-                        target_html_row = html_row_idx
-                        break
-                
-                if target_html_row is not None:
-                    # 🎯 根据合并方向决定目标行
-                    if merge_direction == "上方":
-                        # 合并到上方对应的 HTML 行
-                        if target_html_row in mapping:
-                            if unused_idx not in mapping[target_html_row]:
-                                mapping[target_html_row].append(unused_idx)
-                                print(f"      • 组 {unused_idx} 合并到 HTML 行 {target_html_row}(上方行)")
-                    else:
-                        # 合并到下方对应的 HTML 行
-                        if target_html_row in mapping:
-                            if unused_idx not in mapping[target_html_row]:
-                                mapping[target_html_row].append(unused_idx)
-                                print(f"      • 组 {unused_idx} 合并到 HTML 行 {target_html_row}(下方行)")
                 
+                if closest_used_idx is not None:
+                    target_html_row = None
+                    for html_row_idx, group_indices in mapping.items():
+                        if closest_used_idx in group_indices:
+                            target_html_row = html_row_idx
+                            break
+                    
+                    if target_html_row is not None:
+                        if unused_idx not in mapping[target_html_row]:
+                            mapping[target_html_row].append(unused_idx)
+                            mapping[target_html_row].sort()
+                            print(f"      • 组 {unused_idx} 合并到 HTML 行 {target_html_row}({merge_direction}行)")                
                 used_groups.add(unused_idx)
         
         # 🔑 策略 4: 第三遍 - 按 y 坐标排序每行的组索引
@@ -812,48 +841,43 @@ class TableCellMatcher:
         
         return mapping
 
-    def _preprocess_close_groups(self, grouped_boxes: List[Dict], 
-                                y_gap_threshold: int = 10) -> List[List[int]]:
+    def _calculate_similarity(self, text1: str, text2: str) -> float:
         """
-        🆕 预处理:将 y 间距很小的组预合并
-        
-        Args:
-            grouped_boxes: 原始分组
-            y_gap_threshold: Y 间距阈值(小于此值认为是同一行)
-        
-        Returns:
-            预处理后的组索引列表 [[0,1], [2], [3,4,5], ...]
+        计算两个文本的相似度,结合字符覆盖率和序列相似度 (性能优化版)
         """
-        if not grouped_boxes:
-            return []
+        if not text1 or not text2:
+            return 0.0
+            
+        len1, len2 = len(text1), len(text2)
+        
+        # ⚡️ 优化 1: 长度快速检查
+        # 如果长度差异过大(例如一个 50 字符,一个 2 字符),直接认为不匹配
+        if len1 > 0 and len2 > 0:
+            min_l, max_l = min(len1, len2), max(len1, len2)
+            if max_l > 10 and min_l / max_l < 0.2:
+                return 0.0
+
+        # 1. 字符覆盖率 (Character Overlap)
+        from collections import Counter
+        c1 = Counter(text1)
+        c2 = Counter(text2)
         
-        preprocessed = []
-        current_group = [0]
+        intersection = c1 & c2
+        overlap_count = sum(intersection.values())
         
-        for i in range(1, len(grouped_boxes)):
-            prev_group = grouped_boxes[i - 1]
-            curr_group = grouped_boxes[i]
-            
-            # 计算间距
-            prev_y_max = max(b['bbox'][3] for b in prev_group['boxes'])
-            curr_y_min = min(b['bbox'][1] for b in curr_group['boxes'])
-            
-            gap = abs(curr_y_min - prev_y_max)
-            
-            if gap <= y_gap_threshold:
-                # 间距很小,合并
-                current_group.append(i)
-                print(f"   预合并: 组 {i-1} 和 {i} (间距: {gap}px)")
-            else:
-                # 间距较大,开始新组
-                preprocessed.append(current_group)
-                current_group = [i]
+        coverage = overlap_count / len1 if len1 > 0 else 0
         
-        # 添加最后一组
-        if current_group:
-            preprocessed.append(current_group)
+        # ⚡️ 优化 2: 覆盖率低时跳过昂贵的 fuzz 计算
+        # 如果字符重叠率低于 30%,说明内容基本不相关,没必要算序列相似度
+        if coverage < 0.3:
+            return coverage * 0.7
+
+        # 2. 序列相似度 (Sequence Similarity)
+        from fuzzywuzzy import fuzz
+        # 使用 token_sort_ratio 来容忍一定的乱序
+        seq_score = fuzz.token_sort_ratio(text1, text2) / 100.0
         
-        return preprocessed
+        return (coverage * 0.7) + (seq_score * 0.3)
 
     def _match_cell_sequential(self, cell_text: str, 
                             boxes: List[Dict],
@@ -878,8 +902,6 @@ class TableCellMatcher:
             'paddle_indices': [idx1, idx2], 'used_boxes': [box1, box2],
             'last_used_index': int}
         """
-        from fuzzywuzzy import fuzz
-        
         cell_text_normalized = self.text_matcher.normalize_text(cell_text)
         
         if len(cell_text_normalized) < 2: