__init__.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """
  2. 模型适配器模块
  3. 提供统一的接口适配不同的模型后端
  4. """
  5. from .base import (
  6. BaseAdapter,
  7. BasePreprocessor,
  8. BaseLayoutDetector,
  9. BaseVLRecognizer,
  10. BaseOCRRecognizer
  11. )
  12. from .paddle_layout_detector import PaddleLayoutDetector
  13. # 可选导入 MinerU 适配器
  14. try:
  15. from .mineru_adapter import (
  16. MinerUPreprocessor,
  17. MinerULayoutDetector,
  18. MinerUVLRecognizer,
  19. MinerUOCRRecognizer
  20. )
  21. MINERU_AVAILABLE = True
  22. except ImportError:
  23. MINERU_AVAILABLE = False
  24. __all__ = [
  25. # 基类
  26. 'BaseAdapter',
  27. 'BasePreprocessor',
  28. 'BaseLayoutDetector',
  29. 'BaseVLRecognizer',
  30. 'BaseOCRRecognizer',
  31. # PaddleX 适配器
  32. 'PaddleLayoutDetector',
  33. ]
  34. # 如果 MinerU 可用,添加到导出列表
  35. if MINERU_AVAILABLE:
  36. __all__.extend([
  37. 'MinerUPreprocessor',
  38. 'MinerULayoutDetector',
  39. 'MinerUVLRecognizer',
  40. 'MinerUOCRRecognizer',
  41. ])
  42. def get_layout_detector(config: dict):
  43. """
  44. 根据配置获取布局检测器
  45. Args:
  46. config: 配置字典,包含 module 和其他参数
  47. Returns:
  48. BaseLayoutDetector 实例
  49. """
  50. module = config.get('module', 'paddle')
  51. if module == 'paddle':
  52. return PaddleLayoutDetector(config)
  53. elif module == 'mineru':
  54. if not MINERU_AVAILABLE:
  55. raise ImportError("MinerU adapter not available")
  56. return MinerULayoutDetector(config)
  57. else:
  58. raise ValueError(f"Unknown layout detection module: {module}")
  59. def get_preprocessor(config: dict):
  60. """根据配置获取预处理器"""
  61. module = config.get('module', 'mineru')
  62. if module == 'mineru':
  63. if not MINERU_AVAILABLE:
  64. raise ImportError("MinerU adapter not available")
  65. return MinerUPreprocessor(config)
  66. else:
  67. raise ValueError(f"Unknown preprocessor module: {module}")
  68. def get_vl_recognizer(config: dict):
  69. """根据配置获取VL识别器"""
  70. module = config.get('module', 'mineru')
  71. if module == 'mineru':
  72. if not MINERU_AVAILABLE:
  73. raise ImportError("MinerU adapter not available")
  74. return MinerUVLRecognizer(config)
  75. else:
  76. raise ValueError(f"Unknown VL recognizer module: {module}")
  77. def get_ocr_recognizer(config: dict):
  78. """根据配置获取OCR识别器"""
  79. module = config.get('module', 'mineru')
  80. if module == 'mineru':
  81. if not MINERU_AVAILABLE:
  82. raise ImportError("MinerU adapter not available")
  83. return MinerUOCRRecognizer(config)
  84. else:
  85. raise ValueError(f"Unknown OCR recognizer module: {module}")