test_providers_hook.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """Tests for providers_hook bridge — 验证 Bridge 正确从 fileLoc 加载,不 fallback。"""
  2. import importlib
  3. import inspect
  4. import io
  5. import sys
  6. import pytest
  7. from loguru import logger
  8. @pytest.fixture(autouse=True)
  9. def _clean_providers_hook_cache():
  10. """每个用例前清除 providers_hook 缓存,避免交叉污染。"""
  11. for mod in list(sys.modules.keys()):
  12. if "providers_hook" in mod:
  13. del sys.modules[mod]
  14. class TestBridgeLoadsFromFileLoc:
  15. """Bridge 应成功从 fileLoc 加载,不触发 WARNING。"""
  16. def test_no_warning_on_import(self):
  17. """导入 providers_hook 不应产生 WARNING 日志。"""
  18. # 捕获 loguru 输出
  19. logger.remove()
  20. buf = io.StringIO()
  21. logger.add(buf, format="{level} | {message}", level="DEBUG")
  22. import ocr_utils.hooks.providers_hook # noqa: F811
  23. importlib.reload(ocr_utils.hooks.providers_hook)
  24. log_output = buf.getvalue()
  25. hook_lines = [l for l in log_output.split("\n") if "providers_hook" in l and l.strip()]
  26. warnings = [l for l in hook_lines if "WARNING" in l]
  27. assert not warnings, f"providers_hook 不应输出 WARNING: {warnings}"
  28. debugs = [l for l in hook_lines if "DEBUG" in l]
  29. assert any("loaded from fileLoc" in l for l in debugs), \
  30. f"应产生 'loaded from fileLoc' DEBUG 日志, 实际: {debugs}"
  31. def test_symbols_from_fileloc(self):
  32. """resolve_providers 和 apply_providers_to_session 应来自 fileLoc device_and_provider 模块。"""
  33. import ocr_utils.hooks.providers_hook
  34. importlib.reload(ocr_utils.hooks.providers_hook)
  35. hook = ocr_utils.hooks.providers_hook
  36. assert "device_and_provider" in hook.resolve_providers.__module__, \
  37. f"resolve_providers 来自 {hook.resolve_providers.__module__}, 期望 device_and_provider"
  38. assert "device_and_provider" in hook.apply_providers_to_session.__module__, \
  39. f"apply_providers_to_session 来自 {hook.apply_providers_to_session.__module__}, 期望 device_and_provider"
  40. def test_apply_providers_wrapper_signature(self):
  41. """apply_providers 包装函数签名应为 (model, config, session_paths)。"""
  42. import ocr_utils.hooks.providers_hook
  43. importlib.reload(ocr_utils.hooks.providers_hook)
  44. hook = ocr_utils.hooks.providers_hook
  45. sig = inspect.signature(hook.apply_providers)
  46. param_names = list(sig.parameters.keys())
  47. assert param_names == ["model", "config", "session_paths"], \
  48. f"期望签名 (model, config, session_paths), 实际: {param_names}"
  49. def test_default_providers_constants(self):
  50. """DEFAULT_CPU_PROVIDERS 和 DEFAULT_GPU_PROVIDERS 应正确导出。"""
  51. import ocr_utils.hooks.providers_hook
  52. importlib.reload(ocr_utils.hooks.providers_hook)
  53. hook = ocr_utils.hooks.providers_hook
  54. assert hook.DEFAULT_CPU_PROVIDERS == ["CPUExecutionProvider"]
  55. assert hook.DEFAULT_GPU_PROVIDERS == ["CUDAExecutionProvider", "CPUExecutionProvider"]