Browse Source

feat: 优化HTML行与Paddle分组的匹配逻辑,支持跳过无关组并防止贪婪匹配

zhch158_admin 1 day ago
parent
commit
2e6f0e8da0
1 changed files with 90 additions and 88 deletions
  1. 90 88
      merger/table_cell_matcher.py

+ 90 - 88
merger/table_cell_matcher.py

@@ -758,11 +758,7 @@ class TableCellMatcher:
     def _match_html_rows_to_paddle_groups(self, html_rows: List, 
                                         grouped_boxes: List[Dict]) -> Dict[int, List[int]]:
         """
-        智能匹配 HTML 行与 paddle 分组(修正版:严格顺序匹配)
-
-        策略:
-        1. 数量相等:1:1 映射
-        2. 数量不等:按内容匹配,但保持 y 坐标顺序
+        智能匹配 HTML 行与 paddle 分组(优化版:支持跳过无关组 + 防贪婪)
         """
         if not html_rows or not grouped_boxes:
             return {}
@@ -775,15 +771,19 @@ class TableCellMatcher:
                 mapping[i] = [i]
             return mapping
         
-        # 🎯 策略 2: 基于内容匹配(修正版:严格单调递增
+        # 🎯 策略 2: 基于内容匹配(带跳过机制的单调匹配
         from fuzzywuzzy import fuzz
         used_groups = set()
-        next_group_to_check = 0  # 🔑 关键改进:维护全局组索引
+        next_group_to_check = 0
         
         for row_idx, row in enumerate(html_rows):
-            row_texts = [cell.get_text(strip=True) for cell in row.find_all(['td', 'th'])]
+            row_cells = row.find_all(['td', 'th'])
+            row_texts = [cell.get_text(strip=True) for cell in row_cells]
             row_texts = [t for t in row_texts if t]
             
+            # 提取行首文本(通常是项目名称),用于加权匹配
+            row_header = row_texts[0] if row_texts else ""
+            
             if not row_texts:
                 mapping[row_idx] = []
                 continue
@@ -794,105 +794,107 @@ class TableCellMatcher:
             best_groups = []
             best_score = 0
             
-            # 🔑 关键改进:从 next_group_to_check 开始搜索
-            max_window = 5
-            for group_count in range(1, max_window + 1):
-                # 🔑 从当前位置开始,而不是从第一个未使用的组
-                start_group = next_group_to_check
-                end_group = start_group + group_count
+            # 🆕 动态跳过窗口:首行允许跳过较多(处理文档标题),后续行跳过较少(处理噪声)
+            max_skip = 15 if row_idx == 0 else 5
+            
+            # 遍历可能的跳过数量
+            for skip in range(max_skip + 1):
+                start_group = next_group_to_check + skip
                 
-                if end_group > len(grouped_boxes):
+                if start_group >= len(grouped_boxes):
                     break
-
-                combined_group_indices = list(range(start_group, end_group))
                 
-                # 🔑 跳过已使用的组(但不重新计算 start_group)
-                if any(idx in used_groups for idx in combined_group_indices):
-                    continue
-
-                # 收集组内所有文本
-                combined_texts = []
-                for g_idx in combined_group_indices:
-                    group_boxes = grouped_boxes[g_idx].get('boxes', [])
-                    for box in group_boxes:
-                        if box.get('used'):
-                            continue
-                        normalized_text = self.text_matcher.normalize_text(box.get('text', ''))
-                        if normalized_text:
-                            combined_texts.append(normalized_text)
-
-                if not combined_texts:
-                    continue
+                # 尝试合并不同数量的组 (1-5)
+                max_merge_window = 5
                 
-                paddle_combined_text = ''.join(combined_texts)
-                
-                # 匹配策略
-                match_count = 0
-                
-                for rt in row_text_normalized:
-                    if len(rt) < 2:
-                        continue
+                for group_count in range(1, max_merge_window + 1):
+                    end_group = start_group + group_count
+                    if end_group > len(grouped_boxes):
+                        break
+
+                    combined_group_indices = list(range(start_group, end_group))
                     
-                    # 精确匹配
-                    if any(rt == ct for ct in combined_texts):
-                        match_count += 1
-                        continue
+                    # 收集组内所有文本
+                    combined_texts = []
                     
-                    # 子串匹配
-                    if any(rt in ct or ct in rt for ct in combined_texts):
-                        match_count += 1
+                    for g_idx in combined_group_indices:
+                        group_boxes = grouped_boxes[g_idx].get('boxes', [])
+                        for box in group_boxes:
+                            if box.get('used'):
+                                continue
+                            normalized_text = self.text_matcher.normalize_text(box.get('text', ''))
+                            if normalized_text:
+                                combined_texts.append(normalized_text)
+
+                    if not combined_texts:
                         continue
                     
-                    # 在合并文本中查找
-                    if rt in paddle_combined_text:
-                        match_count += 1
-                        continue
+                    paddle_combined_text = ''.join(combined_texts)
+                    
+                    # --- 评分逻辑 ---
+                    match_count = 0
                     
-                    # 模糊匹配
-                    for ct in combined_texts:
-                        similarity = fuzz.partial_ratio(rt, ct)
-                        if similarity >= 75:
+                    # 1. 单元格覆盖率
+                    for rt in row_text_normalized:
+                        if len(rt) < 2: 
+                            continue
+                        if rt in paddle_combined_text:
                             match_count += 1
-                            break
-                
-                # 整行匹配
-                row_similarity = fuzz.partial_ratio(row_combined_text, paddle_combined_text)
-                
-                coverage = match_count / len(row_texts) if row_texts else 0
-                combined_coverage = row_similarity / 100.0
-                
-                final_score = max(coverage, combined_coverage)
-                
-                if final_score > best_score:
-                    best_score = final_score
-                    best_groups = combined_group_indices
+                            continue
+                        for ct in combined_texts:
+                            if fuzz.partial_ratio(rt, ct) >= 80:
+                                match_count += 1
+                                break
+                    
+                    coverage = match_count / len(row_texts) if row_texts else 0
+                    
+                    # 2. 整行相似度
+                    row_similarity = fuzz.partial_ratio(row_combined_text, paddle_combined_text) / 100.0
+                    
+                    # 3. 表头关键匹配(加权)
+                    header_score = 0
+                    if len(row_header) > 1:
+                        if row_header in paddle_combined_text:
+                            header_score = 1.0
+                        else:
+                            header_sim = fuzz.partial_ratio(row_header, paddle_combined_text)
+                            if header_sim > 80:
+                                header_score = 0.8
+                    else:
+                        header_score = 0.5
+                    
+                    final_score = (coverage * 0.3) + (row_similarity * 0.3) + (header_score * 0.4)
+                    
+                    # 🔑 惩罚项:合并惩罚 + 跳过惩罚
+                    # 优先选择:不跳过 > 少合并
+                    merge_penalty = (group_count - 1) * 0.05
+                    skip_penalty = skip * 0.02
                     
-                    print(f"   行 {row_idx} 候选: 组 {combined_group_indices}, "
-                        f"单元格匹配: {match_count}/{len(row_texts)}, "
-                        f"整行相似度: {row_similarity}%, "
-                        f"最终得分: {final_score:.2f}")
+                    adjusted_score = final_score - merge_penalty - skip_penalty
                     
-                    if final_score >= 0.9:
+                    if adjusted_score > best_score:
+                        best_score = adjusted_score
+                        best_groups = combined_group_indices
+                    
+                    # 早停:如果单组匹配极好,不尝试合并更多
+                    if group_count == 1 and final_score > 0.85:
                         break
+                
+                # 优化:如果当前 skip 找到了非常好的匹配,就不再尝试更大的 skip
+                # 避免跳过正确的组去匹配后面相似的组
+                if best_score > 0.85:
+                    break
             
-            # 🔑 降低阈值
-            if best_groups and best_score >= 0.3:
+            # 判定匹配
+            if best_groups and best_score >= 0.4:
                 mapping[row_idx] = best_groups
                 used_groups.update(best_groups)
-                
-                # 🔑 关键改进:更新下一个要检查的组
                 next_group_to_check = max(best_groups) + 1
-                
-                print(f"   ✓ 行 {row_idx}: 匹配组 {best_groups} (得分: {best_score:.2f}), "
-                    f"下次从组 {next_group_to_check} 开始")
+                print(f"   ✓ 行 {row_idx} ('{row_header[:10]}...'): 匹配组 {best_groups} (得分: {best_score:.2f})")
             else:
                 mapping[row_idx] = []
-                # 🔑 关键改进:即使没匹配,也要推进指针(假设跳过 1 个组)
-                if next_group_to_check < len(grouped_boxes):
-                    next_group_to_check += 1
-                
-                print(f"   ✗ 行 {row_idx}: 无匹配 (最佳得分: {best_score:.2f}), "
-                    f"推进到组 {next_group_to_check}")
+                # 如果没匹配上,next_group_to_check 不变,给下一行机会
+                print(f"   ✗ 行 {row_idx} ('{row_header[:10]}...'): 无匹配 (最佳得分: {best_score:.2f})")
 
         # 🎯 策略 3: 第二遍 - 处理未使用的组(关键!)
         unused_groups = [i for i in range(len(grouped_boxes)) if i not in used_groups]