瀏覽代碼

feat: 更新合并函数,添加data_format参数以支持格式转换为MinerU

zhch158_admin 4 周之前
父節點
當前提交
2dca41c351
共有 1 個文件被更改,包括 41 次插入6 次删除
  1. 41 6
      merger/paddleocr_vl_merger.py

+ 41 - 6
merger/paddleocr_vl_merger.py

@@ -9,11 +9,13 @@ try:
     from .bbox_extractor import BBoxExtractor
     from .data_processor import DataProcessor
     from .markdown_generator import MarkdownGenerator
+    from .unified_output_converter import UnifiedOutputConverter
 except ImportError:
     from text_matcher import TextMatcher
     from bbox_extractor import BBoxExtractor
     from data_processor import DataProcessor
     from markdown_generator import MarkdownGenerator
+    from unified_output_converter import UnifiedOutputConverter
 
 
 class PaddleOCRVLMerger:
@@ -33,18 +35,21 @@ class PaddleOCRVLMerger:
         self.bbox_extractor = BBoxExtractor()
         self.data_processor = DataProcessor(self.text_matcher, look_ahead_window)
         self.markdown_generator = MarkdownGenerator()
+        self.output_converter = UnifiedOutputConverter()
     
     def merge_table_with_bbox(self, paddleocr_vl_json_path: str, 
-                             paddle_json_path: str) -> List[Dict]:
+                             paddle_json_path: str,
+                             data_format: str = 'mineru') -> List[Dict]:
         """
         合并 PaddleOCR_VL 和 PaddleOCR 的结果
         
         Args:
             paddleocr_vl_json_path: PaddleOCR_VL 输出的 JSON 路径
             paddle_json_path: PaddleOCR 输出的 JSON 路径
+            data_format: 输出格式 ('mineru' 或 'paddleocr_vl')
         
         Returns:
-            合并后的结果 (MinerU 格式)
+            合并后的结果 (默认MinerU格式)
         """
         # 加载数据
         with open(paddleocr_vl_json_path, 'r', encoding='utf-8') as f:
@@ -61,16 +66,46 @@ class PaddleOCRVLMerger:
             paddleocr_vl_data, paddle_text_boxes
         )
         
+        # 转换为指定格式
+        if data_format == 'mineru':
+            merged_data = self.output_converter.convert_to_mineru_format(
+                merged_data, data_source='paddleocr_vl'
+            )
+        
         return merged_data
     
     def generate_enhanced_markdown(self, merged_data: List[Dict], 
                                    output_path: str = None,
-                                   source_file: str = None) -> str:
-        """生成增强的 Markdown"""
-        return self.markdown_generator._generate_paddleocr_vl_markdown(
-            merged_data, output_path, source_file
+                                   source_file: str = None,
+                                   data_format: str = None) -> str:
+        """
+        生成增强的 Markdown
+        
+        Args:
+            merged_data: 合并后的数据
+            output_path: 输出路径
+            source_file: 源文件路径
+            data_format: 数据格式,None 则自动检测
+        """
+        # 如果data_format未指定,自动检测
+        if data_format is None:
+            data_format = self.markdown_generator.detect_data_format(merged_data)
+        
+        # 如果是PaddleOCR_VL格式,先转换为MinerU格式
+        if data_format == 'paddleocr_vl':
+            merged_data = self.output_converter.convert_to_mineru_format(
+                merged_data, data_source='paddleocr_vl'
+            )
+            data_format = 'mineru'
+        
+        return self.markdown_generator.generate_enhanced_markdown(
+            merged_data, output_path, source_file, data_format
         )
     
     def extract_table_cells_with_bbox(self, merged_data: List[Dict]) -> List[Dict]:
         """提取所有表格单元格及其 bbox 信息"""
+        # 确保数据是MinerU格式
+        if self.output_converter._detect_data_source(merged_data) != 'mineru':
+            merged_data = self.output_converter.convert_to_mineru_format(merged_data)
+        
         return self.bbox_extractor.extract_table_cells_with_bbox(merged_data)