| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- """Tests for providers_hook bridge — 验证 Bridge 正确从 fileLoc 加载,不 fallback。"""
- import importlib
- import inspect
- import io
- import sys
- import pytest
- from loguru import logger
- @pytest.fixture(autouse=True)
- def _clean_providers_hook_cache():
- """每个用例前清除 providers_hook 缓存,避免交叉污染。"""
- for mod in list(sys.modules.keys()):
- if "providers_hook" in mod:
- del sys.modules[mod]
- class TestBridgeLoadsFromFileLoc:
- """Bridge 应成功从 fileLoc 加载,不触发 WARNING。"""
- def test_no_warning_on_import(self):
- """导入 providers_hook 不应产生 WARNING 日志。"""
- # 捕获 loguru 输出
- logger.remove()
- buf = io.StringIO()
- logger.add(buf, format="{level} | {message}", level="DEBUG")
- import ocr_utils.hooks.providers_hook # noqa: F811
- importlib.reload(ocr_utils.hooks.providers_hook)
- log_output = buf.getvalue()
- hook_lines = [l for l in log_output.split("\n") if "providers_hook" in l and l.strip()]
- warnings = [l for l in hook_lines if "WARNING" in l]
- assert not warnings, f"providers_hook 不应输出 WARNING: {warnings}"
- debugs = [l for l in hook_lines if "DEBUG" in l]
- assert any("loaded from fileLoc" in l for l in debugs), \
- f"应产生 'loaded from fileLoc' DEBUG 日志, 实际: {debugs}"
- def test_symbols_from_fileloc(self):
- """resolve_providers 和 apply_providers_to_session 应来自 fileLoc device_and_provider 模块。"""
- import ocr_utils.hooks.providers_hook
- importlib.reload(ocr_utils.hooks.providers_hook)
- hook = ocr_utils.hooks.providers_hook
- assert "device_and_provider" in hook.resolve_providers.__module__, \
- f"resolve_providers 来自 {hook.resolve_providers.__module__}, 期望 device_and_provider"
- assert "device_and_provider" in hook.apply_providers_to_session.__module__, \
- f"apply_providers_to_session 来自 {hook.apply_providers_to_session.__module__}, 期望 device_and_provider"
- def test_apply_providers_wrapper_signature(self):
- """apply_providers 包装函数签名应为 (model, config, session_paths)。"""
- import ocr_utils.hooks.providers_hook
- importlib.reload(ocr_utils.hooks.providers_hook)
- hook = ocr_utils.hooks.providers_hook
- sig = inspect.signature(hook.apply_providers)
- param_names = list(sig.parameters.keys())
- assert param_names == ["model", "config", "session_paths"], \
- f"期望签名 (model, config, session_paths), 实际: {param_names}"
- def test_default_providers_constants(self):
- """DEFAULT_CPU_PROVIDERS 和 DEFAULT_GPU_PROVIDERS 应正确导出。"""
- import ocr_utils.hooks.providers_hook
- importlib.reload(ocr_utils.hooks.providers_hook)
- hook = ocr_utils.hooks.providers_hook
- assert hook.DEFAULT_CPU_PROVIDERS == ["CPUExecutionProvider"]
- assert hook.DEFAULT_GPU_PROVIDERS == ["CUDAExecutionProvider", "CPUExecutionProvider"]
|