Sfoglia il codice sorgente

feat: 新增合并 MinerU 和 PaddleOCR 结果的功能,支持批量处理和生成增强的 Markdown

zhch158_admin 1 mese fa
parent
commit
1bd52e4660
1 ha cambiato i file con 399 aggiunte e 0 eliminazioni
  1. 399 0
      merge_mineru_paddle_ocr.py

+ 399 - 0
merge_mineru_paddle_ocr.py

@@ -0,0 +1,399 @@
+"""
+合并 MinerU 和 PaddleOCR 的结果
+使用 MinerU 的表格结构识别 + PaddleOCR 的文字框坐标
+"""
+import json
+import re
+from pathlib import Path
+from typing import List, Dict, Tuple, Optional
+from bs4 import BeautifulSoup
+from fuzzywuzzy import fuzz
+
+
+class MinerUPaddleOCRMerger:
+    """合并 MinerU 和 PaddleOCR 的结果"""
+    
+    def __init__(self, look_ahead_window: int = 10, similarity_threshold: int = 80):
+        """
+        Args:
+            look_ahead_window: 向前查找的窗口大小
+            similarity_threshold: 文本相似度阈值
+        """
+        self.look_ahead_window = look_ahead_window
+        self.similarity_threshold = similarity_threshold
+    
+    def merge_table_with_bbox(self, mineru_json_path: str, paddle_json_path: str, 
+                              output_path: Optional[str] = None) -> Dict:
+        """
+        合并 MinerU 和 PaddleOCR 的结果
+        
+        Args:
+            mineru_json_path: MinerU 输出的 JSON 路径
+            paddle_json_path: PaddleOCR 输出的 JSON 路径
+            output_path: 输出路径(可选)
+        
+        Returns:
+            合并后的结果字典
+        """
+        # 加载数据
+        with open(mineru_json_path, 'r', encoding='utf-8') as f:
+            mineru_data = json.load(f)
+        
+        with open(paddle_json_path, 'r', encoding='utf-8') as f:
+            paddle_data = json.load(f)
+        
+        # 提取 PaddleOCR 的文字框信息
+        paddle_text_boxes = self._extract_paddle_text_boxes(paddle_data)
+        
+        # 处理 MinerU 的数据
+        merged_data = self._process_mineru_data(mineru_data, paddle_text_boxes)
+        
+        # 保存结果
+        if output_path:
+            output_path = Path(output_path).resolve()
+            output_path.parent.mkdir(parents=True, exist_ok=True)
+            with open(str(output_path), 'w', encoding='utf-8') as f:
+                json.dump(merged_data, f, ensure_ascii=False, indent=2)
+        
+        return merged_data
+    
+    def _extract_paddle_text_boxes(self, paddle_data: Dict) -> List[Dict]:
+        """提取 PaddleOCR 的文字框信息"""
+        text_boxes = []
+        
+        if 'overall_ocr_res' in paddle_data:
+            ocr_res = paddle_data['overall_ocr_res']
+            rec_texts = ocr_res.get('rec_texts', [])
+            rec_polys = ocr_res.get('rec_polys', [])
+            rec_scores = ocr_res.get('rec_scores', [])
+
+            for i, (text, poly, score) in enumerate(zip(rec_texts, rec_polys, rec_scores)):
+                if text and text.strip():
+                    # 计算 bbox (x_min, y_min, x_max, y_max)
+                    xs = [p[0] for p in poly]
+                    ys = [p[1] for p in poly]
+                    bbox = [min(xs), min(ys), max(xs), max(ys)]
+                    
+                    text_boxes.append({
+                        'text': text,
+                        'bbox': bbox,
+                        'poly': poly,
+                        'score': score,
+                        'paddle_bbox_index': i,
+                        'used': False  # 标记是否已被使用
+                    })
+
+        return text_boxes
+    
+    def _process_mineru_data(self, mineru_data: List[Dict], 
+                            paddle_text_boxes: List[Dict]) -> List[Dict]:
+        """处理 MinerU 数据,添加 bbox 信息"""
+        merged_data = []
+        paddle_pointer = 0  # PaddleOCR 文字框指针
+        
+        for item in mineru_data:
+            if item['type'] == 'table':
+                # 处理表格
+                merged_item = item.copy()
+                table_html = item.get('table_body', '')
+                
+                # 解析 HTML 表格并添加 bbox
+                enhanced_html, paddle_pointer = self._enhance_table_html_with_bbox(
+                    table_html, paddle_text_boxes, paddle_pointer
+                )
+                
+                merged_item['table_body'] = enhanced_html
+                merged_item['table_body_with_bbox'] = enhanced_html
+                merged_item['bbox_mapping'] = 'merged_from_paddle_ocr'
+                
+                merged_data.append(merged_item)
+            
+            elif item['type'] in ['text', 'header']:
+                # 处理普通文本
+                merged_item = item.copy()
+                text = item.get('text', '')
+                
+                # 查找匹配的 bbox
+                matched_bbox, paddle_pointer = self._find_matching_bbox(
+                    text, paddle_text_boxes, paddle_pointer
+                )
+                
+                if matched_bbox:
+                    merged_item['bbox'] = matched_bbox['bbox']
+                    merged_item['bbox_source'] = 'paddle_ocr'
+                    merged_item['text_score'] = matched_bbox['score']
+                    # 标记为已使用
+                    matched_bbox['used'] = True
+                
+                merged_data.append(merged_item)
+            
+            else:
+                # 其他类型直接复制
+                merged_data.append(item.copy())
+        
+        return merged_data
+    
+    def _enhance_table_html_with_bbox(self, html: str, paddle_text_boxes: List[Dict], 
+                                      start_pointer: int) -> Tuple[str, int]:
+        """
+        为 HTML 表格添加 bbox 信息
+        
+        Args:
+            html: 原始 HTML 表格
+            paddle_text_boxes: PaddleOCR 文字框列表
+            start_pointer: 起始指针位置
+        
+        Returns:
+            (增强后的 HTML, 新的指针位置)
+        """
+        soup = BeautifulSoup(html, 'html.parser')
+        current_pointer = start_pointer
+        
+        # 遍历所有单元格
+        for cell in soup.find_all(['td', 'th']):
+            cell_text = cell.get_text(strip=True)
+            
+            if not cell_text:
+                continue
+            
+            # 查找匹配的 bbox
+            matched_bbox, current_pointer = self._find_matching_bbox(
+                cell_text, paddle_text_boxes, current_pointer
+            )
+            
+            if matched_bbox:
+                # 添加 data-bbox 属性
+                bbox = matched_bbox['bbox']
+                cell['data-bbox'] = f"[{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]}]"
+                cell['data-score'] = f"{matched_bbox['score']:.4f}"
+                cell['data-paddle-index'] = str(matched_bbox['paddle_bbox_index'])
+                
+                # 标记为已使用
+                matched_bbox['used'] = True
+        
+        return str(soup), current_pointer
+    
+    def _find_matching_bbox(self, target_text: str, text_boxes: List[Dict], 
+                           start_index: int) -> tuple[Optional[Dict], int]:
+        """
+        查找匹配的文字框
+        
+        Args:
+            target_text: 目标文本
+            text_boxes: 文字框列表
+            start_index: 起始索引
+        
+        Returns:
+            匹配的文字框信息,如果未找到返回 None
+        """
+        target_text = self._normalize_text(target_text)
+        
+        # 在窗口范围内查找
+        search_end = min(start_index + self.look_ahead_window, len(text_boxes))
+        
+        best_match = None
+        best_index = start_index
+        best_similarity = 0
+        
+        for i in range(start_index, search_end):
+            if text_boxes[i]['used']:
+                continue
+            
+            box_text = self._normalize_text(text_boxes[i]['text'])
+            
+            # 计算相似度
+            # similarity = fuzz.ratio(target_text, box_text)
+            similarity = fuzz.token_set_ratio(target_text, box_text)
+            
+            # 精确匹配优先
+            if target_text == box_text:
+                return text_boxes[i], i + 1
+            
+            # 记录最佳匹配
+            if similarity > best_similarity and similarity >= self.similarity_threshold:
+                best_similarity = similarity
+                best_match = text_boxes[i]
+                best_index = i + 1
+
+        return best_match, best_index
+
+    def _normalize_text(self, text: str) -> str:
+        """标准化文本(去除空格、标点等)"""
+        # 移除所有空白字符
+        text = re.sub(r'\s+', '', text)
+        # 转换全角数字和字母为半角
+        text = self._full_to_half(text)
+        return text.lower()
+    
+    def _full_to_half(self, text: str) -> str:
+        """全角转半角"""
+        result = []
+        for char in text:
+            code = ord(char)
+            if code == 0x3000:  # 全角空格
+                code = 0x0020
+            elif 0xFF01 <= code <= 0xFF5E:  # 全角字符
+                code -= 0xFEE0
+            result.append(chr(code))
+        return ''.join(result)
+    
+    def generate_enhanced_markdown(self, merged_data: List[Dict], 
+                                   output_path: Optional[str] = None) -> str:
+        """
+        生成增强的 Markdown(包含 bbox 信息的注释)
+        
+        Args:
+            merged_data: 合并后的数据
+            output_path: 输出路径(可选)
+        
+        Returns:
+            Markdown 内容
+        """
+        md_lines = []
+        
+        for item in merged_data:
+            if item['type'] == 'header':
+                text = item.get('text', '')
+                bbox = item.get('bbox', [])
+                md_lines.append(f"<!-- bbox: {bbox} -->")
+                md_lines.append(f"# {text}\n")
+            
+            elif item['type'] == 'text':
+                text = item.get('text', '')
+                bbox = item.get('bbox', [])
+                if bbox:
+                    md_lines.append(f"<!-- bbox: {bbox} -->")
+                md_lines.append(f"{text}\n")
+            
+            elif item['type'] == 'table':
+                md_lines.append("\n## 表格\n")
+                md_lines.append("<!-- 表格单元格包含 data-bbox 属性 -->\n")
+                md_lines.append(item.get('table_body_with_bbox', item.get('table_body', '')))
+                md_lines.append("\n")
+        
+        markdown_content = '\n'.join(md_lines)
+        
+        if output_path:
+            with open(output_path, 'w', encoding='utf-8') as f:
+                f.write(markdown_content)
+        
+        return markdown_content
+    
+    def extract_table_cells_with_bbox(self, merged_data: List[Dict]) -> List[Dict]:
+        """
+        提取所有表格单元格及其 bbox 信息
+        
+        Returns:
+            单元格列表,每个包含 text, bbox, row, col 等信息
+        """
+        cells = []
+        
+        for item in merged_data:
+            if item['type'] != 'table':
+                continue
+            
+            html = item.get('table_body_with_bbox', item.get('table_body', ''))
+            soup = BeautifulSoup(html, 'html.parser')
+            
+            # 遍历所有行
+            for row_idx, row in enumerate(soup.find_all('tr')):
+                # 遍历所有单元格
+                for col_idx, cell in enumerate(row.find_all(['td', 'th'])):
+                    cell_text = cell.get_text(strip=True)
+                    bbox_str = cell.get('data-bbox', '')
+                    
+                    if bbox_str:
+                        try:
+                            bbox = json.loads(bbox_str)
+                            cells.append({
+                                'text': cell_text,
+                                'bbox': bbox,
+                                'row': row_idx,
+                                'col': col_idx,
+                                'score': float(cell.get('data-score', 0)),
+                                'paddle_index': int(cell.get('data-paddle-index', -1))
+                            })
+                        except (json.JSONDecodeError, ValueError):
+                            pass
+        
+        return cells
+
+
+def merge_mineru_paddle_batch(mineru_dir: str, paddle_dir: str, output_dir: str):
+    """
+    批量合并 MinerU 和 PaddleOCR 的结果
+    
+    Args:
+        mineru_dir: MinerU 结果目录
+        paddle_dir: PaddleOCR 结果目录
+        output_dir: 输出目录
+    """
+    mineru_path = Path(mineru_dir)
+    paddle_path = Path(paddle_dir)
+    output_path = Path(output_dir)
+    output_path.mkdir(parents=True, exist_ok=True)
+    
+    merger = MinerUPaddleOCRMerger(look_ahead_window=10, similarity_threshold=80)
+    
+    # 查找所有 MinerU 的 JSON 文件, page_001.json
+    mineru_files = list(mineru_path.glob('*_page_*[0-9].json'))
+    mineru_files.sort()
+    
+    print(f"找到 {len(mineru_files)} 个 MinerU 文件")
+    
+    for mineru_file in mineru_files:
+        # 查找对应的 PaddleOCR 文件
+        paddle_file = paddle_path / mineru_file.name
+        
+        if not paddle_file.exists():
+            print(f"⚠️ 未找到对应的 PaddleOCR 文件: {paddle_file}")
+            continue
+        
+        print(f"处理: {mineru_file.name}")
+        
+        # 输出文件路径
+        merged_json_path = output_path / f"{mineru_file.stem}_merged.json"
+        merged_md_path = output_path / f"{mineru_file.stem}_merged.md"
+        cells_json_path = output_path / f"{mineru_file.stem}_cells.json"
+        
+        try:
+            # 合并数据
+            merged_data = merger.merge_table_with_bbox(
+                str(mineru_file),
+                str(paddle_file),
+                str(merged_json_path)
+            )
+            
+            # 生成 Markdown
+            merger.generate_enhanced_markdown(merged_data, str(merged_md_path))
+            
+            # 提取单元格信息
+            cells = merger.extract_table_cells_with_bbox(merged_data)
+            
+            with open(cells_json_path, 'w', encoding='utf-8') as f:
+                json.dump(cells, f, ensure_ascii=False, indent=2)
+            
+            print(f"  ✅ 合并完成")
+            print(f"  - 提取了 {len(cells)} 个表格单元格")
+            
+        except Exception as e:
+            print(f"  ❌ 处理失败: {e}")
+            import traceback
+            traceback.print_exc()
+
+
+if __name__ == "__main__":
+    # 示例用法
+    mineru_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results"
+    paddle_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results"
+    output_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results"
+    
+    merge_mineru_paddle_batch(mineru_dir, paddle_dir, output_dir)
+
+    # 示例:合并1个文件
+    # mineru_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results/A用户_单元格扫描流水_page_001.json"
+    # paddle_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results/A用户_单元格扫描流水_page_001.json"
+    # output_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results/A用户_单元格扫描流水_page_001.json"
+
+    # merger = MinerUPaddleOCRMerger(look_ahead_window=10, similarity_threshold=80)
+    # merger.merge_table_with_bbox(mineru_json, paddle_json, output_json)