""" Providers Hook - Bridge 模块 本文件作为 bridge,当 fileLoc 安装时,从 src.shared.model_manager.device_and_provider 导入, 当 fileLoc 未安装时,使用下方 fallback 实现。 """ import sys from pathlib import Path from typing import Any from loguru import logger # 动态查找 fileLoc src 目录 _current_file = Path(__file__).resolve() for _parent in _current_file.parents: _src_dir = _parent / "src" if _src_dir.exists() and str(_parent) not in sys.path: sys.path.insert(0, str(_parent)) break try: from src.shared.model_manager.device_and_provider import ( resolve_providers, apply_providers_to_session, apply_providers_to_model, apply_device_strategy, # 2026-06-03 新增 DEFAULT_CPU_PROVIDERS, DEFAULT_GPU_PROVIDERS, ) # 适配: fileLoc 的 apply_providers_to_model 接受 providers list, # 而 ocr_platform 消费者传入 config dict,由 bridge 在中间层处理 def apply_providers( model: Any, config: dict, session_paths: list[str], ) -> bool: providers = resolve_providers(config) if providers is None: return False return apply_providers_to_model(model, providers, session_paths) logger.debug("[providers_hook] Bridge: loaded from fileLoc") except ImportError: # Fallback: 使用本地实现 logger.info("[providers_hook] fileLoc not found, using local implementation") from typing import List, Optional DEFAULT_CPU_PROVIDERS = ["CPUExecutionProvider"] DEFAULT_GPU_PROVIDERS = ["CUDAExecutionProvider", "CPUExecutionProvider"] def resolve_providers(config: dict) -> Optional[List[str]]: providers = config.get('providers') if providers is not None: return providers device = config.get('device', 'cpu') if device in ('gpu', 'cuda'): return DEFAULT_GPU_PROVIDERS return DEFAULT_CPU_PROVIDERS def apply_providers_to_session(session: Any, providers: List[str]) -> bool: if session is None: return False try: if hasattr(session, 'set_providers'): session.set_providers(providers) return True if hasattr(session, 'session') and hasattr(session.session, 'set_providers'): session.session.set_providers(providers) return True except Exception as e: logger.warning(f"Failed to set providers: {e}") return False def apply_providers(model: Any, config: dict, session_paths: List[str]) -> bool: providers = resolve_providers(config) if providers is None: return False success = False for path in session_paths: session = _get_nested_attr(model, path) if session is not None and apply_providers_to_session(session, providers): logger.info(f"Applied providers via '{path}': {providers}") success = True return success # 别名,保持向后兼容 apply_providers_to_model = apply_providers def apply_device_strategy(model: Any, config: dict, model_kind: str) -> bool: """[fallback] 无 fileLoc 时仅做日志,不做实际设备移动。""" logger.warning( f"[providers_hook] fileLoc not found, apply_device_strategy noop: " f"model_kind={model_kind}, device={config.get('device', 'cpu')}" ) return True def _get_nested_attr(obj: Any, attr_path: str) -> Any: if obj is None: return None current = obj for attr_name in attr_path.split('.'): try: current = getattr(current, attr_name) except AttributeError: return None return current