| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- """
- Providers Hook - Bridge 模块
- 本文件作为 bridge,当 fileLoc 安装时,从 src.shared.model_manager.device_and_provider 导入,
- 当 fileLoc 未安装时,使用下方 fallback 实现。
- """
- import sys
- from pathlib import Path
- from typing import Any
- from loguru import logger
- # 动态查找 fileLoc src 目录
- _current_file = Path(__file__).resolve()
- for _parent in _current_file.parents:
- _src_dir = _parent / "src"
- if _src_dir.exists() and str(_parent) not in sys.path:
- sys.path.insert(0, str(_parent))
- break
- try:
- from src.shared.model_manager.device_and_provider import (
- resolve_providers,
- apply_providers_to_session,
- apply_providers_to_model,
- apply_device_strategy, # 2026-06-03 新增
- DEFAULT_CPU_PROVIDERS,
- DEFAULT_GPU_PROVIDERS,
- )
- # 适配: fileLoc 的 apply_providers_to_model 接受 providers list,
- # 而 ocr_platform 消费者传入 config dict,由 bridge 在中间层处理
- def apply_providers(
- model: Any,
- config: dict,
- session_paths: list[str],
- ) -> bool:
- providers = resolve_providers(config)
- if providers is None:
- return False
- return apply_providers_to_model(model, providers, session_paths)
- logger.debug("[providers_hook] Bridge: loaded from fileLoc")
- except ImportError:
- # Fallback: 使用本地实现
- logger.info("[providers_hook] fileLoc not found, using local implementation")
- from typing import List, Optional
- DEFAULT_CPU_PROVIDERS = ["CPUExecutionProvider"]
- DEFAULT_GPU_PROVIDERS = ["CUDAExecutionProvider", "CPUExecutionProvider"]
- def resolve_providers(config: dict) -> Optional[List[str]]:
- providers = config.get('providers')
- if providers is not None:
- return providers
- device = config.get('device', 'cpu')
- if device in ('gpu', 'cuda'):
- return DEFAULT_GPU_PROVIDERS
- return DEFAULT_CPU_PROVIDERS
- def apply_providers_to_session(session: Any, providers: List[str]) -> bool:
- if session is None:
- return False
- try:
- if hasattr(session, 'set_providers'):
- session.set_providers(providers)
- return True
- if hasattr(session, 'session') and hasattr(session.session, 'set_providers'):
- session.session.set_providers(providers)
- return True
- except Exception as e:
- logger.warning(f"Failed to set providers: {e}")
- return False
- def apply_providers(model: Any, config: dict, session_paths: List[str]) -> bool:
- providers = resolve_providers(config)
- if providers is None:
- return False
- success = False
- for path in session_paths:
- session = _get_nested_attr(model, path)
- if session is not None and apply_providers_to_session(session, providers):
- logger.info(f"Applied providers via '{path}': {providers}")
- success = True
- return success
- # 别名,保持向后兼容
- apply_providers_to_model = apply_providers
- def apply_device_strategy(model: Any, config: dict, model_kind: str) -> bool:
- """[fallback] 无 fileLoc 时仅做日志,不做实际设备移动。"""
- logger.warning(
- f"[providers_hook] fileLoc not found, apply_device_strategy noop: "
- f"model_kind={model_kind}, device={config.get('device', 'cpu')}"
- )
- return True
- def _get_nested_attr(obj: Any, attr_path: str) -> Any:
- if obj is None:
- return None
- current = obj
- for attr_name in attr_path.split('.'):
- try:
- current = getattr(current, attr_name)
- except AttributeError:
- return None
- return current
|