Browse Source

feat: 添加命令行参数支持,优化单文件和批量处理功能

zhch158_admin 1 month ago
parent
commit
4852e659fe
1 changed files with 277 additions and 65 deletions
  1. 277 65
      merge_mineru_paddle_ocr.py

+ 277 - 65
merge_mineru_paddle_ocr.py

@@ -4,6 +4,7 @@
 """
 import json
 import re
+import argparse
 from pathlib import Path
 from typing import List, Dict, Tuple, Optional
 from bs4 import BeautifulSoup
@@ -22,8 +23,7 @@ class MinerUPaddleOCRMerger:
         self.look_ahead_window = look_ahead_window
         self.similarity_threshold = similarity_threshold
     
-    def merge_table_with_bbox(self, mineru_json_path: str, paddle_json_path: str, 
-                              output_path: Optional[str] = None) -> Dict:
+    def merge_table_with_bbox(self, mineru_json_path: str, paddle_json_path: str) -> List[Dict]:
         """
         合并 MinerU 和 PaddleOCR 的结果
         
@@ -35,6 +35,7 @@ class MinerUPaddleOCRMerger:
         Returns:
             合并后的结果字典
         """
+        merged_data = None
         # 加载数据
         with open(mineru_json_path, 'r', encoding='utf-8') as f:
             mineru_data = json.load(f)
@@ -48,13 +49,6 @@ class MinerUPaddleOCRMerger:
         # 处理 MinerU 的数据
         merged_data = self._process_mineru_data(mineru_data, paddle_text_boxes)
         
-        # 保存结果
-        if output_path:
-            output_path = Path(output_path).resolve()
-            output_path.parent.mkdir(parents=True, exist_ok=True)
-            with open(str(output_path), 'w', encoding='utf-8') as f:
-                json.dump(merged_data, f, ensure_ascii=False, indent=2)
-        
         return merged_data
     
     def _extract_paddle_text_boxes(self, paddle_data: Dict) -> List[Dict]:
@@ -89,6 +83,7 @@ class MinerUPaddleOCRMerger:
                             paddle_text_boxes: List[Dict]) -> List[Dict]:
         """处理 MinerU 数据,添加 bbox 信息"""
         merged_data = []
+        cells = None  # 存储所有表格单元格信息
         paddle_pointer = 0  # PaddleOCR 文字框指针
         
         for item in mineru_data:
@@ -98,7 +93,7 @@ class MinerUPaddleOCRMerger:
                 table_html = item.get('table_body', '')
                 
                 # 解析 HTML 表格并添加 bbox
-                enhanced_html, paddle_pointer = self._enhance_table_html_with_bbox(
+                enhanced_html, cells, paddle_pointer = self._enhance_table_html_with_bbox(
                     table_html, paddle_text_boxes, paddle_pointer
                 )
                 
@@ -131,10 +126,13 @@ class MinerUPaddleOCRMerger:
                 # 其他类型直接复制
                 merged_data.append(item.copy())
         
+        if cells:
+            merged_data.extend(cells)
+
         return merged_data
     
     def _enhance_table_html_with_bbox(self, html: str, paddle_text_boxes: List[Dict], 
-                                      start_pointer: int) -> Tuple[str, int]:
+                                      start_pointer: int) -> Tuple[str, List[Dict], int]:
         """
         为 HTML 表格添加 bbox 信息
         
@@ -144,11 +142,12 @@ class MinerUPaddleOCRMerger:
             start_pointer: 起始指针位置
         
         Returns:
-            (增强后的 HTML, 新的指针位置)
+            (增强后的 HTML, 单元格数组, 新的指针位置)
         """
         soup = BeautifulSoup(html, 'html.parser')
         current_pointer = start_pointer
-        
+        cells = []  # 存储单元格的 bbox 信息
+
         # 遍历所有单元格
         for cell in soup.find_all(['td', 'th']):
             cell_text = cell.get_text(strip=True)
@@ -167,11 +166,18 @@ class MinerUPaddleOCRMerger:
                 cell['data-bbox'] = f"[{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]}]"
                 cell['data-score'] = f"{matched_bbox['score']:.4f}"
                 cell['data-paddle-index'] = str(matched_bbox['paddle_bbox_index'])
-                
+
+                cells.append({
+                    'type': 'table_cell',
+                    'text': cell_text,
+                    'bbox': bbox,
+                    'score': matched_bbox['score'],
+                    'paddle_bbox_index': matched_bbox['paddle_bbox_index']
+                })
                 # 标记为已使用
                 matched_bbox['used'] = True
         
-        return str(soup), current_pointer
+        return str(soup), cells, current_pointer
     
     def _find_matching_bbox(self, target_text: str, text_boxes: List[Dict], 
                            start_index: int) -> tuple[Optional[Dict], int]:
@@ -184,7 +190,7 @@ class MinerUPaddleOCRMerger:
             start_index: 起始索引
         
         Returns:
-            匹配的文字框信息,如果未找到返回 None
+            (匹配的文字框信息, 新的指针位置)
         """
         target_text = self._normalize_text(target_text)
         
@@ -202,7 +208,6 @@ class MinerUPaddleOCRMerger:
             box_text = self._normalize_text(text_boxes[i]['text'])
             
             # 计算相似度
-            # similarity = fuzz.ratio(target_text, box_text)
             similarity = fuzz.token_set_ratio(target_text, box_text)
             
             # 精确匹配优先
@@ -319,7 +324,58 @@ class MinerUPaddleOCRMerger:
         return cells
 
 
-def merge_mineru_paddle_batch(mineru_dir: str, paddle_dir: str, output_dir: str):
+def merge_single_file(mineru_file: Path, paddle_file: Path, output_dir: Path, 
+                     merger: MinerUPaddleOCRMerger) -> bool:
+    """
+    合并单个文件
+    
+    Args:
+        mineru_file: MinerU JSON 文件路径
+        paddle_file: PaddleOCR JSON 文件路径
+        output_dir: 输出目录
+        merger: 合并器实例
+    
+    Returns:
+        是否成功
+    """
+    print(f"📄 处理: {mineru_file.name}")
+    
+    # 输出文件路径
+    merged_json_path = output_dir / f"{mineru_file.stem}.json"
+    
+    try:
+        # 合并数据
+        merged_data = merger.merge_table_with_bbox(
+            str(mineru_file),
+            str(paddle_file)
+        )
+        
+        # 生成 Markdown
+        # merger.generate_enhanced_markdown(merged_data, str(merged_md_path))
+        
+        # 提取单元格信息
+        # cells = merger.extract_table_cells_with_bbox(merged_data)
+        
+        with open(merged_json_path, 'w', encoding='utf-8') as f:
+            json.dump(merged_data, f, ensure_ascii=False, indent=2)
+
+        print(f"  ✅ 合并完成")
+        print(f"  📊 共处理了 {len(merged_data)} 个对象")
+        print(f"  💾 输出文件:")
+        print(f"    - {merged_json_path.name}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"  ❌ 处理失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+def merge_mineru_paddle_batch(mineru_dir: str, paddle_dir: str, output_dir: str,
+                              look_ahead_window: int = 10, 
+                              similarity_threshold: int = 80):
     """
     批量合并 MinerU 和 PaddleOCR 的结果
     
@@ -327,73 +383,229 @@ def merge_mineru_paddle_batch(mineru_dir: str, paddle_dir: str, output_dir: str)
         mineru_dir: MinerU 结果目录
         paddle_dir: PaddleOCR 结果目录
         output_dir: 输出目录
+        look_ahead_window: 向前查找窗口大小
+        similarity_threshold: 相似度阈值
     """
     mineru_path = Path(mineru_dir)
     paddle_path = Path(paddle_dir)
     output_path = Path(output_dir)
     output_path.mkdir(parents=True, exist_ok=True)
     
-    merger = MinerUPaddleOCRMerger(look_ahead_window=10, similarity_threshold=80)
+    merger = MinerUPaddleOCRMerger(
+        look_ahead_window=look_ahead_window, 
+        similarity_threshold=similarity_threshold
+    )
     
-    # 查找所有 MinerU 的 JSON 文件, page_001.json
+    # 查找所有 MinerU 的 JSON 文件
     mineru_files = list(mineru_path.glob('*_page_*[0-9].json'))
     mineru_files.sort()
     
-    print(f"找到 {len(mineru_files)} 个 MinerU 文件")
+    print(f"\n🔍 找到 {len(mineru_files)} 个 MinerU 文件")
+    print(f"📂 MinerU 目录: {mineru_dir}")
+    print(f"📂 PaddleOCR 目录: {paddle_dir}")
+    print(f"📂 输出目录: {output_dir}")
+    print(f"⚙️  查找窗口: {look_ahead_window}")
+    print(f"⚙️  相似度阈值: {similarity_threshold}%\n")
+    
+    success_count = 0
+    failed_count = 0
     
     for mineru_file in mineru_files:
         # 查找对应的 PaddleOCR 文件
         paddle_file = paddle_path / mineru_file.name
         
         if not paddle_file.exists():
-            print(f"⚠️ 未找到对应的 PaddleOCR 文件: {paddle_file}")
+            print(f"⚠️  跳过: 未找到对应的 PaddleOCR 文件: {paddle_file.name}\n")
+            failed_count += 1
             continue
         
-        print(f"处理: {mineru_file.name}")
-        
-        # 输出文件路径
-        merged_json_path = output_path / f"{mineru_file.stem}_merged.json"
-        merged_md_path = output_path / f"{mineru_file.stem}_merged.md"
-        cells_json_path = output_path / f"{mineru_file.stem}_cells.json"
+        if merge_single_file(mineru_file, paddle_file, output_path, merger):
+            success_count += 1
+        else:
+            failed_count += 1
         
-        try:
-            # 合并数据
-            merged_data = merger.merge_table_with_bbox(
-                str(mineru_file),
-                str(paddle_file),
-                str(merged_json_path)
-            )
-            
-            # 生成 Markdown
-            merger.generate_enhanced_markdown(merged_data, str(merged_md_path))
-            
-            # 提取单元格信息
-            cells = merger.extract_table_cells_with_bbox(merged_data)
-            
-            with open(cells_json_path, 'w', encoding='utf-8') as f:
-                json.dump(cells, f, ensure_ascii=False, indent=2)
-            
-            print(f"  ✅ 合并完成")
-            print(f"  - 提取了 {len(cells)} 个表格单元格")
-            
-        except Exception as e:
-            print(f"  ❌ 处理失败: {e}")
-            import traceback
-            traceback.print_exc()
+        print()  # 空行分隔
+    
+    # 打印统计信息
+    print("=" * 60)
+    print(f"✅ 处理完成!")
+    print(f"📊 统计信息:")
+    print(f"  - 总文件数: {len(mineru_files)}")
+    print(f"  - 成功: {success_count}")
+    print(f"  - 失败: {failed_count}")
+    print("=" * 60)
 
 
-if __name__ == "__main__":
-    # 示例用法
-    mineru_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results"
-    paddle_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results"
-    output_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results"
-    
-    merge_mineru_paddle_batch(mineru_dir, paddle_dir, output_dir)
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(
+        description='合并 MinerU 和 PaddleOCR 的识别结果,添加 bbox 坐标信息',
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+示例用法:
 
-    # 示例:合并1个文件
-    # mineru_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results/A用户_单元格扫描流水_page_001.json"
-    # paddle_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results/A用户_单元格扫描流水_page_001.json"
-    # output_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results/A用户_单元格扫描流水_page_001.json"
+  1. 批量处理整个目录:
+     python merge_mineru_paddle_ocr.py \\
+         --mineru-dir /path/to/mineru/results \\
+         --paddle-dir /path/to/paddle/results \\
+         --output-dir /path/to/output
 
-    # merger = MinerUPaddleOCRMerger(look_ahead_window=10, similarity_threshold=80)
-    # merger.merge_table_with_bbox(mineru_json, paddle_json, output_json)
+  2. 处理单个文件:
+     python merge_mineru_paddle_ocr.py \\
+         --mineru-file /path/to/file_page_001.json \\
+         --paddle-file /path/to/file_page_001.json \\
+         --output-dir /path/to/output
+
+  3. 自定义参数:
+     python merge_mineru_paddle_ocr.py \\
+         --mineru-dir /path/to/mineru \\
+         --paddle-dir /path/to/paddle \\
+         --output-dir /path/to/output \\
+         --window 15 \\
+         --threshold 85
+        """
+    )
+    
+    # 文件/目录参数
+    file_group = parser.add_argument_group('文件参数')
+    file_group.add_argument(
+        '--mineru-file', 
+        type=str,
+        help='MinerU 输出的 JSON 文件路径(单文件模式)'
+    )
+    file_group.add_argument(
+        '--paddle-file', 
+        type=str,
+        help='PaddleOCR 输出的 JSON 文件路径(单文件模式)'
+    )
+    
+    dir_group = parser.add_argument_group('目录参数')
+    dir_group.add_argument(
+        '--mineru-dir', 
+        type=str,
+        help='MinerU 结果目录(批量模式)'
+    )
+    dir_group.add_argument(
+        '--paddle-dir', 
+        type=str,
+        help='PaddleOCR 结果目录(批量模式)'
+    )
+    
+    # 输出参数
+    output_group = parser.add_argument_group('输出参数')
+    output_group.add_argument(
+        '-o', '--output-dir',
+        type=str,
+        required=True,
+        help='输出目录(必需)'
+    )
+    
+    # 算法参数
+    algo_group = parser.add_argument_group('算法参数')
+    algo_group.add_argument(
+        '-w', '--window',
+        type=int,
+        default=10,
+        help='向前查找的窗口大小(默认: 10)'
+    )
+    algo_group.add_argument(
+        '-t', '--threshold',
+        type=int,
+        default=80,
+        help='文本相似度阈值(0-100,默认: 80)'
+    )
+    
+    args = parser.parse_args()
+    
+    # 验证参数
+    if args.mineru_file and args.paddle_file:
+        # 单文件模式
+        mineru_file = Path(args.mineru_file)
+        paddle_file = Path(args.paddle_file)
+        output_dir = Path(args.output_dir)
+        
+        if not mineru_file.exists():
+            print(f"❌ 错误: MinerU 文件不存在: {mineru_file}")
+            return
+        
+        if not paddle_file.exists():
+            print(f"❌ 错误: PaddleOCR 文件不存在: {paddle_file}")
+            return
+        
+        output_dir.mkdir(parents=True, exist_ok=True)
+        
+        print("\n🔧 单文件处理模式")
+        print(f"📄 MinerU 文件: {mineru_file}")
+        print(f"📄 PaddleOCR 文件: {paddle_file}")
+        print(f"📂 输出目录: {output_dir}")
+        print(f"⚙️  查找窗口: {args.window}")
+        print(f"⚙️  相似度阈值: {args.threshold}%\n")
+        
+        merger = MinerUPaddleOCRMerger(
+            look_ahead_window=args.window,
+            similarity_threshold=args.threshold
+        )
+        
+        success = merge_single_file(mineru_file, paddle_file, output_dir, merger)
+        
+        if success:
+            print("\n✅ 处理完成!")
+        else:
+            print("\n❌ 处理失败!")
+    
+    elif args.mineru_dir and args.paddle_dir:
+        # 批量模式
+        if not Path(args.mineru_dir).exists():
+            print(f"❌ 错误: MinerU 目录不存在: {args.mineru_dir}")
+            return
+        
+        if not Path(args.paddle_dir).exists():
+            print(f"❌ 错误: PaddleOCR 目录不存在: {args.paddle_dir}")
+            return
+        
+        print("\n🔧 批量处理模式")
+        
+        merge_mineru_paddle_batch(
+            args.mineru_dir,
+            args.paddle_dir,
+            args.output_dir,
+            look_ahead_window=args.window,
+            similarity_threshold=args.threshold
+        )
+    
+    else:
+        parser.print_help()
+        print("\n❌ 错误: 请指定单文件模式或批量模式的参数")
+        print("  单文件模式: --mineru-file 和 --paddle-file")
+        print("  批量模式: --mineru-dir 和 --paddle-dir")
+
+if __name__ == "__main__":
+    print("🚀 启动 MinerU + PaddleOCR 合并程序...")
+    
+    import sys
+    
+    if len(sys.argv) == 1:
+        # 如果没有命令行参数,使用默认配置运行
+        print("ℹ️  未提供命令行参数,使用默认配置运行...")
+        
+        # 默认配置
+        default_config = {
+            "mineru-dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results",
+            "paddle-dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results",
+            "output-dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results",
+            "window": "15",
+            "threshold": "85"
+        }
+        
+        print(f"📂 MinerU 目录: {default_config['mineru-dir']}")
+        print(f"📂 PaddleOCR 目录: {default_config['paddle-dir']}")
+        print(f"📂 输出目录: {default_config['output-dir']}")
+        print(f"⚙️  查找窗口: {default_config['window']}")
+        print(f"⚙️  相似度阈值: {default_config['threshold']}%\n")
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            sys.argv.extend([f"--{key}", str(value)])
+    
+    sys.exit(main())