| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355 |
- """
- 配置管理器
- 支持分层配置和自动发现数据源
- 支持 Jinja2 模板变量
- """
- import yaml
- from pathlib import Path
- from typing import Dict, List, Optional, Any
- from dataclasses import dataclass, field
- import logging
- from jinja2 import Template # 🎯 新增
- @dataclass
- class OCRToolConfig:
- """OCR 工具配置"""
- name: str
- description: str
- json_structure: str
- text_field: str
- bbox_field: str
- category_field: str
- confidence_field: str = "confidence"
- parsing_results_field: Optional[str] = None
- rec_texts_field: Optional[str] = None
- rec_boxes_field: Optional[str] = None
- table_body_field: Optional[str] = None
- table_cells_field: Optional[str] = None
- img_path_field: Optional[str] = None
- rotation: Dict[str, Any] = field(default_factory=dict)
-
- @classmethod
- def from_dict(cls, tool_id: str, data: Dict) -> 'OCRToolConfig':
- """从字典创建"""
- return cls(
- name=data.get('name', tool_id),
- description=data.get('description', ''),
- json_structure=data.get('json_structure', 'object'),
- text_field=data.get('text_field', 'text'),
- bbox_field=data.get('bbox_field', 'bbox'),
- category_field=data.get('category_field', 'category'),
- confidence_field=data.get('confidence_field', 'confidence'),
- parsing_results_field=data.get('parsing_results_field'),
- rec_texts_field=data.get('rec_texts_field'),
- rec_boxes_field=data.get('rec_boxes_field'),
- table_body_field=data.get('table_body_field'),
- table_cells_field=data.get('table_cells_field'),
- img_path_field=data.get('img_path_field'),
- rotation=data.get('rotation', {})
- )
-
- def to_dict(self) -> Dict:
- """转换为字典(用于 OCRValidator)"""
- config_dict = {
- 'name': self.name,
- 'description': self.description,
- 'json_structure': self.json_structure,
- 'text_field': self.text_field,
- 'bbox_field': self.bbox_field,
- 'category_field': self.category_field,
- 'confidence_field': self.confidence_field,
- 'rotation': self.rotation
- }
-
- # 添加可选字段
- if self.parsing_results_field:
- config_dict['parsing_results_field'] = self.parsing_results_field
- if self.rec_texts_field:
- config_dict['rec_texts_field'] = self.rec_texts_field
- if self.rec_boxes_field:
- config_dict['rec_boxes_field'] = self.rec_boxes_field
- if self.table_body_field:
- config_dict['table_body_field'] = self.table_body_field
- if self.table_cells_field:
- config_dict['table_cells_field'] = self.table_cells_field
- if self.img_path_field:
- config_dict['img_path_field'] = self.img_path_field
-
- return config_dict
- @dataclass
- class OCRResultConfig:
- """OCR 结果配置"""
- tool: str
- result_dir: str
- image_dir: Optional[str]
- description: str = ""
- enabled: bool = True
-
- @classmethod
- def from_dict(cls, data: Dict, context: Dict = None) -> 'OCRResultConfig':
- """
- 🎯 从字典创建(支持 Jinja2 模板)
-
- Args:
- data: 配置数据
- context: 模板上下文(如 {'name': '德_内蒙古银行照'})
- """
- # 🎯 渲染模板
- if context:
- result_dir = cls._render_template(data['result_dir'], context)
- image_dir = cls._render_template(data.get('image_dir'), context) if data.get('image_dir') else None
- description = cls._render_template(data.get('description', ''), context)
- else:
- result_dir = data['result_dir']
- image_dir = data.get('image_dir')
- description = data.get('description', '')
-
- return cls(
- tool=data['tool'],
- result_dir=result_dir,
- image_dir=image_dir,
- description=description,
- enabled=data.get('enabled', True)
- )
-
- @staticmethod
- def _render_template(template_str: Optional[str], context: Dict) -> Optional[str]:
- """🎯 渲染 Jinja2 模板"""
- if not template_str:
- return None
-
- try:
- template = Template(template_str)
- return template.render(context)
- except Exception as e:
- logging.warning(f"模板渲染失败: {template_str}, 错误: {e}")
- return template_str
- @dataclass
- class DocumentConfig:
- """文档配置"""
- name: str
- base_dir: str
- ocr_results: List[OCRResultConfig] = field(default_factory=list)
-
- @classmethod
- def from_dict(cls, data: Dict) -> 'DocumentConfig':
- """从字典创建(支持 Jinja2 模板)"""
- doc_data = data.get('document', data)
-
- # 🎯 构建模板上下文
- context = {
- 'name': doc_data['name'],
- 'base_dir': doc_data['base_dir']
- }
-
- return cls(
- name=doc_data['name'],
- base_dir=doc_data['base_dir'],
- ocr_results=[
- OCRResultConfig.from_dict(r, context)
- for r in doc_data.get('ocr_results', [])
- ]
- )
- @dataclass
- class DataSource:
- """数据源(用于 OCRValidator)"""
- name: str
- ocr_tool: str
- ocr_out_dir: str
- src_img_dir: str
- description: str = ""
- class ConfigManager:
- """配置管理器"""
-
- def __init__(self, config_dir: str = "config"):
- """
- Args:
- config_dir: 配置文件目录
- """
- self.config_dir = Path(config_dir)
- self.logger = logging.getLogger(__name__)
-
- # 加载配置
- self.global_config = self._load_global_config()
- self.ocr_tools = self._load_ocr_tools()
- self.documents = self._load_documents()
-
- def _load_global_config(self) -> Dict:
- """加载全局配置"""
- config_file = self.config_dir / "global.yaml"
-
- if not config_file.exists():
- self.logger.warning(f"全局配置文件不存在: {config_file}")
- return {}
-
- with open(config_file, 'r', encoding='utf-8') as f:
- return yaml.safe_load(f) or {}
-
- def _load_ocr_tools(self) -> Dict[str, OCRToolConfig]:
- """加载 OCR 工具配置(从 global.yaml)"""
- tools_data = self.global_config.get('ocr', {}).get('tools', {})
-
- tools = {}
- for tool_id, tool_data in tools_data.items():
- tools[tool_id] = OCRToolConfig.from_dict(tool_id, tool_data)
-
- return tools
-
- def _load_documents(self) -> Dict[str, DocumentConfig]:
- """加载文档配置(支持 Jinja2 模板)"""
- documents = {}
-
- # 从 global.yaml 读取文档配置文件列表
- doc_files = self.global_config.get('data_sources', [])
-
- for doc_file in doc_files:
- # 支持相对路径和绝对路径
- if not doc_file.endswith('.yaml'):
- doc_file = f"{doc_file}.yaml"
-
- yaml_path = self.config_dir / doc_file
-
- if not yaml_path.exists():
- self.logger.warning(f"文档配置文件不存在: {yaml_path}")
- continue
-
- try:
- with open(yaml_path, 'r', encoding='utf-8') as f:
- data = yaml.safe_load(f)
-
- # 🎯 使用支持 Jinja2 的解析方法
- doc_config = DocumentConfig.from_dict(data)
- documents[doc_config.name] = doc_config
-
- self.logger.info(f"✅ 加载文档配置: {doc_config.name} ({len(doc_config.ocr_results)} 个 OCR 结果)")
-
- except Exception as e:
- self.logger.error(f"加载文档配置失败: {yaml_path}, 错误: {e}")
-
- return documents
-
- def get_ocr_tool(self, tool_id: str) -> Optional[OCRToolConfig]:
- """获取 OCR 工具配置"""
- return self.ocr_tools.get(tool_id)
-
- def get_document(self, doc_name: str) -> Optional[DocumentConfig]:
- """获取文档配置"""
- return self.documents.get(doc_name)
-
- def list_documents(self) -> List[str]:
- """列出所有文档"""
- return list(self.documents.keys())
-
- def list_ocr_tools(self) -> List[str]:
- """列出所有 OCR 工具"""
- return list(self.ocr_tools.keys())
-
- def get_data_sources(self) -> List[DataSource]:
- """
- 生成数据源列表(供 OCRValidator 使用)
-
- 从文档配置自动生成 data_sources
- """
- data_sources = []
-
- for doc_name, doc_config in self.documents.items():
- base_dir = Path(doc_config.base_dir)
-
- for ocr_result in doc_config.ocr_results:
- if not ocr_result.enabled:
- continue
-
- # 构建完整路径
- ocr_out_dir = str(base_dir / ocr_result.result_dir)
-
- if ocr_result.image_dir:
- src_img_dir = str(base_dir / ocr_result.image_dir)
- else:
- # 如果未指定图片目录,使用结果目录
- src_img_dir = str(base_dir / ocr_result.result_dir / doc_name)
-
- # 🎯 使用 result_dir 生成数据源名称(更唯一、更清晰)
- source_name = f"{doc_name}_{ocr_result.result_dir}"
-
- data_source = DataSource(
- name=source_name,
- ocr_tool=ocr_result.tool,
- ocr_out_dir=ocr_out_dir,
- src_img_dir=src_img_dir,
- description=ocr_result.description or ocr_result.result_dir
- )
-
- data_sources.append(data_source)
-
- return data_sources
-
- def get_config_value(self, key_path: str, default=None):
- """
- 获取配置值(支持点号路径)
-
- Examples:
- get_config_value('styles.font_size')
- get_config_value('ocr.min_text_length')
- """
- keys = key_path.split('.')
- value = self.global_config
-
- for key in keys:
- if isinstance(value, dict):
- value = value.get(key)
- else:
- return default
-
- return value if value is not None else default
-
- def to_validator_config(self) -> Dict:
- """
- 转换为 OCRValidator 所需的配置格式
-
- Returns:
- 包含 data_sources 和 ocr.tools 的配置字典
- """
- # 构建 data_sources 列表
- data_sources_list = []
- for ds in self.get_data_sources():
- data_sources_list.append({
- 'name': ds.name,
- 'ocr_tool': ds.ocr_tool,
- 'ocr_out_dir': ds.ocr_out_dir,
- 'src_img_dir': ds.src_img_dir
- })
-
- # 构建 ocr.tools 字典
- ocr_tools_dict = {}
- for tool_id, tool_config in self.ocr_tools.items():
- ocr_tools_dict[tool_id] = tool_config.to_dict()
-
- # 返回完整配置
- config = self.global_config.copy()
- config['data_sources'] = data_sources_list
-
- # 确保 ocr.tools 存在
- if 'ocr' not in config:
- config['ocr'] = {}
- config['ocr']['tools'] = ocr_tools_dict
-
- return config
- # ============================================================================
- # 便捷函数
- # ============================================================================
- def load_config(config_dir: str = "config") -> ConfigManager:
- """加载配置"""
- return ConfigManager(config_dir)
|