dotsocr_merger.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """
  2. DotsOCR 和 PaddleOCR 合并模块
  3. """
  4. import json
  5. from typing import List, Dict
  6. import sys
  7. from pathlib import Path
  8. # 添加 ocr_platform 根目录到 Python 路径(用于导入 ocr_utils)
  9. ocr_platform_root = Path(__file__).parents[3] # ocr_merger -> ocr_tools -> ocr_platform -> repository.git
  10. if str(ocr_platform_root) not in sys.path:
  11. sys.path.insert(0, str(ocr_platform_root))
  12. try:
  13. from .text_matcher import TextMatcher
  14. from ocr_utils import BBoxExtractor # 从 ocr_utils 导入
  15. from .data_processor import DataProcessor
  16. from .markdown_generator import MarkdownGenerator
  17. from .unified_output_converter import UnifiedOutputConverter
  18. except ImportError:
  19. from text_matcher import TextMatcher
  20. from ocr_utils import BBoxExtractor # 从 ocr_utils 导入
  21. from data_processor import DataProcessor
  22. from markdown_generator import MarkdownGenerator
  23. from unified_output_converter import UnifiedOutputConverter
  24. class DotsOCRMerger:
  25. """DotsOCR 和 PaddleOCR 结果合并器"""
  26. def __init__(self, look_ahead_window: int = 10, similarity_threshold: int = 90):
  27. """
  28. Args:
  29. look_ahead_window: 向前查找的窗口大小
  30. similarity_threshold: 文本相似度阈值
  31. """
  32. self.look_ahead_window = look_ahead_window
  33. self.similarity_threshold = similarity_threshold
  34. # 初始化子模块
  35. self.text_matcher = TextMatcher(similarity_threshold)
  36. self.bbox_extractor = BBoxExtractor()
  37. self.data_processor = DataProcessor(self.text_matcher, look_ahead_window)
  38. self.markdown_generator = MarkdownGenerator()
  39. self.output_converter = UnifiedOutputConverter()
  40. def merge_table_with_bbox(self, dotsocr_json_path: str,
  41. paddle_json_path: str,
  42. data_format: str = 'mineru') -> List[Dict]:
  43. """
  44. 合并 DotsOCR 和 PaddleOCR 的结果
  45. Args:
  46. dotsocr_json_path: DotsOCR 输出的 JSON 路径
  47. paddle_json_path: PaddleOCR 输出的 JSON 路径
  48. data_format: 输出格式(固定为 'mineru')
  49. Returns:
  50. MinerU 格式的合并数据
  51. """
  52. # 加载数据
  53. with open(dotsocr_json_path, 'r', encoding='utf-8') as f:
  54. dotsocr_data = json.load(f)
  55. with open(paddle_json_path, 'r', encoding='utf-8') as f:
  56. paddle_data = json.load(f)
  57. # 🎯 提取 PaddleOCR 的文字框信息
  58. paddle_text_boxes, rotation_angle, orig_image_size = self.bbox_extractor.extract_paddle_text_boxes(paddle_data)
  59. # 🎯 使用专门的 DotsOCR 处理方法(自动转换为 MinerU 格式)
  60. merged_data = self.data_processor.process_dotsocr_data(
  61. dotsocr_data, paddle_text_boxes, rotation_angle, orig_image_size
  62. )
  63. return merged_data
  64. def generate_enhanced_markdown(self, merged_data: List[Dict],
  65. output_path: str = None,
  66. source_file: str = None,
  67. data_format: str = 'mineru') -> str:
  68. """
  69. 生成增强的 Markdown(MinerU 格式)
  70. Args:
  71. merged_data: 合并后的数据(MinerU 格式)
  72. output_path: 输出路径
  73. source_file: 源文件路径
  74. data_format: 数据格式(固定为 'mineru')
  75. """
  76. # 🎯 强制使用 MinerU 格式生成 Markdown
  77. return self.markdown_generator.generate_enhanced_markdown(
  78. merged_data, output_path, source_file, data_format='mineru'
  79. )
  80. def extract_table_cells_with_bbox(self, merged_data: List[Dict]) -> List[Dict]:
  81. """提取所有表格单元格及其 bbox 信息"""
  82. # 🎯 直接复用 BBoxExtractor 的方法
  83. return self.bbox_extractor.extract_table_cells_with_bbox(merged_data)