providers_hook.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. """
  2. Providers Hook - Bridge 模块
  3. 本文件作为 bridge,当 fileLoc 安装时,从 src.shared.model_manager.device_and_provider 导入,
  4. 当 fileLoc 未安装时,使用下方 fallback 实现。
  5. """
  6. import sys
  7. from pathlib import Path
  8. from typing import Any
  9. from loguru import logger
  10. # 动态查找 fileLoc src 目录
  11. _current_file = Path(__file__).resolve()
  12. for _parent in _current_file.parents:
  13. _src_dir = _parent / "src"
  14. if _src_dir.exists() and str(_parent) not in sys.path:
  15. sys.path.insert(0, str(_parent))
  16. break
  17. try:
  18. from src.shared.model_manager.device_and_provider import (
  19. resolve_providers,
  20. apply_providers_to_session,
  21. apply_providers_to_model,
  22. apply_device_strategy, # 2026-06-03 新增
  23. DEFAULT_CPU_PROVIDERS,
  24. DEFAULT_GPU_PROVIDERS,
  25. )
  26. # 适配: fileLoc 的 apply_providers_to_model 接受 providers list,
  27. # 而 ocr_platform 消费者传入 config dict,由 bridge 在中间层处理
  28. def apply_providers(
  29. model: Any,
  30. config: dict,
  31. session_paths: list[str],
  32. ) -> bool:
  33. providers = resolve_providers(config)
  34. if providers is None:
  35. return False
  36. return apply_providers_to_model(model, providers, session_paths)
  37. logger.debug("[providers_hook] Bridge: loaded from fileLoc")
  38. except ImportError:
  39. # Fallback: 使用本地实现
  40. logger.info("[providers_hook] fileLoc not found, using local implementation")
  41. from typing import List, Optional
  42. DEFAULT_CPU_PROVIDERS = ["CPUExecutionProvider"]
  43. DEFAULT_GPU_PROVIDERS = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  44. def resolve_providers(config: dict) -> Optional[List[str]]:
  45. providers = config.get('providers')
  46. if providers is not None:
  47. return providers
  48. device = config.get('device', 'cpu')
  49. if device in ('gpu', 'cuda'):
  50. return DEFAULT_GPU_PROVIDERS
  51. return DEFAULT_CPU_PROVIDERS
  52. def apply_providers_to_session(session: Any, providers: List[str]) -> bool:
  53. if session is None:
  54. return False
  55. try:
  56. if hasattr(session, 'set_providers'):
  57. session.set_providers(providers)
  58. return True
  59. if hasattr(session, 'session') and hasattr(session.session, 'set_providers'):
  60. session.session.set_providers(providers)
  61. return True
  62. except Exception as e:
  63. logger.warning(f"Failed to set providers: {e}")
  64. return False
  65. def apply_providers(model: Any, config: dict, session_paths: List[str]) -> bool:
  66. providers = resolve_providers(config)
  67. if providers is None:
  68. return False
  69. success = False
  70. for path in session_paths:
  71. session = _get_nested_attr(model, path)
  72. if session is not None and apply_providers_to_session(session, providers):
  73. logger.info(f"Applied providers via '{path}': {providers}")
  74. success = True
  75. return success
  76. # 别名,保持向后兼容
  77. apply_providers_to_model = apply_providers
  78. def apply_device_strategy(model: Any, config: dict, model_kind: str) -> bool:
  79. """[fallback] 无 fileLoc 时仅做日志,不做实际设备移动。"""
  80. logger.warning(
  81. f"[providers_hook] fileLoc not found, apply_device_strategy noop: "
  82. f"model_kind={model_kind}, device={config.get('device', 'cpu')}"
  83. )
  84. return True
  85. def _get_nested_attr(obj: Any, attr_path: str) -> Any:
  86. if obj is None:
  87. return None
  88. current = obj
  89. for attr_name in attr_path.split('.'):
  90. try:
  91. current = getattr(current, attr_name)
  92. except AttributeError:
  93. return None
  94. return current