Bladeren bron

feat: 新增配置管理器,支持分层配置和自动发现数据源,集成 Jinja2 模板变量

zhch158_admin 1 week geleden
bovenliggende
commit
afc9e3d481
1 gewijzigde bestanden met toevoegingen van 360 en 0 verwijderingen
  1. 360 0
      config_manager.py

+ 360 - 0
config_manager.py

@@ -0,0 +1,360 @@
+"""
+配置管理器
+支持分层配置和自动发现数据源
+支持 Jinja2 模板变量
+"""
+
+import yaml
+from pathlib import Path
+from typing import Dict, List, Optional, Any
+from dataclasses import dataclass, field
+import logging
+from jinja2 import Template  # 🎯 新增
+
+
+@dataclass
+class OCRToolConfig:
+    """OCR 工具配置"""
+    name: str
+    description: str
+    json_structure: str
+    text_field: str
+    bbox_field: str
+    category_field: str
+    confidence_field: str = "confidence"
+    parsing_results_field: Optional[str] = None
+    rec_texts_field: Optional[str] = None
+    rec_boxes_field: Optional[str] = None
+    table_body_field: Optional[str] = None
+    table_cells_field: Optional[str] = None
+    img_path_field: Optional[str] = None
+    rotation: Dict[str, Any] = field(default_factory=dict)
+    
+    @classmethod
+    def from_dict(cls, tool_id: str, data: Dict) -> 'OCRToolConfig':
+        """从字典创建"""
+        return cls(
+            name=data.get('name', tool_id),
+            description=data.get('description', ''),
+            json_structure=data.get('json_structure', 'object'),
+            text_field=data.get('text_field', 'text'),
+            bbox_field=data.get('bbox_field', 'bbox'),
+            category_field=data.get('category_field', 'category'),
+            confidence_field=data.get('confidence_field', 'confidence'),
+            parsing_results_field=data.get('parsing_results_field'),
+            rec_texts_field=data.get('rec_texts_field'),
+            rec_boxes_field=data.get('rec_boxes_field'),
+            table_body_field=data.get('table_body_field'),
+            table_cells_field=data.get('table_cells_field'),
+            img_path_field=data.get('img_path_field'),
+            rotation=data.get('rotation', {})
+        )
+    
+    def to_dict(self) -> Dict:
+        """转换为字典(用于 OCRValidator)"""
+        config_dict = {
+            'name': self.name,
+            'description': self.description,
+            'json_structure': self.json_structure,
+            'text_field': self.text_field,
+            'bbox_field': self.bbox_field,
+            'category_field': self.category_field,
+            'confidence_field': self.confidence_field,
+            'rotation': self.rotation
+        }
+        
+        # 添加可选字段
+        if self.parsing_results_field:
+            config_dict['parsing_results_field'] = self.parsing_results_field
+        if self.rec_texts_field:
+            config_dict['rec_texts_field'] = self.rec_texts_field
+        if self.rec_boxes_field:
+            config_dict['rec_boxes_field'] = self.rec_boxes_field
+        if self.table_body_field:
+            config_dict['table_body_field'] = self.table_body_field
+        if self.table_cells_field:
+            config_dict['table_cells_field'] = self.table_cells_field
+        if self.img_path_field:
+            config_dict['img_path_field'] = self.img_path_field
+        
+        return config_dict
+
+
+@dataclass
+class OCRResultConfig:
+    """OCR 结果配置"""
+    tool: str
+    result_dir: str
+    image_dir: Optional[str]
+    description: str = ""
+    enabled: bool = True
+    
+    @classmethod
+    def from_dict(cls, data: Dict, context: Dict = None) -> 'OCRResultConfig':
+        """
+        🎯 从字典创建(支持 Jinja2 模板)
+        
+        Args:
+            data: 配置数据
+            context: 模板上下文(如 {'name': '德_内蒙古银行照'})
+        """
+        # 🎯 渲染模板
+        if context:
+            result_dir = cls._render_template(data['result_dir'], context)
+            image_dir = cls._render_template(data.get('image_dir'), context) if data.get('image_dir') else None
+            description = cls._render_template(data.get('description', ''), context)
+        else:
+            result_dir = data['result_dir']
+            image_dir = data.get('image_dir')
+            description = data.get('description', '')
+        
+        return cls(
+            tool=data['tool'],
+            result_dir=result_dir,
+            image_dir=image_dir,
+            description=description,
+            enabled=data.get('enabled', True)
+        )
+    
+    @staticmethod
+    def _render_template(template_str: Optional[str], context: Dict) -> Optional[str]:
+        """🎯 渲染 Jinja2 模板"""
+        if not template_str:
+            return None
+        
+        try:
+            template = Template(template_str)
+            return template.render(context)
+        except Exception as e:
+            logging.warning(f"模板渲染失败: {template_str}, 错误: {e}")
+            return template_str
+
+
+@dataclass
+class DocumentConfig:
+    """文档配置"""
+    name: str
+    base_dir: str
+    ocr_results: List[OCRResultConfig] = field(default_factory=list)
+    
+    @classmethod
+    def from_dict(cls, data: Dict) -> 'DocumentConfig':
+        """从字典创建(支持 Jinja2 模板)"""
+        doc_data = data.get('document', data)
+        
+        # 🎯 构建模板上下文
+        context = {
+            'name': doc_data['name'],
+            'base_dir': doc_data['base_dir']
+        }
+        
+        return cls(
+            name=doc_data['name'],
+            base_dir=doc_data['base_dir'],
+            ocr_results=[
+                OCRResultConfig.from_dict(r, context) 
+                for r in doc_data.get('ocr_results', [])
+            ]
+        )
+
+
+@dataclass
+class DataSource:
+    """数据源(用于 OCRValidator)"""
+    name: str
+    ocr_tool: str
+    ocr_out_dir: str
+    src_img_dir: str
+    description: str = ""
+
+
+class ConfigManager:
+    """配置管理器"""
+    
+    def __init__(self, config_dir: str = "config"):
+        """
+        Args:
+            config_dir: 配置文件目录
+        """
+        self.config_dir = Path(config_dir)
+        self.logger = logging.getLogger(__name__)
+        
+        # 加载配置
+        self.global_config = self._load_global_config()
+        self.ocr_tools = self._load_ocr_tools()
+        self.documents = self._load_documents()
+    
+    def _load_global_config(self) -> Dict:
+        """加载全局配置"""
+        config_file = self.config_dir / "global.yaml"
+        
+        if not config_file.exists():
+            self.logger.warning(f"全局配置文件不存在: {config_file}")
+            return {}
+        
+        with open(config_file, 'r', encoding='utf-8') as f:
+            return yaml.safe_load(f) or {}
+    
+    def _load_ocr_tools(self) -> Dict[str, OCRToolConfig]:
+        """加载 OCR 工具配置(从 global.yaml)"""
+        tools_data = self.global_config.get('ocr', {}).get('tools', {})
+        
+        tools = {}
+        for tool_id, tool_data in tools_data.items():
+            tools[tool_id] = OCRToolConfig.from_dict(tool_id, tool_data)
+        
+        return tools
+    
+    def _load_documents(self) -> Dict[str, DocumentConfig]:
+        """加载文档配置(支持 Jinja2 模板)"""
+        documents = {}
+        
+        # 从 global.yaml 读取文档配置文件列表
+        doc_files = self.global_config.get('data_sources', [])
+        
+        for doc_file in doc_files:
+            # 支持相对路径和绝对路径
+            if not doc_file.endswith('.yaml'):
+                doc_file = f"{doc_file}.yaml"
+            
+            yaml_path = self.config_dir / doc_file
+            
+            if not yaml_path.exists():
+                self.logger.warning(f"文档配置文件不存在: {yaml_path}")
+                continue
+            
+            try:
+                with open(yaml_path, 'r', encoding='utf-8') as f:
+                    data = yaml.safe_load(f)
+                
+                # 🎯 使用支持 Jinja2 的解析方法
+                doc_config = DocumentConfig.from_dict(data)
+                documents[doc_config.name] = doc_config
+                
+                self.logger.info(f"✅ 加载文档配置: {doc_config.name} ({len(doc_config.ocr_results)} 个 OCR 结果)")
+                
+            except Exception as e:
+                self.logger.error(f"加载文档配置失败: {yaml_path}, 错误: {e}")
+        
+        return documents
+    
+    def get_ocr_tool(self, tool_id: str) -> Optional[OCRToolConfig]:
+        """获取 OCR 工具配置"""
+        return self.ocr_tools.get(tool_id)
+    
+    def get_document(self, doc_name: str) -> Optional[DocumentConfig]:
+        """获取文档配置"""
+        return self.documents.get(doc_name)
+    
+    def list_documents(self) -> List[str]:
+        """列出所有文档"""
+        return list(self.documents.keys())
+    
+    def list_ocr_tools(self) -> List[str]:
+        """列出所有 OCR 工具"""
+        return list(self.ocr_tools.keys())
+    
+    def get_data_sources(self) -> List[DataSource]:
+        """
+        生成数据源列表(供 OCRValidator 使用)
+        
+        从文档配置自动生成 data_sources
+        """
+        data_sources = []
+        
+        for doc_name, doc_config in self.documents.items():
+            base_dir = Path(doc_config.base_dir)
+            
+            for ocr_result in doc_config.ocr_results:
+                if not ocr_result.enabled:
+                    continue
+                
+                # 构建完整路径
+                ocr_out_dir = str(base_dir / ocr_result.result_dir)
+                
+                if ocr_result.image_dir:
+                    src_img_dir = str(base_dir / ocr_result.image_dir)
+                else:
+                    # 如果未指定图片目录,使用结果目录
+                    src_img_dir = str(base_dir / ocr_result.result_dir / doc_name)
+                
+                # 生成数据源名称
+                if ocr_result.description:
+                    source_name = f"{doc_name}_{ocr_result.description}"
+                else:
+                    tool_config = self.get_ocr_tool(ocr_result.tool)
+                    tool_name = tool_config.name if tool_config else ocr_result.tool
+                    source_name = f"{doc_name}_{tool_name}"
+                
+                data_source = DataSource(
+                    name=source_name,
+                    ocr_tool=ocr_result.tool,
+                    ocr_out_dir=ocr_out_dir,
+                    src_img_dir=src_img_dir,
+                    description=ocr_result.description or f"{doc_name} 使用 {ocr_result.tool}"
+                )
+                
+                data_sources.append(data_source)
+        
+        return data_sources
+    
+    def get_config_value(self, key_path: str, default=None):
+        """
+        获取配置值(支持点号路径)
+        
+        Examples:
+            get_config_value('styles.font_size')
+            get_config_value('ocr.min_text_length')
+        """
+        keys = key_path.split('.')
+        value = self.global_config
+        
+        for key in keys:
+            if isinstance(value, dict):
+                value = value.get(key)
+            else:
+                return default
+        
+        return value if value is not None else default
+    
+    def to_validator_config(self) -> Dict:
+        """
+        转换为 OCRValidator 所需的配置格式
+        
+        Returns:
+            包含 data_sources 和 ocr.tools 的配置字典
+        """
+        # 构建 data_sources 列表
+        data_sources_list = []
+        for ds in self.get_data_sources():
+            data_sources_list.append({
+                'name': ds.name,
+                'ocr_tool': ds.ocr_tool,
+                'ocr_out_dir': ds.ocr_out_dir,
+                'src_img_dir': ds.src_img_dir
+            })
+        
+        # 构建 ocr.tools 字典
+        ocr_tools_dict = {}
+        for tool_id, tool_config in self.ocr_tools.items():
+            ocr_tools_dict[tool_id] = tool_config.to_dict()
+        
+        # 返回完整配置
+        config = self.global_config.copy()
+        config['data_sources'] = data_sources_list
+        
+        # 确保 ocr.tools 存在
+        if 'ocr' not in config:
+            config['ocr'] = {}
+        config['ocr']['tools'] = ocr_tools_dict
+        
+        return config
+
+
+# ============================================================================
+# 便捷函数
+# ============================================================================
+
+def load_config(config_dir: str = "config") -> ConfigManager:
+    """加载配置"""
+    return ConfigManager(config_dir)