Просмотр исходного кода

feat: 增强Markdown生成器,支持自动检测数据格式并生成PaddleOCR_VL格式的Markdown

zhch158_admin 4 недель назад
Родитель
Сommit
069753e209
1 измененных файлов с 226 добавлено и 38 удалено
  1. 226 38
      merger/markdown_generator.py

+ 226 - 38
merger/markdown_generator.py

@@ -11,36 +11,83 @@ class MarkdownGenerator:
     """Markdown 生成器"""
     
     @staticmethod
+    def detect_data_format(merged_data: List[Dict]) -> str:
+        """
+        检测数据格式
+        
+        Returns:
+            'mineru' 或 'paddleocr_vl'
+        """
+        if not merged_data:
+            return 'mineru'
+        
+        first_item = merged_data[0]
+        
+        # 检查是否有 PaddleOCR_VL 特有字段
+        if 'block_label' in first_item and 'block_content' in first_item:
+            return 'paddleocr_vl'
+        
+        # 检查是否有 MinerU 特有字段
+        if 'type' in first_item and ('table_body' in first_item or 'text' in first_item):
+            return 'mineru'
+        
+        # 默认按 MinerU 格式处理
+        return 'mineru'
+    
+    @staticmethod
     def generate_enhanced_markdown(merged_data: List[Dict], 
                                    output_path: Optional[str] = None,
-                                   mineru_file: Optional[str] = None) -> str:
+                                   source_file: Optional[str] = None,
+                                   data_format: Optional[str] = None) -> str:
         """
         生成增强的 Markdown(包含 bbox 信息的注释)
         
         Args:
             merged_data: 合并后的数据
             output_path: 输出路径
-            mineru_file: MinerU 源文件路径(用于复制图片)
+            source_file: 源文件路径(用于复制图片)
+            data_format: 数据格式 ('mineru' 或 'paddleocr_vl'),None 则自动检测
         
         Returns:
             Markdown 内容
         """
+        # ✅ 自动检测数据格式
+        if data_format is None:
+            data_format = MarkdownGenerator.detect_data_format(merged_data)
+        
+        print(f"ℹ️  检测到数据格式: {data_format}")
+        
+        # ✅ 根据格式选择处理函数
+        if data_format == 'paddleocr_vl':
+            return MarkdownGenerator._generate_paddleocr_vl_markdown(
+                merged_data, output_path, source_file
+            )
+        else:
+            return MarkdownGenerator._generate_mineru_markdown(
+                merged_data, output_path, source_file
+            )
+    
+    @staticmethod
+    def _generate_mineru_markdown(merged_data: List[Dict],
+                                  output_path: Optional[str] = None,
+                                  source_file: Optional[str] = None) -> str:
+        """生成 MinerU 格式的 Markdown"""
         md_lines = []
         
         for item in merged_data:
             item_type = item.get('type', '')
             
             if item_type == 'title':
-                md_lines.extend(MarkdownGenerator._format_title(item))
+                md_lines.extend(MarkdownGenerator._format_mineru_title(item))
             elif item_type == 'text':
-                md_lines.extend(MarkdownGenerator._format_text(item))
+                md_lines.extend(MarkdownGenerator._format_mineru_text(item))
             elif item_type == 'list':
-                md_lines.extend(MarkdownGenerator._format_list(item))
+                md_lines.extend(MarkdownGenerator._format_mineru_list(item))
             elif item_type == 'table':
-                md_lines.extend(MarkdownGenerator._format_table(item))
+                md_lines.extend(MarkdownGenerator._format_mineru_table(item))
             elif item_type == 'image':
-                md_lines.extend(MarkdownGenerator._format_image(
-                    item, output_path, mineru_file
+                md_lines.extend(MarkdownGenerator._format_mineru_image(
+                    item, output_path, source_file
                 ))
             elif item_type == 'equation':
                 md_lines.extend(MarkdownGenerator._format_equation(item))
@@ -62,13 +109,43 @@ class MarkdownGenerator:
         return markdown_content
     
     @staticmethod
-    def _add_bbox_comment(bbox: List) -> str:
-        """添加 bbox 注释"""
-        return f"<!-- bbox: {bbox} -->"
+    def _generate_paddleocr_vl_markdown(merged_data: List[Dict],
+                                        output_path: Optional[str] = None,
+                                        source_file: Optional[str] = None) -> str:
+        """生成 PaddleOCR_VL 格式的 Markdown"""
+        md_lines = []
+        
+        for item in merged_data:
+            block_label = item.get('block_label', '')
+            
+            if 'title' in block_label:
+                md_lines.extend(MarkdownGenerator._format_paddleocr_vl_title(item))
+            elif block_label == 'text':
+                md_lines.extend(MarkdownGenerator._format_paddleocr_vl_text(item))
+            elif block_label == 'table':
+                md_lines.extend(MarkdownGenerator._format_paddleocr_vl_table(item))
+            elif block_label == 'image':
+                md_lines.extend(MarkdownGenerator._format_paddleocr_vl_figure(item))
+            elif block_label == 'equation':
+                md_lines.extend(MarkdownGenerator._format_paddleocr_vl_equation(item))
+            elif block_label == 'reference':
+                md_lines.extend(MarkdownGenerator._format_paddleocr_vl_reference(item))
+            else:
+                md_lines.extend(MarkdownGenerator._format_paddleocr_vl_unknown(item))
+        
+        markdown_content = '\n'.join(md_lines)
+        
+        if output_path:
+            with open(output_path, 'w', encoding='utf-8') as f:
+                f.write(markdown_content)
+        
+        return markdown_content
+    
+    # ================== MinerU 格式化方法 ==================
     
     @staticmethod
-    def _format_title(item: Dict) -> List[str]:
-        """格式化标题"""
+    def _format_mineru_title(item: Dict) -> List[str]:
+        """格式化 MinerU 标题"""
         lines = []
         bbox = item.get('bbox', [])
         if bbox:
@@ -82,8 +159,8 @@ class MarkdownGenerator:
         return lines
     
     @staticmethod
-    def _format_text(item: Dict) -> List[str]:
-        """格式化文本"""
+    def _format_mineru_text(item: Dict) -> List[str]:
+        """格式化 MinerU 文本"""
         lines = []
         bbox = item.get('bbox', [])
         if bbox:
@@ -101,8 +178,8 @@ class MarkdownGenerator:
         return lines
     
     @staticmethod
-    def _format_list(item: Dict) -> List[str]:
-        """格式化列表"""
+    def _format_mineru_list(item: Dict) -> List[str]:
+        """格式化 MinerU 列表"""
         lines = []
         bbox = item.get('bbox', [])
         if bbox:
@@ -116,8 +193,8 @@ class MarkdownGenerator:
         return lines
     
     @staticmethod
-    def _format_table(item: Dict) -> List[str]:
-        """格式化表格"""
+    def _format_mineru_table(item: Dict) -> List[str]:
+        """格式化 MinerU 表格"""
         lines = []
         bbox = item.get('bbox', [])
         if bbox:
@@ -146,9 +223,9 @@ class MarkdownGenerator:
         return lines
     
     @staticmethod
-    def _format_image(item: Dict, output_path: Optional[str],
-                     mineru_file: Optional[str]) -> List[str]:
-        """格式化图片"""
+    def _format_mineru_image(item: Dict, output_path: Optional[str],
+                            source_file: Optional[str]) -> List[str]:
+        """格式化 MinerU 图片"""
         lines = []
         bbox = item.get('bbox', [])
         if bbox:
@@ -157,8 +234,8 @@ class MarkdownGenerator:
         img_path = item.get('img_path', '')
         
         # 复制图片
-        if img_path and mineru_file and output_path:
-            MarkdownGenerator._copy_image(img_path, mineru_file, output_path)
+        if img_path and source_file and output_path:
+            MarkdownGenerator._copy_image(img_path, source_file, output_path)
         
         # 图片标题
         image_caption = item.get('image_caption', [])
@@ -178,19 +255,120 @@ class MarkdownGenerator:
         
         return lines
     
+    # ================== PaddleOCR_VL 格式化方法 ==================
+    
     @staticmethod
-    def _copy_image(img_path: str, mineru_file: str, output_path: str):
-        """复制图片到输出目录"""
-        mineru_dir = Path(mineru_file).parent
-        img_full_path = mineru_dir / img_path
-        if img_full_path.exists():
-            output_img_path = Path(output_path).parent / img_path
-            output_img_path.parent.mkdir(parents=True, exist_ok=True)
-            shutil.copy(img_full_path, output_img_path)
+    def _format_paddleocr_vl_title(item: Dict) -> List[str]:
+        """格式化 PaddleOCR_VL 标题"""
+        lines = []
+        bbox = item.get('block_bbox', [])
+        if bbox:
+            lines.append(MarkdownGenerator._add_bbox_comment(bbox))
+        
+        text = item.get('block_content', '')
+        block_label = item.get('block_label', '')
+        
+        # 根据 block_label 确定标题级别
+        level_map = {
+            'paragraph_title': 1,
+            'figure_title': 2,
+            'title': 1
+        }
+        text_level = level_map.get(block_label, 1)
+        
+        heading = '#' * min(text_level, 6)
+        lines.append(f"{heading} {text}\n")
+        
+        return lines
+    
+    @staticmethod
+    def _format_paddleocr_vl_text(item: Dict) -> List[str]:
+        """格式化 PaddleOCR_VL 文本"""
+        lines = []
+        bbox = item.get('block_bbox', [])
+        if bbox:
+            lines.append(MarkdownGenerator._add_bbox_comment(bbox))
+        
+        text = item.get('block_content', '')
+        lines.append(f"{text}\n")
+        
+        return lines
+    
+    @staticmethod
+    def _format_paddleocr_vl_table(item: Dict) -> List[str]:
+        """格式化 PaddleOCR_VL 表格"""
+        lines = []
+        bbox = item.get('block_bbox', [])
+        if bbox:
+            lines.append(MarkdownGenerator._add_bbox_comment(bbox))
+        
+        # 表格内容
+        table_content = item.get('block_content_with_bbox', 
+                                item.get('block_content', ''))
+        if table_content:
+            lines.append(table_content)
+            lines.append("")
+        
+        return lines
+    
+    @staticmethod
+    def _format_paddleocr_vl_figure(item: Dict) -> List[str]:
+        """格式化 PaddleOCR_VL 图片"""
+        lines = []
+        bbox = item.get('block_bbox', [])
+        if bbox:
+            lines.append(MarkdownGenerator._add_bbox_comment(bbox))
+        
+        # PaddleOCR_VL 图片信息在 block_content 中
+        content = item.get('block_content', '')
+        lines.append(f"![Figure]({content})\n")
+        
+        return lines
+    
+    @staticmethod
+    def _format_paddleocr_vl_equation(item: Dict) -> List[str]:
+        """格式化 PaddleOCR_VL 公式"""
+        lines = []
+        bbox = item.get('block_bbox', [])
+        if bbox:
+            lines.append(MarkdownGenerator._add_bbox_comment(bbox))
+        
+        latex = item.get('block_content', '')
+        if latex:
+            lines.append(f"$$\n{latex}\n$$\n")
+        
+        return lines
+    
+    @staticmethod
+    def _format_paddleocr_vl_reference(item: Dict) -> List[str]:
+        """格式化 PaddleOCR_VL 参考文献"""
+        text = item.get('block_content', '')
+        return [f"> {text}\n"]
+    
+    @staticmethod
+    def _format_paddleocr_vl_unknown(item: Dict) -> List[str]:
+        """格式化 PaddleOCR_VL 未知类型"""
+        lines = []
+        bbox = item.get('block_bbox', [])
+        if bbox:
+            lines.append(MarkdownGenerator._add_bbox_comment(bbox))
+        
+        text = item.get('block_content', '')
+        if text:
+            lines.append(f"{text}\n")
+        
+        return lines
+    
+    # ================== 通用方法 ==================
+    
+    @staticmethod
+    def _add_bbox_comment(bbox: List) -> str:
+        """添加 bbox 注释"""
+        return f"<!-- bbox: {bbox} -->"
     
     @staticmethod
     def _format_equation(item: Dict) -> List[str]:
-        """格式化公式"""
+        """格式化公式(通用)"""
         latex = item.get('latex', '')
         if latex:
             return [f"$$\n{latex}\n$$\n"]
@@ -198,7 +376,7 @@ class MarkdownGenerator:
     
     @staticmethod
     def _format_inline_equation(item: Dict) -> List[str]:
-        """格式化行内公式"""
+        """格式化行内公式(通用)"""
         latex = item.get('latex', '')
         if latex:
             return [f"${latex}$\n"]
@@ -206,7 +384,7 @@ class MarkdownGenerator:
     
     @staticmethod
     def _format_metadata(item: Dict, item_type: str) -> List[str]:
-        """格式化元数据(页码、页眉、页脚)"""
+        """格式化元数据(通用)"""
         text = item.get('text', '')
         type_map = {
             'page_number': '页码',
@@ -219,14 +397,24 @@ class MarkdownGenerator:
     
     @staticmethod
     def _format_reference(item: Dict) -> List[str]:
-        """格式化参考文献"""
+        """格式化参考文献(MinerU)"""
         text = item.get('text', '')
         return [f"> {text}\n"]
     
     @staticmethod
     def _format_unknown(item: Dict) -> List[str]:
-        """格式化未知类型"""
+        """格式化未知类型(MinerU)"""
         text = item.get('text', '')
         if text:
             return [f"{text}\n"]
-        return []
+        return []
+    
+    @staticmethod
+    def _copy_image(img_path: str, source_file: str, output_path: str):
+        """复制图片到输出目录"""
+        source_dir = Path(source_file).parent
+        img_full_path = source_dir / img_path
+        if img_full_path.exists():
+            output_img_path = Path(output_path).parent / img_path
+            output_img_path.parent.mkdir(parents=True, exist_ok=True)
+            shutil.copy(img_full_path, output_img_path)