|
@@ -0,0 +1,188 @@
|
|
|
|
|
+#!/opt/miniconda3/envs/mineru2/bin/python
|
|
|
|
|
+"""测试 GLM-OCR 适配器加载
|
|
|
|
|
+
|
|
|
|
|
+验证:
|
|
|
|
|
+1. 适配器类可以正确导入
|
|
|
|
|
+2. 配置文件可以正确解析
|
|
|
|
|
+3. 适配器可以正确初始化
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import sys
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+
|
|
|
|
|
+# 添加项目根目录到路径
|
|
|
|
|
+project_root = Path(__file__).parents[1]
|
|
|
|
|
+sys.path.insert(0, str(project_root))
|
|
|
|
|
+
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
+
|
|
|
|
|
+def test_import_adapter():
|
|
|
|
|
+ """测试导入适配器"""
|
|
|
|
|
+ logger.info("测试 1: 导入 GLM-OCR 适配器...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ from models.adapters import GLMOCRVLRecognizer
|
|
|
|
|
+ logger.success("✅ GLMOCRVLRecognizer 导入成功")
|
|
|
|
|
+ return True
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ 导入失败: {e}")
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_load_config():
|
|
|
|
|
+ """测试加载配置文件"""
|
|
|
|
|
+ logger.info("测试 2: 加载配置文件...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ import yaml
|
|
|
|
|
+ # 修正配置文件路径
|
|
|
|
|
+ config_path = project_root / "config" / "bank_statement_glm_vl.yaml"
|
|
|
|
|
+ if not config_path.exists():
|
|
|
|
|
+ # 尝试其他可能的路径
|
|
|
|
|
+ config_path = Path(__file__).parent / "config" / "bank_statement_glm_vl.yaml"
|
|
|
|
|
+
|
|
|
|
|
+ if not config_path.exists():
|
|
|
|
|
+ logger.warning(f"⚠️ 配置文件不存在: {config_path}")
|
|
|
|
|
+ logger.warning("跳过配置文件测试")
|
|
|
|
|
+ return True, None # 不算失败
|
|
|
|
|
+
|
|
|
|
|
+ with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
+ config = yaml.safe_load(f)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"配置场景: {config.get('scene_name')}")
|
|
|
|
|
+ logger.info(f"VL模块: {config.get('vl_recognition', {}).get('module')}")
|
|
|
|
|
+ logger.info(f"Layout模块: {config.get('layout_detection', {}).get('module')}")
|
|
|
|
|
+
|
|
|
|
|
+ logger.success("✅ 配置文件加载成功")
|
|
|
|
|
+ return True, config
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ 配置加载失败: {e}")
|
|
|
|
|
+ return False, None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_create_adapter():
|
|
|
|
|
+ """测试创建适配器实例"""
|
|
|
|
|
+ logger.info("测试 3: 创建适配器实例...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ from models.adapters import GLMOCRVLRecognizer
|
|
|
|
|
+
|
|
|
|
|
+ # 简化的配置
|
|
|
|
|
+ config = {
|
|
|
|
|
+ 'module': 'glmocr',
|
|
|
|
|
+ 'api_url': 'http://10.192.72.11:20036/v1/chat/completions',
|
|
|
|
|
+ 'model': 'glm-ocr',
|
|
|
|
|
+ 'max_image_size': 3500,
|
|
|
|
|
+ 'resize_mode': 'max',
|
|
|
|
|
+ 'task_prompt_mapping': {
|
|
|
|
|
+ 'text': 'Text Recognition:',
|
|
|
|
|
+ 'table': 'Table Recognition:',
|
|
|
|
|
+ 'formula': 'Formula Recognition:',
|
|
|
|
|
+ 'seal': 'Seal Recognition:',
|
|
|
|
|
+ },
|
|
|
|
|
+ 'model_params': {
|
|
|
|
|
+ 'connection_pool_size': 128,
|
|
|
|
|
+ 'http_timeout': 300,
|
|
|
|
|
+ 'retry_max_attempts': 2,
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ recognizer = GLMOCRVLRecognizer(config)
|
|
|
|
|
+ logger.info(f"适配器类型: {type(recognizer)}")
|
|
|
|
|
+ logger.info(f"最大图片尺寸: {recognizer.max_image_size}")
|
|
|
|
|
+ logger.info(f"任务提示词: {recognizer.task_prompt_mapping}")
|
|
|
|
|
+
|
|
|
|
|
+ logger.success("✅ 适配器实例创建成功")
|
|
|
|
|
+ return True, recognizer
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ 适配器创建失败: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ return False, None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_initialize_adapter(recognizer):
|
|
|
|
|
+ """测试初始化适配器(需要 API 服务可用)"""
|
|
|
|
|
+ logger.info("测试 4: 初始化适配器(需要 GLM-OCR API 服务)...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ recognizer.initialize()
|
|
|
|
|
+ logger.success("✅ 适配器初始化成功")
|
|
|
|
|
+ logger.info(f"HTTP Session: {recognizer.session}")
|
|
|
|
|
+ return True
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning(f"⚠️ 适配器初始化失败(可能是 API 服务不可用): {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def test_model_factory():
|
|
|
|
|
+ """测试通过工厂创建适配器"""
|
|
|
|
|
+ logger.info("测试 5: 通过 ModelFactory 创建适配器...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ from core.model_factory import ModelFactory
|
|
|
|
|
+
|
|
|
|
|
+ config = {
|
|
|
|
|
+ 'module': 'glmocr',
|
|
|
|
|
+ 'api_url': 'http://10.192.72.11:20036/v1/chat/completions',
|
|
|
|
|
+ 'model': 'glm-ocr',
|
|
|
|
|
+ 'max_image_size': 3500,
|
|
|
|
|
+ 'model_params': {
|
|
|
|
|
+ 'connection_pool_size': 128,
|
|
|
|
|
+ 'http_timeout': 300,
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ recognizer = ModelFactory.create_vl_recognizer(config)
|
|
|
|
|
+ logger.info(f"适配器类型: {type(recognizer).__name__}")
|
|
|
|
|
+ logger.success("✅ ModelFactory 创建适配器成功")
|
|
|
|
|
+ return True
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ ModelFactory 创建失败: {e}")
|
|
|
|
|
+ import traceback
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main():
|
|
|
|
|
+ """主测试函数"""
|
|
|
|
|
+ logger.info("="*60)
|
|
|
|
|
+ logger.info("开始测试 GLM-OCR 适配器")
|
|
|
|
|
+ logger.info("="*60)
|
|
|
|
|
+
|
|
|
|
|
+ results = []
|
|
|
|
|
+
|
|
|
|
|
+ # 测试 1: 导入
|
|
|
|
|
+ results.append(("导入适配器", test_import_adapter()))
|
|
|
|
|
+
|
|
|
|
|
+ # 测试 2: 加载配置
|
|
|
|
|
+ success, config = test_load_config()
|
|
|
|
|
+ results.append(("加载配置", success))
|
|
|
|
|
+
|
|
|
|
|
+ # 测试 3: 创建实例
|
|
|
|
|
+ success, recognizer = test_create_adapter()
|
|
|
|
|
+ results.append(("创建实例", success))
|
|
|
|
|
+
|
|
|
|
|
+ # 测试 4: 初始化(可选,需要 API 服务)
|
|
|
|
|
+ if success and recognizer:
|
|
|
|
|
+ init_success = test_initialize_adapter(recognizer)
|
|
|
|
|
+ results.append(("初始化适配器", init_success))
|
|
|
|
|
+
|
|
|
|
|
+ # 测试 5: 工厂方法
|
|
|
|
|
+ results.append(("ModelFactory", test_model_factory()))
|
|
|
|
|
+
|
|
|
|
|
+ # 汇总结果
|
|
|
|
|
+ logger.info("="*60)
|
|
|
|
|
+ logger.info("测试结果汇总:")
|
|
|
|
|
+ logger.info("="*60)
|
|
|
|
|
+ for test_name, result in results:
|
|
|
|
|
+ status = "✅ 通过" if result else "❌ 失败"
|
|
|
|
|
+ logger.info(f"{test_name:20s}: {status}")
|
|
|
|
|
+
|
|
|
|
|
+ passed = sum(1 for _, r in results if r)
|
|
|
|
|
+ total = len(results)
|
|
|
|
|
+ logger.info(f"\n总计: {passed}/{total} 测试通过")
|
|
|
|
|
+
|
|
|
|
|
+ return passed == total
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ success = main()
|
|
|
|
|
+ sys.exit(0 if success else 1)
|