Эх сурвалжийг харах

feat: 添加命令行参数支持,优化单文件和批量处理功能

zhch158_admin 1 сар өмнө
parent
commit
4852e659fe

+ 277 - 65
merge_mineru_paddle_ocr.py

@@ -4,6 +4,7 @@
 """
 """
 import json
 import json
 import re
 import re
+import argparse
 from pathlib import Path
 from pathlib import Path
 from typing import List, Dict, Tuple, Optional
 from typing import List, Dict, Tuple, Optional
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
@@ -22,8 +23,7 @@ class MinerUPaddleOCRMerger:
         self.look_ahead_window = look_ahead_window
         self.look_ahead_window = look_ahead_window
         self.similarity_threshold = similarity_threshold
         self.similarity_threshold = similarity_threshold
     
     
-    def merge_table_with_bbox(self, mineru_json_path: str, paddle_json_path: str, 
-                              output_path: Optional[str] = None) -> Dict:
+    def merge_table_with_bbox(self, mineru_json_path: str, paddle_json_path: str) -> List[Dict]:
         """
         """
         合并 MinerU 和 PaddleOCR 的结果
         合并 MinerU 和 PaddleOCR 的结果
         
         
@@ -35,6 +35,7 @@ class MinerUPaddleOCRMerger:
         Returns:
         Returns:
             合并后的结果字典
             合并后的结果字典
         """
         """
+        merged_data = None
         # 加载数据
         # 加载数据
         with open(mineru_json_path, 'r', encoding='utf-8') as f:
         with open(mineru_json_path, 'r', encoding='utf-8') as f:
             mineru_data = json.load(f)
             mineru_data = json.load(f)
@@ -48,13 +49,6 @@ class MinerUPaddleOCRMerger:
         # 处理 MinerU 的数据
         # 处理 MinerU 的数据
         merged_data = self._process_mineru_data(mineru_data, paddle_text_boxes)
         merged_data = self._process_mineru_data(mineru_data, paddle_text_boxes)
         
         
-        # 保存结果
-        if output_path:
-            output_path = Path(output_path).resolve()
-            output_path.parent.mkdir(parents=True, exist_ok=True)
-            with open(str(output_path), 'w', encoding='utf-8') as f:
-                json.dump(merged_data, f, ensure_ascii=False, indent=2)
-        
         return merged_data
         return merged_data
     
     
     def _extract_paddle_text_boxes(self, paddle_data: Dict) -> List[Dict]:
     def _extract_paddle_text_boxes(self, paddle_data: Dict) -> List[Dict]:
@@ -89,6 +83,7 @@ class MinerUPaddleOCRMerger:
                             paddle_text_boxes: List[Dict]) -> List[Dict]:
                             paddle_text_boxes: List[Dict]) -> List[Dict]:
         """处理 MinerU 数据,添加 bbox 信息"""
         """处理 MinerU 数据,添加 bbox 信息"""
         merged_data = []
         merged_data = []
+        cells = None  # 存储所有表格单元格信息
         paddle_pointer = 0  # PaddleOCR 文字框指针
         paddle_pointer = 0  # PaddleOCR 文字框指针
         
         
         for item in mineru_data:
         for item in mineru_data:
@@ -98,7 +93,7 @@ class MinerUPaddleOCRMerger:
                 table_html = item.get('table_body', '')
                 table_html = item.get('table_body', '')
                 
                 
                 # 解析 HTML 表格并添加 bbox
                 # 解析 HTML 表格并添加 bbox
-                enhanced_html, paddle_pointer = self._enhance_table_html_with_bbox(
+                enhanced_html, cells, paddle_pointer = self._enhance_table_html_with_bbox(
                     table_html, paddle_text_boxes, paddle_pointer
                     table_html, paddle_text_boxes, paddle_pointer
                 )
                 )
                 
                 
@@ -131,10 +126,13 @@ class MinerUPaddleOCRMerger:
                 # 其他类型直接复制
                 # 其他类型直接复制
                 merged_data.append(item.copy())
                 merged_data.append(item.copy())
         
         
+        if cells:
+            merged_data.extend(cells)
+
         return merged_data
         return merged_data
     
     
     def _enhance_table_html_with_bbox(self, html: str, paddle_text_boxes: List[Dict], 
     def _enhance_table_html_with_bbox(self, html: str, paddle_text_boxes: List[Dict], 
-                                      start_pointer: int) -> Tuple[str, int]:
+                                      start_pointer: int) -> Tuple[str, List[Dict], int]:
         """
         """
         为 HTML 表格添加 bbox 信息
         为 HTML 表格添加 bbox 信息
         
         
@@ -144,11 +142,12 @@ class MinerUPaddleOCRMerger:
             start_pointer: 起始指针位置
             start_pointer: 起始指针位置
         
         
         Returns:
         Returns:
-            (增强后的 HTML, 新的指针位置)
+            (增强后的 HTML, 单元格数组, 新的指针位置)
         """
         """
         soup = BeautifulSoup(html, 'html.parser')
         soup = BeautifulSoup(html, 'html.parser')
         current_pointer = start_pointer
         current_pointer = start_pointer
-        
+        cells = []  # 存储单元格的 bbox 信息
+
         # 遍历所有单元格
         # 遍历所有单元格
         for cell in soup.find_all(['td', 'th']):
         for cell in soup.find_all(['td', 'th']):
             cell_text = cell.get_text(strip=True)
             cell_text = cell.get_text(strip=True)
@@ -167,11 +166,18 @@ class MinerUPaddleOCRMerger:
                 cell['data-bbox'] = f"[{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]}]"
                 cell['data-bbox'] = f"[{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]}]"
                 cell['data-score'] = f"{matched_bbox['score']:.4f}"
                 cell['data-score'] = f"{matched_bbox['score']:.4f}"
                 cell['data-paddle-index'] = str(matched_bbox['paddle_bbox_index'])
                 cell['data-paddle-index'] = str(matched_bbox['paddle_bbox_index'])
-                
+
+                cells.append({
+                    'type': 'table_cell',
+                    'text': cell_text,
+                    'bbox': bbox,
+                    'score': matched_bbox['score'],
+                    'paddle_bbox_index': matched_bbox['paddle_bbox_index']
+                })
                 # 标记为已使用
                 # 标记为已使用
                 matched_bbox['used'] = True
                 matched_bbox['used'] = True
         
         
-        return str(soup), current_pointer
+        return str(soup), cells, current_pointer
     
     
     def _find_matching_bbox(self, target_text: str, text_boxes: List[Dict], 
     def _find_matching_bbox(self, target_text: str, text_boxes: List[Dict], 
                            start_index: int) -> tuple[Optional[Dict], int]:
                            start_index: int) -> tuple[Optional[Dict], int]:
@@ -184,7 +190,7 @@ class MinerUPaddleOCRMerger:
             start_index: 起始索引
             start_index: 起始索引
         
         
         Returns:
         Returns:
-            匹配的文字框信息,如果未找到返回 None
+            (匹配的文字框信息, 新的指针位置)
         """
         """
         target_text = self._normalize_text(target_text)
         target_text = self._normalize_text(target_text)
         
         
@@ -202,7 +208,6 @@ class MinerUPaddleOCRMerger:
             box_text = self._normalize_text(text_boxes[i]['text'])
             box_text = self._normalize_text(text_boxes[i]['text'])
             
             
             # 计算相似度
             # 计算相似度
-            # similarity = fuzz.ratio(target_text, box_text)
             similarity = fuzz.token_set_ratio(target_text, box_text)
             similarity = fuzz.token_set_ratio(target_text, box_text)
             
             
             # 精确匹配优先
             # 精确匹配优先
@@ -319,7 +324,58 @@ class MinerUPaddleOCRMerger:
         return cells
         return cells
 
 
 
 
-def merge_mineru_paddle_batch(mineru_dir: str, paddle_dir: str, output_dir: str):
+def merge_single_file(mineru_file: Path, paddle_file: Path, output_dir: Path, 
+                     merger: MinerUPaddleOCRMerger) -> bool:
+    """
+    合并单个文件
+    
+    Args:
+        mineru_file: MinerU JSON 文件路径
+        paddle_file: PaddleOCR JSON 文件路径
+        output_dir: 输出目录
+        merger: 合并器实例
+    
+    Returns:
+        是否成功
+    """
+    print(f"📄 处理: {mineru_file.name}")
+    
+    # 输出文件路径
+    merged_json_path = output_dir / f"{mineru_file.stem}.json"
+    
+    try:
+        # 合并数据
+        merged_data = merger.merge_table_with_bbox(
+            str(mineru_file),
+            str(paddle_file)
+        )
+        
+        # 生成 Markdown
+        # merger.generate_enhanced_markdown(merged_data, str(merged_md_path))
+        
+        # 提取单元格信息
+        # cells = merger.extract_table_cells_with_bbox(merged_data)
+        
+        with open(merged_json_path, 'w', encoding='utf-8') as f:
+            json.dump(merged_data, f, ensure_ascii=False, indent=2)
+
+        print(f"  ✅ 合并完成")
+        print(f"  📊 共处理了 {len(merged_data)} 个对象")
+        print(f"  💾 输出文件:")
+        print(f"    - {merged_json_path.name}")
+        
+        return True
+        
+    except Exception as e:
+        print(f"  ❌ 处理失败: {e}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+def merge_mineru_paddle_batch(mineru_dir: str, paddle_dir: str, output_dir: str,
+                              look_ahead_window: int = 10, 
+                              similarity_threshold: int = 80):
     """
     """
     批量合并 MinerU 和 PaddleOCR 的结果
     批量合并 MinerU 和 PaddleOCR 的结果
     
     
@@ -327,73 +383,229 @@ def merge_mineru_paddle_batch(mineru_dir: str, paddle_dir: str, output_dir: str)
         mineru_dir: MinerU 结果目录
         mineru_dir: MinerU 结果目录
         paddle_dir: PaddleOCR 结果目录
         paddle_dir: PaddleOCR 结果目录
         output_dir: 输出目录
         output_dir: 输出目录
+        look_ahead_window: 向前查找窗口大小
+        similarity_threshold: 相似度阈值
     """
     """
     mineru_path = Path(mineru_dir)
     mineru_path = Path(mineru_dir)
     paddle_path = Path(paddle_dir)
     paddle_path = Path(paddle_dir)
     output_path = Path(output_dir)
     output_path = Path(output_dir)
     output_path.mkdir(parents=True, exist_ok=True)
     output_path.mkdir(parents=True, exist_ok=True)
     
     
-    merger = MinerUPaddleOCRMerger(look_ahead_window=10, similarity_threshold=80)
+    merger = MinerUPaddleOCRMerger(
+        look_ahead_window=look_ahead_window, 
+        similarity_threshold=similarity_threshold
+    )
     
     
-    # 查找所有 MinerU 的 JSON 文件, page_001.json
+    # 查找所有 MinerU 的 JSON 文件
     mineru_files = list(mineru_path.glob('*_page_*[0-9].json'))
     mineru_files = list(mineru_path.glob('*_page_*[0-9].json'))
     mineru_files.sort()
     mineru_files.sort()
     
     
-    print(f"找到 {len(mineru_files)} 个 MinerU 文件")
+    print(f"\n🔍 找到 {len(mineru_files)} 个 MinerU 文件")
+    print(f"📂 MinerU 目录: {mineru_dir}")
+    print(f"📂 PaddleOCR 目录: {paddle_dir}")
+    print(f"📂 输出目录: {output_dir}")
+    print(f"⚙️  查找窗口: {look_ahead_window}")
+    print(f"⚙️  相似度阈值: {similarity_threshold}%\n")
+    
+    success_count = 0
+    failed_count = 0
     
     
     for mineru_file in mineru_files:
     for mineru_file in mineru_files:
         # 查找对应的 PaddleOCR 文件
         # 查找对应的 PaddleOCR 文件
         paddle_file = paddle_path / mineru_file.name
         paddle_file = paddle_path / mineru_file.name
         
         
         if not paddle_file.exists():
         if not paddle_file.exists():
-            print(f"⚠️ 未找到对应的 PaddleOCR 文件: {paddle_file}")
+            print(f"⚠️  跳过: 未找到对应的 PaddleOCR 文件: {paddle_file.name}\n")
+            failed_count += 1
             continue
             continue
         
         
-        print(f"处理: {mineru_file.name}")
-        
-        # 输出文件路径
-        merged_json_path = output_path / f"{mineru_file.stem}_merged.json"
-        merged_md_path = output_path / f"{mineru_file.stem}_merged.md"
-        cells_json_path = output_path / f"{mineru_file.stem}_cells.json"
+        if merge_single_file(mineru_file, paddle_file, output_path, merger):
+            success_count += 1
+        else:
+            failed_count += 1
         
         
-        try:
-            # 合并数据
-            merged_data = merger.merge_table_with_bbox(
-                str(mineru_file),
-                str(paddle_file),
-                str(merged_json_path)
-            )
-            
-            # 生成 Markdown
-            merger.generate_enhanced_markdown(merged_data, str(merged_md_path))
-            
-            # 提取单元格信息
-            cells = merger.extract_table_cells_with_bbox(merged_data)
-            
-            with open(cells_json_path, 'w', encoding='utf-8') as f:
-                json.dump(cells, f, ensure_ascii=False, indent=2)
-            
-            print(f"  ✅ 合并完成")
-            print(f"  - 提取了 {len(cells)} 个表格单元格")
-            
-        except Exception as e:
-            print(f"  ❌ 处理失败: {e}")
-            import traceback
-            traceback.print_exc()
+        print()  # 空行分隔
+    
+    # 打印统计信息
+    print("=" * 60)
+    print(f"✅ 处理完成!")
+    print(f"📊 统计信息:")
+    print(f"  - 总文件数: {len(mineru_files)}")
+    print(f"  - 成功: {success_count}")
+    print(f"  - 失败: {failed_count}")
+    print("=" * 60)
 
 
 
 
-if __name__ == "__main__":
-    # 示例用法
-    mineru_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results"
-    paddle_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results"
-    output_dir = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results"
-    
-    merge_mineru_paddle_batch(mineru_dir, paddle_dir, output_dir)
+def main():
+    """主函数"""
+    parser = argparse.ArgumentParser(
+        description='合并 MinerU 和 PaddleOCR 的识别结果,添加 bbox 坐标信息',
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+示例用法:
 
 
-    # 示例:合并1个文件
-    # mineru_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results/A用户_单元格扫描流水_page_001.json"
-    # paddle_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results/A用户_单元格扫描流水_page_001.json"
-    # output_json = "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results/A用户_单元格扫描流水_page_001.json"
+  1. 批量处理整个目录:
+     python merge_mineru_paddle_ocr.py \\
+         --mineru-dir /path/to/mineru/results \\
+         --paddle-dir /path/to/paddle/results \\
+         --output-dir /path/to/output
 
 
-    # merger = MinerUPaddleOCRMerger(look_ahead_window=10, similarity_threshold=80)
-    # merger.merge_table_with_bbox(mineru_json, paddle_json, output_json)
+  2. 处理单个文件:
+     python merge_mineru_paddle_ocr.py \\
+         --mineru-file /path/to/file_page_001.json \\
+         --paddle-file /path/to/file_page_001.json \\
+         --output-dir /path/to/output
+
+  3. 自定义参数:
+     python merge_mineru_paddle_ocr.py \\
+         --mineru-dir /path/to/mineru \\
+         --paddle-dir /path/to/paddle \\
+         --output-dir /path/to/output \\
+         --window 15 \\
+         --threshold 85
+        """
+    )
+    
+    # 文件/目录参数
+    file_group = parser.add_argument_group('文件参数')
+    file_group.add_argument(
+        '--mineru-file', 
+        type=str,
+        help='MinerU 输出的 JSON 文件路径(单文件模式)'
+    )
+    file_group.add_argument(
+        '--paddle-file', 
+        type=str,
+        help='PaddleOCR 输出的 JSON 文件路径(单文件模式)'
+    )
+    
+    dir_group = parser.add_argument_group('目录参数')
+    dir_group.add_argument(
+        '--mineru-dir', 
+        type=str,
+        help='MinerU 结果目录(批量模式)'
+    )
+    dir_group.add_argument(
+        '--paddle-dir', 
+        type=str,
+        help='PaddleOCR 结果目录(批量模式)'
+    )
+    
+    # 输出参数
+    output_group = parser.add_argument_group('输出参数')
+    output_group.add_argument(
+        '-o', '--output-dir',
+        type=str,
+        required=True,
+        help='输出目录(必需)'
+    )
+    
+    # 算法参数
+    algo_group = parser.add_argument_group('算法参数')
+    algo_group.add_argument(
+        '-w', '--window',
+        type=int,
+        default=10,
+        help='向前查找的窗口大小(默认: 10)'
+    )
+    algo_group.add_argument(
+        '-t', '--threshold',
+        type=int,
+        default=80,
+        help='文本相似度阈值(0-100,默认: 80)'
+    )
+    
+    args = parser.parse_args()
+    
+    # 验证参数
+    if args.mineru_file and args.paddle_file:
+        # 单文件模式
+        mineru_file = Path(args.mineru_file)
+        paddle_file = Path(args.paddle_file)
+        output_dir = Path(args.output_dir)
+        
+        if not mineru_file.exists():
+            print(f"❌ 错误: MinerU 文件不存在: {mineru_file}")
+            return
+        
+        if not paddle_file.exists():
+            print(f"❌ 错误: PaddleOCR 文件不存在: {paddle_file}")
+            return
+        
+        output_dir.mkdir(parents=True, exist_ok=True)
+        
+        print("\n🔧 单文件处理模式")
+        print(f"📄 MinerU 文件: {mineru_file}")
+        print(f"📄 PaddleOCR 文件: {paddle_file}")
+        print(f"📂 输出目录: {output_dir}")
+        print(f"⚙️  查找窗口: {args.window}")
+        print(f"⚙️  相似度阈值: {args.threshold}%\n")
+        
+        merger = MinerUPaddleOCRMerger(
+            look_ahead_window=args.window,
+            similarity_threshold=args.threshold
+        )
+        
+        success = merge_single_file(mineru_file, paddle_file, output_dir, merger)
+        
+        if success:
+            print("\n✅ 处理完成!")
+        else:
+            print("\n❌ 处理失败!")
+    
+    elif args.mineru_dir and args.paddle_dir:
+        # 批量模式
+        if not Path(args.mineru_dir).exists():
+            print(f"❌ 错误: MinerU 目录不存在: {args.mineru_dir}")
+            return
+        
+        if not Path(args.paddle_dir).exists():
+            print(f"❌ 错误: PaddleOCR 目录不存在: {args.paddle_dir}")
+            return
+        
+        print("\n🔧 批量处理模式")
+        
+        merge_mineru_paddle_batch(
+            args.mineru_dir,
+            args.paddle_dir,
+            args.output_dir,
+            look_ahead_window=args.window,
+            similarity_threshold=args.threshold
+        )
+    
+    else:
+        parser.print_help()
+        print("\n❌ 错误: 请指定单文件模式或批量模式的参数")
+        print("  单文件模式: --mineru-file 和 --paddle-file")
+        print("  批量模式: --mineru-dir 和 --paddle-dir")
+
+if __name__ == "__main__":
+    print("🚀 启动 MinerU + PaddleOCR 合并程序...")
+    
+    import sys
+    
+    if len(sys.argv) == 1:
+        # 如果没有命令行参数,使用默认配置运行
+        print("ℹ️  未提供命令行参数,使用默认配置运行...")
+        
+        # 默认配置
+        default_config = {
+            "mineru-dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/mineru-vlm-2.5.3_Results",
+            "paddle-dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/data_PPStructureV3_Results",
+            "output-dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/merged_results",
+            "window": "15",
+            "threshold": "85"
+        }
+        
+        print(f"📂 MinerU 目录: {default_config['mineru-dir']}")
+        print(f"📂 PaddleOCR 目录: {default_config['paddle-dir']}")
+        print(f"📂 输出目录: {default_config['output-dir']}")
+        print(f"⚙️  查找窗口: {default_config['window']}")
+        print(f"⚙️  相似度阈值: {default_config['threshold']}%\n")
+        
+        # 构造参数
+        sys.argv = [sys.argv[0]]
+        for key, value in default_config.items():
+            sys.argv.extend([f"--{key}", str(value)])
+    
+    sys.exit(main())