""" 配置管理器 支持分层配置和自动发现数据源 支持 Jinja2 模板变量 """ import yaml from pathlib import Path from typing import Dict, List, Optional, Any from dataclasses import dataclass, field import logging from jinja2 import Template # 🎯 新增 @dataclass class OCRToolConfig: """OCR 工具配置""" name: str description: str json_structure: str text_field: str bbox_field: str category_field: str confidence_field: str = "confidence" parsing_results_field: Optional[str] = None rec_texts_field: Optional[str] = None rec_boxes_field: Optional[str] = None table_body_field: Optional[str] = None table_cells_field: Optional[str] = None img_path_field: Optional[str] = None rotation: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, tool_id: str, data: Dict) -> 'OCRToolConfig': """从字典创建""" return cls( name=data.get('name', tool_id), description=data.get('description', ''), json_structure=data.get('json_structure', 'object'), text_field=data.get('text_field', 'text'), bbox_field=data.get('bbox_field', 'bbox'), category_field=data.get('category_field', 'category'), confidence_field=data.get('confidence_field', 'confidence'), parsing_results_field=data.get('parsing_results_field'), rec_texts_field=data.get('rec_texts_field'), rec_boxes_field=data.get('rec_boxes_field'), table_body_field=data.get('table_body_field'), table_cells_field=data.get('table_cells_field'), img_path_field=data.get('img_path_field'), rotation=data.get('rotation', {}) ) def to_dict(self) -> Dict: """转换为字典(用于 OCRValidator)""" config_dict = { 'name': self.name, 'description': self.description, 'json_structure': self.json_structure, 'text_field': self.text_field, 'bbox_field': self.bbox_field, 'category_field': self.category_field, 'confidence_field': self.confidence_field, 'rotation': self.rotation } # 添加可选字段 if self.parsing_results_field: config_dict['parsing_results_field'] = self.parsing_results_field if self.rec_texts_field: config_dict['rec_texts_field'] = self.rec_texts_field if self.rec_boxes_field: config_dict['rec_boxes_field'] = self.rec_boxes_field if self.table_body_field: config_dict['table_body_field'] = self.table_body_field if self.table_cells_field: config_dict['table_cells_field'] = self.table_cells_field if self.img_path_field: config_dict['img_path_field'] = self.img_path_field return config_dict @dataclass class OCRResultConfig: """OCR 结果配置""" tool: str result_dir: str image_dir: Optional[str] description: str = "" enabled: bool = True @classmethod def from_dict(cls, data: Dict, context: Dict = None) -> 'OCRResultConfig': """ 🎯 从字典创建(支持 Jinja2 模板) Args: data: 配置数据 context: 模板上下文(如 {'name': '德_内蒙古银行照'}) """ # 🎯 渲染模板 if context: result_dir = cls._render_template(data['result_dir'], context) image_dir = cls._render_template(data.get('image_dir'), context) if data.get('image_dir') else None description = cls._render_template(data.get('description', ''), context) else: result_dir = data['result_dir'] image_dir = data.get('image_dir') description = data.get('description', '') return cls( tool=data['tool'], result_dir=result_dir, image_dir=image_dir, description=description, enabled=data.get('enabled', True) ) @staticmethod def _render_template(template_str: Optional[str], context: Dict) -> Optional[str]: """🎯 渲染 Jinja2 模板""" if not template_str: return None try: template = Template(template_str) return template.render(context) except Exception as e: logging.warning(f"模板渲染失败: {template_str}, 错误: {e}") return template_str @dataclass class DocumentConfig: """文档配置""" name: str base_dir: str ocr_results: List[OCRResultConfig] = field(default_factory=list) @classmethod def from_dict(cls, data: Dict) -> 'DocumentConfig': """从字典创建(支持 Jinja2 模板)""" doc_data = data.get('document', data) # 🎯 构建模板上下文 context = { 'name': doc_data['name'], 'base_dir': doc_data['base_dir'] } return cls( name=doc_data['name'], base_dir=doc_data['base_dir'], ocr_results=[ OCRResultConfig.from_dict(r, context) for r in doc_data.get('ocr_results', []) ] ) @dataclass class DataSource: """数据源(用于 OCRValidator)""" name: str ocr_tool: str ocr_out_dir: str src_img_dir: str description: str = "" class ConfigManager: """配置管理器""" def __init__(self, config_dir: str = "config"): """ Args: config_dir: 配置文件目录 """ self.config_dir = Path(config_dir) self.logger = logging.getLogger(__name__) # 加载配置 self.global_config = self._load_global_config() self.ocr_tools = self._load_ocr_tools() self.documents = self._load_documents() def _load_global_config(self) -> Dict: """加载全局配置""" config_file = self.config_dir / "global.yaml" if not config_file.exists(): self.logger.warning(f"全局配置文件不存在: {config_file}") return {} with open(config_file, 'r', encoding='utf-8') as f: return yaml.safe_load(f) or {} def _load_ocr_tools(self) -> Dict[str, OCRToolConfig]: """加载 OCR 工具配置(从 global.yaml)""" tools_data = self.global_config.get('ocr', {}).get('tools', {}) tools = {} for tool_id, tool_data in tools_data.items(): tools[tool_id] = OCRToolConfig.from_dict(tool_id, tool_data) return tools def _load_documents(self) -> Dict[str, DocumentConfig]: """加载文档配置(支持 Jinja2 模板)""" documents = {} # 从 global.yaml 读取文档配置文件列表 doc_files = self.global_config.get('data_sources', []) for doc_file in doc_files: # 支持相对路径和绝对路径 if not doc_file.endswith('.yaml'): doc_file = f"{doc_file}.yaml" yaml_path = self.config_dir / doc_file if not yaml_path.exists(): self.logger.warning(f"文档配置文件不存在: {yaml_path}") continue try: with open(yaml_path, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) # 🎯 使用支持 Jinja2 的解析方法 doc_config = DocumentConfig.from_dict(data) documents[doc_config.name] = doc_config self.logger.info(f"✅ 加载文档配置: {doc_config.name} ({len(doc_config.ocr_results)} 个 OCR 结果)") except Exception as e: self.logger.error(f"加载文档配置失败: {yaml_path}, 错误: {e}") return documents def get_ocr_tool(self, tool_id: str) -> Optional[OCRToolConfig]: """获取 OCR 工具配置""" return self.ocr_tools.get(tool_id) def get_document(self, doc_name: str) -> Optional[DocumentConfig]: """获取文档配置""" return self.documents.get(doc_name) def list_documents(self) -> List[str]: """列出所有文档""" return list(self.documents.keys()) def list_ocr_tools(self) -> List[str]: """列出所有 OCR 工具""" return list(self.ocr_tools.keys()) def get_data_sources(self) -> List[DataSource]: """ 生成数据源列表(供 OCRValidator 使用) 从文档配置自动生成 data_sources """ data_sources = [] for doc_name, doc_config in self.documents.items(): base_dir = Path(doc_config.base_dir) for ocr_result in doc_config.ocr_results: if not ocr_result.enabled: continue # 构建完整路径 ocr_out_dir = str(base_dir / ocr_result.result_dir) if ocr_result.image_dir: src_img_dir = str(base_dir / ocr_result.image_dir) else: # 如果未指定图片目录,使用结果目录 src_img_dir = str(base_dir / ocr_result.result_dir / doc_name) # 🎯 使用 result_dir 生成数据源名称(更唯一、更清晰) source_name = f"{doc_name}_{ocr_result.result_dir}" data_source = DataSource( name=source_name, ocr_tool=ocr_result.tool, ocr_out_dir=ocr_out_dir, src_img_dir=src_img_dir, description=ocr_result.description or ocr_result.result_dir ) data_sources.append(data_source) return data_sources def get_config_value(self, key_path: str, default=None): """ 获取配置值(支持点号路径) Examples: get_config_value('styles.font_size') get_config_value('ocr.min_text_length') """ keys = key_path.split('.') value = self.global_config for key in keys: if isinstance(value, dict): value = value.get(key) else: return default return value if value is not None else default def to_validator_config(self) -> Dict: """ 转换为 OCRValidator 所需的配置格式 Returns: 包含 data_sources 和 ocr.tools 的配置字典 """ # 构建 data_sources 列表 data_sources_list = [] for ds in self.get_data_sources(): data_sources_list.append({ 'name': ds.name, 'ocr_tool': ds.ocr_tool, 'ocr_out_dir': ds.ocr_out_dir, 'src_img_dir': ds.src_img_dir }) # 构建 ocr.tools 字典 ocr_tools_dict = {} for tool_id, tool_config in self.ocr_tools.items(): ocr_tools_dict[tool_id] = tool_config.to_dict() # 返回完整配置 config = self.global_config.copy() config['data_sources'] = data_sources_list # 确保 ocr.tools 存在 if 'ocr' not in config: config['ocr'] = {} config['ocr']['tools'] = ocr_tools_dict return config # ============================================================================ # 便捷函数 # ============================================================================ def load_config(config_dir: str = "config") -> ConfigManager: """加载配置""" return ConfigManager(config_dir)