|
|
@@ -0,0 +1,360 @@
|
|
|
+"""
|
|
|
+配置管理器
|
|
|
+支持分层配置和自动发现数据源
|
|
|
+支持 Jinja2 模板变量
|
|
|
+"""
|
|
|
+
|
|
|
+import yaml
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, List, Optional, Any
|
|
|
+from dataclasses import dataclass, field
|
|
|
+import logging
|
|
|
+from jinja2 import Template # 🎯 新增
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class OCRToolConfig:
|
|
|
+ """OCR 工具配置"""
|
|
|
+ name: str
|
|
|
+ description: str
|
|
|
+ json_structure: str
|
|
|
+ text_field: str
|
|
|
+ bbox_field: str
|
|
|
+ category_field: str
|
|
|
+ confidence_field: str = "confidence"
|
|
|
+ parsing_results_field: Optional[str] = None
|
|
|
+ rec_texts_field: Optional[str] = None
|
|
|
+ rec_boxes_field: Optional[str] = None
|
|
|
+ table_body_field: Optional[str] = None
|
|
|
+ table_cells_field: Optional[str] = None
|
|
|
+ img_path_field: Optional[str] = None
|
|
|
+ rotation: Dict[str, Any] = field(default_factory=dict)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_dict(cls, tool_id: str, data: Dict) -> 'OCRToolConfig':
|
|
|
+ """从字典创建"""
|
|
|
+ return cls(
|
|
|
+ name=data.get('name', tool_id),
|
|
|
+ description=data.get('description', ''),
|
|
|
+ json_structure=data.get('json_structure', 'object'),
|
|
|
+ text_field=data.get('text_field', 'text'),
|
|
|
+ bbox_field=data.get('bbox_field', 'bbox'),
|
|
|
+ category_field=data.get('category_field', 'category'),
|
|
|
+ confidence_field=data.get('confidence_field', 'confidence'),
|
|
|
+ parsing_results_field=data.get('parsing_results_field'),
|
|
|
+ rec_texts_field=data.get('rec_texts_field'),
|
|
|
+ rec_boxes_field=data.get('rec_boxes_field'),
|
|
|
+ table_body_field=data.get('table_body_field'),
|
|
|
+ table_cells_field=data.get('table_cells_field'),
|
|
|
+ img_path_field=data.get('img_path_field'),
|
|
|
+ rotation=data.get('rotation', {})
|
|
|
+ )
|
|
|
+
|
|
|
+ def to_dict(self) -> Dict:
|
|
|
+ """转换为字典(用于 OCRValidator)"""
|
|
|
+ config_dict = {
|
|
|
+ 'name': self.name,
|
|
|
+ 'description': self.description,
|
|
|
+ 'json_structure': self.json_structure,
|
|
|
+ 'text_field': self.text_field,
|
|
|
+ 'bbox_field': self.bbox_field,
|
|
|
+ 'category_field': self.category_field,
|
|
|
+ 'confidence_field': self.confidence_field,
|
|
|
+ 'rotation': self.rotation
|
|
|
+ }
|
|
|
+
|
|
|
+ # 添加可选字段
|
|
|
+ if self.parsing_results_field:
|
|
|
+ config_dict['parsing_results_field'] = self.parsing_results_field
|
|
|
+ if self.rec_texts_field:
|
|
|
+ config_dict['rec_texts_field'] = self.rec_texts_field
|
|
|
+ if self.rec_boxes_field:
|
|
|
+ config_dict['rec_boxes_field'] = self.rec_boxes_field
|
|
|
+ if self.table_body_field:
|
|
|
+ config_dict['table_body_field'] = self.table_body_field
|
|
|
+ if self.table_cells_field:
|
|
|
+ config_dict['table_cells_field'] = self.table_cells_field
|
|
|
+ if self.img_path_field:
|
|
|
+ config_dict['img_path_field'] = self.img_path_field
|
|
|
+
|
|
|
+ return config_dict
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class OCRResultConfig:
|
|
|
+ """OCR 结果配置"""
|
|
|
+ tool: str
|
|
|
+ result_dir: str
|
|
|
+ image_dir: Optional[str]
|
|
|
+ description: str = ""
|
|
|
+ enabled: bool = True
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_dict(cls, data: Dict, context: Dict = None) -> 'OCRResultConfig':
|
|
|
+ """
|
|
|
+ 🎯 从字典创建(支持 Jinja2 模板)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data: 配置数据
|
|
|
+ context: 模板上下文(如 {'name': '德_内蒙古银行照'})
|
|
|
+ """
|
|
|
+ # 🎯 渲染模板
|
|
|
+ if context:
|
|
|
+ result_dir = cls._render_template(data['result_dir'], context)
|
|
|
+ image_dir = cls._render_template(data.get('image_dir'), context) if data.get('image_dir') else None
|
|
|
+ description = cls._render_template(data.get('description', ''), context)
|
|
|
+ else:
|
|
|
+ result_dir = data['result_dir']
|
|
|
+ image_dir = data.get('image_dir')
|
|
|
+ description = data.get('description', '')
|
|
|
+
|
|
|
+ return cls(
|
|
|
+ tool=data['tool'],
|
|
|
+ result_dir=result_dir,
|
|
|
+ image_dir=image_dir,
|
|
|
+ description=description,
|
|
|
+ enabled=data.get('enabled', True)
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _render_template(template_str: Optional[str], context: Dict) -> Optional[str]:
|
|
|
+ """🎯 渲染 Jinja2 模板"""
|
|
|
+ if not template_str:
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ template = Template(template_str)
|
|
|
+ return template.render(context)
|
|
|
+ except Exception as e:
|
|
|
+ logging.warning(f"模板渲染失败: {template_str}, 错误: {e}")
|
|
|
+ return template_str
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class DocumentConfig:
|
|
|
+ """文档配置"""
|
|
|
+ name: str
|
|
|
+ base_dir: str
|
|
|
+ ocr_results: List[OCRResultConfig] = field(default_factory=list)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_dict(cls, data: Dict) -> 'DocumentConfig':
|
|
|
+ """从字典创建(支持 Jinja2 模板)"""
|
|
|
+ doc_data = data.get('document', data)
|
|
|
+
|
|
|
+ # 🎯 构建模板上下文
|
|
|
+ context = {
|
|
|
+ 'name': doc_data['name'],
|
|
|
+ 'base_dir': doc_data['base_dir']
|
|
|
+ }
|
|
|
+
|
|
|
+ return cls(
|
|
|
+ name=doc_data['name'],
|
|
|
+ base_dir=doc_data['base_dir'],
|
|
|
+ ocr_results=[
|
|
|
+ OCRResultConfig.from_dict(r, context)
|
|
|
+ for r in doc_data.get('ocr_results', [])
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class DataSource:
|
|
|
+ """数据源(用于 OCRValidator)"""
|
|
|
+ name: str
|
|
|
+ ocr_tool: str
|
|
|
+ ocr_out_dir: str
|
|
|
+ src_img_dir: str
|
|
|
+ description: str = ""
|
|
|
+
|
|
|
+
|
|
|
+class ConfigManager:
|
|
|
+ """配置管理器"""
|
|
|
+
|
|
|
+ def __init__(self, config_dir: str = "config"):
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ config_dir: 配置文件目录
|
|
|
+ """
|
|
|
+ self.config_dir = Path(config_dir)
|
|
|
+ self.logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+ # 加载配置
|
|
|
+ self.global_config = self._load_global_config()
|
|
|
+ self.ocr_tools = self._load_ocr_tools()
|
|
|
+ self.documents = self._load_documents()
|
|
|
+
|
|
|
+ def _load_global_config(self) -> Dict:
|
|
|
+ """加载全局配置"""
|
|
|
+ config_file = self.config_dir / "global.yaml"
|
|
|
+
|
|
|
+ if not config_file.exists():
|
|
|
+ self.logger.warning(f"全局配置文件不存在: {config_file}")
|
|
|
+ return {}
|
|
|
+
|
|
|
+ with open(config_file, 'r', encoding='utf-8') as f:
|
|
|
+ return yaml.safe_load(f) or {}
|
|
|
+
|
|
|
+ def _load_ocr_tools(self) -> Dict[str, OCRToolConfig]:
|
|
|
+ """加载 OCR 工具配置(从 global.yaml)"""
|
|
|
+ tools_data = self.global_config.get('ocr', {}).get('tools', {})
|
|
|
+
|
|
|
+ tools = {}
|
|
|
+ for tool_id, tool_data in tools_data.items():
|
|
|
+ tools[tool_id] = OCRToolConfig.from_dict(tool_id, tool_data)
|
|
|
+
|
|
|
+ return tools
|
|
|
+
|
|
|
+ def _load_documents(self) -> Dict[str, DocumentConfig]:
|
|
|
+ """加载文档配置(支持 Jinja2 模板)"""
|
|
|
+ documents = {}
|
|
|
+
|
|
|
+ # 从 global.yaml 读取文档配置文件列表
|
|
|
+ doc_files = self.global_config.get('data_sources', [])
|
|
|
+
|
|
|
+ for doc_file in doc_files:
|
|
|
+ # 支持相对路径和绝对路径
|
|
|
+ if not doc_file.endswith('.yaml'):
|
|
|
+ doc_file = f"{doc_file}.yaml"
|
|
|
+
|
|
|
+ yaml_path = self.config_dir / doc_file
|
|
|
+
|
|
|
+ if not yaml_path.exists():
|
|
|
+ self.logger.warning(f"文档配置文件不存在: {yaml_path}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(yaml_path, 'r', encoding='utf-8') as f:
|
|
|
+ data = yaml.safe_load(f)
|
|
|
+
|
|
|
+ # 🎯 使用支持 Jinja2 的解析方法
|
|
|
+ doc_config = DocumentConfig.from_dict(data)
|
|
|
+ documents[doc_config.name] = doc_config
|
|
|
+
|
|
|
+ self.logger.info(f"✅ 加载文档配置: {doc_config.name} ({len(doc_config.ocr_results)} 个 OCR 结果)")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"加载文档配置失败: {yaml_path}, 错误: {e}")
|
|
|
+
|
|
|
+ return documents
|
|
|
+
|
|
|
+ def get_ocr_tool(self, tool_id: str) -> Optional[OCRToolConfig]:
|
|
|
+ """获取 OCR 工具配置"""
|
|
|
+ return self.ocr_tools.get(tool_id)
|
|
|
+
|
|
|
+ def get_document(self, doc_name: str) -> Optional[DocumentConfig]:
|
|
|
+ """获取文档配置"""
|
|
|
+ return self.documents.get(doc_name)
|
|
|
+
|
|
|
+ def list_documents(self) -> List[str]:
|
|
|
+ """列出所有文档"""
|
|
|
+ return list(self.documents.keys())
|
|
|
+
|
|
|
+ def list_ocr_tools(self) -> List[str]:
|
|
|
+ """列出所有 OCR 工具"""
|
|
|
+ return list(self.ocr_tools.keys())
|
|
|
+
|
|
|
+ def get_data_sources(self) -> List[DataSource]:
|
|
|
+ """
|
|
|
+ 生成数据源列表(供 OCRValidator 使用)
|
|
|
+
|
|
|
+ 从文档配置自动生成 data_sources
|
|
|
+ """
|
|
|
+ data_sources = []
|
|
|
+
|
|
|
+ for doc_name, doc_config in self.documents.items():
|
|
|
+ base_dir = Path(doc_config.base_dir)
|
|
|
+
|
|
|
+ for ocr_result in doc_config.ocr_results:
|
|
|
+ if not ocr_result.enabled:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 构建完整路径
|
|
|
+ ocr_out_dir = str(base_dir / ocr_result.result_dir)
|
|
|
+
|
|
|
+ if ocr_result.image_dir:
|
|
|
+ src_img_dir = str(base_dir / ocr_result.image_dir)
|
|
|
+ else:
|
|
|
+ # 如果未指定图片目录,使用结果目录
|
|
|
+ src_img_dir = str(base_dir / ocr_result.result_dir / doc_name)
|
|
|
+
|
|
|
+ # 生成数据源名称
|
|
|
+ if ocr_result.description:
|
|
|
+ source_name = f"{doc_name}_{ocr_result.description}"
|
|
|
+ else:
|
|
|
+ tool_config = self.get_ocr_tool(ocr_result.tool)
|
|
|
+ tool_name = tool_config.name if tool_config else ocr_result.tool
|
|
|
+ source_name = f"{doc_name}_{tool_name}"
|
|
|
+
|
|
|
+ data_source = DataSource(
|
|
|
+ name=source_name,
|
|
|
+ ocr_tool=ocr_result.tool,
|
|
|
+ ocr_out_dir=ocr_out_dir,
|
|
|
+ src_img_dir=src_img_dir,
|
|
|
+ description=ocr_result.description or f"{doc_name} 使用 {ocr_result.tool}"
|
|
|
+ )
|
|
|
+
|
|
|
+ data_sources.append(data_source)
|
|
|
+
|
|
|
+ return data_sources
|
|
|
+
|
|
|
+ def get_config_value(self, key_path: str, default=None):
|
|
|
+ """
|
|
|
+ 获取配置值(支持点号路径)
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ get_config_value('styles.font_size')
|
|
|
+ get_config_value('ocr.min_text_length')
|
|
|
+ """
|
|
|
+ keys = key_path.split('.')
|
|
|
+ value = self.global_config
|
|
|
+
|
|
|
+ for key in keys:
|
|
|
+ if isinstance(value, dict):
|
|
|
+ value = value.get(key)
|
|
|
+ else:
|
|
|
+ return default
|
|
|
+
|
|
|
+ return value if value is not None else default
|
|
|
+
|
|
|
+ def to_validator_config(self) -> Dict:
|
|
|
+ """
|
|
|
+ 转换为 OCRValidator 所需的配置格式
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 包含 data_sources 和 ocr.tools 的配置字典
|
|
|
+ """
|
|
|
+ # 构建 data_sources 列表
|
|
|
+ data_sources_list = []
|
|
|
+ for ds in self.get_data_sources():
|
|
|
+ data_sources_list.append({
|
|
|
+ 'name': ds.name,
|
|
|
+ 'ocr_tool': ds.ocr_tool,
|
|
|
+ 'ocr_out_dir': ds.ocr_out_dir,
|
|
|
+ 'src_img_dir': ds.src_img_dir
|
|
|
+ })
|
|
|
+
|
|
|
+ # 构建 ocr.tools 字典
|
|
|
+ ocr_tools_dict = {}
|
|
|
+ for tool_id, tool_config in self.ocr_tools.items():
|
|
|
+ ocr_tools_dict[tool_id] = tool_config.to_dict()
|
|
|
+
|
|
|
+ # 返回完整配置
|
|
|
+ config = self.global_config.copy()
|
|
|
+ config['data_sources'] = data_sources_list
|
|
|
+
|
|
|
+ # 确保 ocr.tools 存在
|
|
|
+ if 'ocr' not in config:
|
|
|
+ config['ocr'] = {}
|
|
|
+ config['ocr']['tools'] = ocr_tools_dict
|
|
|
+
|
|
|
+ return config
|
|
|
+
|
|
|
+
|
|
|
+# ============================================================================
|
|
|
+# 便捷函数
|
|
|
+# ============================================================================
|
|
|
+
|
|
|
+def load_config(config_dir: str = "config") -> ConfigManager:
|
|
|
+ """加载配置"""
|
|
|
+ return ConfigManager(config_dir)
|