dotsocr_merger.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. DotsOCR 和 PaddleOCR 合并模块
  3. """
  4. import json
  5. from typing import List, Dict
  6. try:
  7. from .text_matcher import TextMatcher
  8. from .bbox_extractor import BBoxExtractor
  9. from .data_processor import DataProcessor
  10. from .markdown_generator import MarkdownGenerator
  11. from .unified_output_converter import UnifiedOutputConverter
  12. except ImportError:
  13. from text_matcher import TextMatcher
  14. from bbox_extractor import BBoxExtractor
  15. from data_processor import DataProcessor
  16. from markdown_generator import MarkdownGenerator
  17. from unified_output_converter import UnifiedOutputConverter
  18. class DotsOCRMerger:
  19. """DotsOCR 和 PaddleOCR 结果合并器"""
  20. def __init__(self, look_ahead_window: int = 10, similarity_threshold: int = 90):
  21. """
  22. Args:
  23. look_ahead_window: 向前查找的窗口大小
  24. similarity_threshold: 文本相似度阈值
  25. """
  26. self.look_ahead_window = look_ahead_window
  27. self.similarity_threshold = similarity_threshold
  28. # 初始化子模块
  29. self.text_matcher = TextMatcher(similarity_threshold)
  30. self.bbox_extractor = BBoxExtractor()
  31. self.data_processor = DataProcessor(self.text_matcher, look_ahead_window)
  32. self.markdown_generator = MarkdownGenerator()
  33. self.output_converter = UnifiedOutputConverter()
  34. def merge_table_with_bbox(self, dotsocr_json_path: str,
  35. paddle_json_path: str,
  36. data_format: str = 'mineru') -> List[Dict]:
  37. """
  38. 合并 DotsOCR 和 PaddleOCR 的结果
  39. Args:
  40. dotsocr_json_path: DotsOCR 输出的 JSON 路径
  41. paddle_json_path: PaddleOCR 输出的 JSON 路径
  42. data_format: 输出格式(固定为 'mineru')
  43. Returns:
  44. MinerU 格式的合并数据
  45. """
  46. # 加载数据
  47. with open(dotsocr_json_path, 'r', encoding='utf-8') as f:
  48. dotsocr_data = json.load(f)
  49. with open(paddle_json_path, 'r', encoding='utf-8') as f:
  50. paddle_data = json.load(f)
  51. # 🎯 提取 PaddleOCR 的文字框信息
  52. paddle_text_boxes = self.bbox_extractor.extract_paddle_text_boxes(paddle_data)
  53. # 🎯 使用专门的 DotsOCR 处理方法(自动转换为 MinerU 格式)
  54. merged_data = self.data_processor.process_dotsocr_data(
  55. dotsocr_data, paddle_text_boxes
  56. )
  57. return merged_data
  58. def generate_enhanced_markdown(self, merged_data: List[Dict],
  59. output_path: str = None,
  60. source_file: str = None,
  61. data_format: str = 'mineru') -> str:
  62. """
  63. 生成增强的 Markdown(MinerU 格式)
  64. Args:
  65. merged_data: 合并后的数据(MinerU 格式)
  66. output_path: 输出路径
  67. source_file: 源文件路径
  68. data_format: 数据格式(固定为 'mineru')
  69. """
  70. # 🎯 强制使用 MinerU 格式生成 Markdown
  71. return self.markdown_generator.generate_enhanced_markdown(
  72. merged_data, output_path, source_file, data_format='mineru'
  73. )
  74. def extract_table_cells_with_bbox(self, merged_data: List[Dict]) -> List[Dict]:
  75. """提取所有表格单元格及其 bbox 信息"""
  76. # 🎯 直接复用 BBoxExtractor 的方法
  77. return self.bbox_extractor.extract_table_cells_with_bbox(merged_data)