config_manager.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """配置管理器 - 加载和验证配置文件"""
  2. import yaml
  3. from pathlib import Path
  4. from typing import Dict, Any, Optional
  5. class ConfigManager:
  6. """配置管理器"""
  7. _config_cache = {}
  8. @classmethod
  9. def load_config(cls, config_path: str) -> Dict[str, Any]:
  10. """加载配置文件"""
  11. config_path = Path(config_path)
  12. # 缓存机制
  13. cache_key = str(config_path.absolute())
  14. if cache_key in cls._config_cache:
  15. return cls._config_cache[cache_key]
  16. if not config_path.exists():
  17. raise FileNotFoundError(f"Config file not found: {config_path}")
  18. with open(config_path, 'r', encoding='utf-8') as f:
  19. config = yaml.safe_load(f)
  20. # 配置验证和默认值设置
  21. config = cls._validate_and_set_defaults(config)
  22. # 缓存配置
  23. cls._config_cache[cache_key] = config
  24. return config
  25. @classmethod
  26. def _validate_and_set_defaults(cls, config: Dict[str, Any]) -> Dict[str, Any]:
  27. """验证配置并设置默认值"""
  28. # 设置默认场景名称
  29. if 'scene_name' not in config:
  30. config['scene_name'] = 'unknown'
  31. # 验证必需的配置项
  32. required_sections = ['preprocessor', 'layout_detection', 'vl_recognition', 'ocr_recognition']
  33. for section in required_sections:
  34. if section not in config:
  35. config[section] = {'module': 'mineru'}
  36. # 设置预处理器默认配置
  37. preprocessor_defaults = {
  38. 'module': 'mineru',
  39. 'orientation_classifier': {'enabled': True},
  40. 'unwarping': {'enabled': False}
  41. }
  42. config['preprocessor'] = cls._merge_defaults(
  43. config.get('preprocessor', {}), preprocessor_defaults
  44. )
  45. # 设置版式检测默认配置
  46. layout_defaults = {
  47. 'module': 'mineru',
  48. 'model_name': 'RT-DETR-H_layout_17cls',
  49. 'device': 'cpu',
  50. 'batch_size': 1,
  51. 'conf': 0.25,
  52. 'iou': 0.45
  53. }
  54. config['layout_detection'] = cls._merge_defaults(
  55. config.get('layout_detection', {}), layout_defaults
  56. )
  57. # 设置VL识别默认配置
  58. vl_defaults = {
  59. 'module': 'mineru',
  60. 'backend': 'vllm-http-client',
  61. 'server_url': 'http://localhost:8111/v1',
  62. 'device': 'cpu',
  63. 'batch_size': 1,
  64. 'model_params': {'max_concurrency': 10, 'http_timeout': 600}
  65. }
  66. config['vl_recognition'] = cls._merge_defaults(
  67. config.get('vl_recognition', {}), vl_defaults
  68. )
  69. # 设置OCR默认配置
  70. ocr_defaults = {
  71. 'module': 'mineru',
  72. 'language': 'ch',
  73. 'det_threshold': 0.3,
  74. 'unclip_ratio': 1.8,
  75. 'batch_size': 8,
  76. 'device': 'cpu'
  77. }
  78. config['ocr_recognition'] = cls._merge_defaults(
  79. config.get('ocr_recognition', {}), ocr_defaults
  80. )
  81. # 设置输出默认配置
  82. output_defaults = {
  83. 'format': 'enhanced_json',
  84. 'save_json': True,
  85. 'save_markdown': True,
  86. 'save_html': True,
  87. 'save_images': {'layout': True, 'ocr': True, 'table_cells': True},
  88. 'coordinate_precision': 2
  89. }
  90. config['output'] = cls._merge_defaults(
  91. config.get('output', {}), output_defaults
  92. )
  93. return config
  94. @classmethod
  95. def _merge_defaults(cls, user_config: Dict[str, Any], defaults: Dict[str, Any]) -> Dict[str, Any]:
  96. """合并用户配置和默认配置"""
  97. result = defaults.copy()
  98. for key, value in user_config.items():
  99. if isinstance(value, dict) and key in result and isinstance(result[key], dict):
  100. result[key] = cls._merge_defaults(value, result[key])
  101. else:
  102. result[key] = value
  103. return result
  104. @classmethod
  105. def save_config(cls, config: Dict[str, Any], config_path: str):
  106. """保存配置文件"""
  107. config_path = Path(config_path)
  108. config_path.parent.mkdir(parents=True, exist_ok=True)
  109. with open(config_path, 'w', encoding='utf-8') as f:
  110. yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
  111. @classmethod
  112. def clear_cache(cls):
  113. """清空配置缓存"""
  114. cls._config_cache.clear()