|
|
@@ -0,0 +1,440 @@
|
|
|
+import os
|
|
|
+import base64
|
|
|
+import json
|
|
|
+import time
|
|
|
+from pathlib import Path
|
|
|
+from openai import OpenAI
|
|
|
+from dotenv import load_dotenv
|
|
|
+from typing import Any, Dict, List
|
|
|
+
|
|
|
+# 加载环境变量
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+def verify_ocr_with_vlm(image_path, ocr_json_path, output_path="ocr_differences.json",
|
|
|
+ api_key=None, api_base=None, model_id=None,
|
|
|
+ temperature=0.1, max_tokens=4096, timeout=180):
|
|
|
+ """
|
|
|
+ 使用VLM对比OCR识别结果和原图,找出差异部分
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_path: 原图路径
|
|
|
+ ocr_json_path: OCR识别结果JSON文件路径
|
|
|
+ output_path: 差异分析输出文件路径
|
|
|
+ api_key: API密钥,如果为None则从环境变量获取
|
|
|
+ api_base: API基础URL,如果为None则从环境变量获取
|
|
|
+ model_id: 模型ID,如果为None则从环境变量获取
|
|
|
+ temperature: 生成温度,默认0.1
|
|
|
+ max_tokens: 最大输出token数,默认4096
|
|
|
+ timeout: 请求超时时间,默认180秒
|
|
|
+ """
|
|
|
+ # 从参数或环境变量获取API配置
|
|
|
+ api_key = api_key or os.getenv("YUSYS_MULTIMODAL_API_KEY")
|
|
|
+ api_base = api_base or os.getenv("YUSYS_MULTIMODAL_API_BASE")
|
|
|
+ model_id = model_id or os.getenv("YUSYS_MULTIMODAL_ID")
|
|
|
+
|
|
|
+ if not api_key:
|
|
|
+ raise ValueError("未找到API密钥,请通过参数传入或设置YUSYS_MULTIMODAL_API_KEY环境变量")
|
|
|
+ if not api_base:
|
|
|
+ raise ValueError("未找到API基础URL,请通过参数传入或设置YUSYS_MULTIMODAL_API_BASE环境变量")
|
|
|
+ if not model_id:
|
|
|
+ raise ValueError("未找到模型ID,请通过参数传入或设置YUSYS_MULTIMODAL_ID环境变量")
|
|
|
+
|
|
|
+ # 去掉openai/前缀
|
|
|
+ model_name = model_id.replace("openai/", "")
|
|
|
+
|
|
|
+ # 读取图片文件并转换为base64
|
|
|
+ try:
|
|
|
+ with open(image_path, "rb") as image_file:
|
|
|
+ image_data = base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
+ except FileNotFoundError:
|
|
|
+ raise FileNotFoundError(f"找不到图片文件: {image_path}")
|
|
|
+
|
|
|
+ # 读取OCR结果
|
|
|
+ try:
|
|
|
+ with open(ocr_json_path, "r", encoding='utf-8') as f:
|
|
|
+ ocr_results = json.load(f)
|
|
|
+ except FileNotFoundError:
|
|
|
+ raise FileNotFoundError(f"找不到OCR结果文件: {ocr_json_path}")
|
|
|
+
|
|
|
+ # 获取图片的MIME类型
|
|
|
+ file_extension = Path(image_path).suffix.lower()
|
|
|
+ mime_type_map = {
|
|
|
+ '.jpg': 'image/jpeg',
|
|
|
+ '.jpeg': 'image/jpeg',
|
|
|
+ '.png': 'image/png',
|
|
|
+ '.gif': 'image/gif',
|
|
|
+ '.webp': 'image/webp'
|
|
|
+ }
|
|
|
+ mime_type = mime_type_map.get(file_extension, 'image/jpeg')
|
|
|
+
|
|
|
+ # 构建详细的OCR结果文本,包含位置信息
|
|
|
+ ocr_text = "OCR识别结果:\n"
|
|
|
+ for item in ocr_results:
|
|
|
+ bbox = item.get('bbox', [])
|
|
|
+ category = item.get('category', '')
|
|
|
+ text = item.get('text', '')
|
|
|
+ ocr_text += f"位置坐标[{bbox}] - 类别: {category} - 文本: {text}\n"
|
|
|
+
|
|
|
+ # 构建分析提示词
|
|
|
+ prompt = f"""请仔细分析这张图片,并与以下OCR识别结果进行逐项详细对比:
|
|
|
+
|
|
|
+{ocr_text}
|
|
|
+
|
|
|
+重要要求:
|
|
|
+1. 对于表格中的每一个数据项(特别是数字、金额、项目名称),都必须逐一验证
|
|
|
+2. 即使发现微小差异也要报告(如小数点位数、千分符、标点符号等)
|
|
|
+3. 对于表格结构要仔细检查行列对应关系
|
|
|
+4. 必须输出所有发现的问题,不要遗漏任何错误
|
|
|
+
|
|
|
+请执行以下详细任务:
|
|
|
+1. 逐行逐列验证表格中的每个数据项是否准确
|
|
|
+2. 检查数字格式:小数点、千分符、负号等
|
|
|
+3. 验证项目名称的完整性和准确性
|
|
|
+4. 检查表格标题和表头信息
|
|
|
+5. 验证行次编号的正确性
|
|
|
+6. 识别任何遗漏的表格内容
|
|
|
+7. 检查文本的位置坐标是否与实际位置匹配
|
|
|
+
|
|
|
+请以JSON格式返回详细分析结果,对于表格中的每个识别项都要有明确的验证结果:
|
|
|
+{{
|
|
|
+ "table_verification": {{
|
|
|
+ "total_items_checked": 数字,
|
|
|
+ "accuracy_rate": "百分比",
|
|
|
+ "table_structure_correct": true/false
|
|
|
+ }},
|
|
|
+ "accurate_items": [
|
|
|
+ {{
|
|
|
+ "bbox": [x1, y1, x2, y2],
|
|
|
+ "original_text": "OCR识别的文本",
|
|
|
+ "verified": true,
|
|
|
+ "table_position": "行X列Y",
|
|
|
+ "notes": "验证说明"
|
|
|
+ }}
|
|
|
+ ],
|
|
|
+ "errors": [
|
|
|
+ {{
|
|
|
+ "bbox": [x1, y1, x2, y2],
|
|
|
+ "original_text": "OCR识别的文本",
|
|
|
+ "correct_text": "正确的文本",
|
|
|
+ "error_type": "数字错误|格式错误|字符错误|缺失内容",
|
|
|
+ "table_position": "行X列Y",
|
|
|
+ "severity": "高|中|低",
|
|
|
+ "confidence": "高|中|低",
|
|
|
+ "detailed_description": "详细错误描述"
|
|
|
+ }}
|
|
|
+ ],
|
|
|
+ "missing_items": [
|
|
|
+ {{
|
|
|
+ "estimated_bbox": [x1, y1, x2, y2],
|
|
|
+ "missing_text": "遗漏的文本",
|
|
|
+ "category": "数字|文本|标题|其他",
|
|
|
+ "table_position": "行X列Y",
|
|
|
+ "importance": "高|中|低",
|
|
|
+ "impact": "对数据准确性的影响描述"
|
|
|
+ }}
|
|
|
+ ],
|
|
|
+ "position_errors": [
|
|
|
+ {{
|
|
|
+ "text": "文本内容",
|
|
|
+ "ocr_bbox": [x1, y1, x2, y2],
|
|
|
+ "actual_bbox": [x1, y1, x2, y2],
|
|
|
+ "table_position": "行X列Y",
|
|
|
+ "deviation": "具体偏差描述"
|
|
|
+ }}
|
|
|
+ ],
|
|
|
+ "format_issues": [
|
|
|
+ {{
|
|
|
+ "item": "受影响的项目",
|
|
|
+ "issue_type": "千分符缺失|小数位错误|符号错误",
|
|
|
+ "original_format": "OCR识别的格式",
|
|
|
+ "correct_format": "正确格式",
|
|
|
+ "table_position": "行X列Y"
|
|
|
+ }}
|
|
|
+ ]
|
|
|
+}}
|
|
|
+
|
|
|
+特别注意:
|
|
|
+- 表格中的所有金额数据都必须精确验证
|
|
|
+- 对于大额数字,要特别注意千分符和小数点
|
|
|
+- 项目名称要检查是否完整,包括括号、冒号等标点符号
|
|
|
+- 行次编号必须与对应项目正确匹配
|
|
|
+- 任何发现的问题都要详细记录,包括位置和具体错误内容"""
|
|
|
+
|
|
|
+ # 创建OpenAI客户端
|
|
|
+ client = OpenAI(
|
|
|
+ api_key=api_key,
|
|
|
+ base_url=api_base
|
|
|
+ )
|
|
|
+
|
|
|
+ # 构建消息内容
|
|
|
+ messages: List[Dict[str, Any]] = [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": prompt
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": f"data:{mime_type};base64,{image_data}"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ]
|
|
|
+
|
|
|
+ try:
|
|
|
+ print(f"正在使用模型 {model_name} 进行OCR验证...")
|
|
|
+ print(f"API地址: {api_base}")
|
|
|
+
|
|
|
+ # 调用API
|
|
|
+ response = client.chat.completions.create(
|
|
|
+ model=model_name,
|
|
|
+ messages=messages, # type: ignore
|
|
|
+ temperature=temperature,
|
|
|
+ max_tokens=max_tokens,
|
|
|
+ timeout=timeout
|
|
|
+ )
|
|
|
+
|
|
|
+ # 提取响应内容
|
|
|
+ generated_text = response.choices[0].message.content
|
|
|
+
|
|
|
+ if not generated_text:
|
|
|
+ raise Exception("模型没有生成文本内容")
|
|
|
+
|
|
|
+ print(f"成功使用模型 {model_name} 完成OCR验证!")
|
|
|
+
|
|
|
+ # 尝试解析JSON结果
|
|
|
+ verification_result: Dict[str, Any] = {}
|
|
|
+ try:
|
|
|
+ # 查找JSON部分
|
|
|
+ json_start = generated_text.find('{')
|
|
|
+ json_end = generated_text.rfind('}') + 1
|
|
|
+
|
|
|
+ if json_start != -1 and json_end > json_start:
|
|
|
+ json_content = generated_text[json_start:json_end]
|
|
|
+ verification_result = json.loads(json_content)
|
|
|
+ else:
|
|
|
+ # 如果没有找到JSON,创建一个包含原始文本的结果
|
|
|
+ verification_result = {
|
|
|
+ "raw_analysis": generated_text,
|
|
|
+ "parsing_note": "无法解析为标准JSON格式,返回原始分析文本"
|
|
|
+ }
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ verification_result = {
|
|
|
+ "raw_analysis": generated_text,
|
|
|
+ "parsing_note": "JSON解析失败,返回原始分析文本"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 添加元数据
|
|
|
+ verification_result["metadata"] = {
|
|
|
+ "model_used": model_name,
|
|
|
+ "model_id": model_id,
|
|
|
+ "api_base": api_base,
|
|
|
+ "temperature": temperature,
|
|
|
+ "max_tokens": max_tokens,
|
|
|
+ "timeout": timeout,
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "original_image": image_path,
|
|
|
+ "ocr_json": ocr_json_path
|
|
|
+ }
|
|
|
+
|
|
|
+ # 保存结果
|
|
|
+ with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(verification_result, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f"OCR验证结果已保存到: {output_path}")
|
|
|
+
|
|
|
+ # 打印详细统计
|
|
|
+ print("\n📊 验证结果统计:")
|
|
|
+ if "table_verification" in verification_result:
|
|
|
+ tv = verification_result["table_verification"]
|
|
|
+ print(f" 📋 表格检查项目: {tv.get('total_items_checked', 'N/A')}")
|
|
|
+ print(f" 📈 准确率: {tv.get('accuracy_rate', 'N/A')}")
|
|
|
+
|
|
|
+ if "accurate_items" in verification_result:
|
|
|
+ print(f" ✅ 准确项目数量: {len(verification_result['accurate_items'])}")
|
|
|
+ if "errors" in verification_result:
|
|
|
+ error_count = len(verification_result['errors'])
|
|
|
+ print(f" ❌ 发现错误数量: {error_count}")
|
|
|
+ if error_count > 0 and verification_result['errors']:
|
|
|
+ high_errors = len([e for e in verification_result['errors'] if e.get('severity') == '高'])
|
|
|
+ if high_errors > 0:
|
|
|
+ print(f" 🔴 高严重程度错误: {high_errors}")
|
|
|
+ if "missing_items" in verification_result:
|
|
|
+ missing_count = len(verification_result['missing_items'])
|
|
|
+ print(f" 📝 遗漏项目数量: {missing_count}")
|
|
|
+ if missing_count > 0 and verification_result['missing_items']:
|
|
|
+ high_missing = len([m for m in verification_result['missing_items'] if m.get('importance') == '高'])
|
|
|
+ if high_missing > 0:
|
|
|
+ print(f" 🔴 重要遗漏项目: {high_missing}")
|
|
|
+ if "position_errors" in verification_result:
|
|
|
+ print(f" 📍 位置错误数量: {len(verification_result['position_errors'])}")
|
|
|
+ if "format_issues" in verification_result:
|
|
|
+ print(f" 🎨 格式问题数量: {len(verification_result['format_issues'])}")
|
|
|
+
|
|
|
+ return verification_result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"OCR验证失败: {e}")
|
|
|
+ raise Exception(f"OCR验证任务失败: {e}")
|
|
|
+
|
|
|
+def analyze_differences(verification_result_path):
|
|
|
+ """
|
|
|
+ 分析OCR验证结果,生成详细的表格差异报告
|
|
|
+
|
|
|
+ Args:
|
|
|
+ verification_result_path: 验证结果JSON文件路径
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ with open(verification_result_path, 'r', encoding='utf-8') as f:
|
|
|
+ result = json.load(f)
|
|
|
+ except FileNotFoundError:
|
|
|
+ raise FileNotFoundError(f"找不到验证结果文件: {verification_result_path}")
|
|
|
+
|
|
|
+ print("\n" + "="*60)
|
|
|
+ print("OCR验证详细差异分析报告")
|
|
|
+ print("="*60)
|
|
|
+
|
|
|
+ if "metadata" in result:
|
|
|
+ print(f"分析时间: {result['metadata']['timestamp']}")
|
|
|
+ print(f"使用模型: {result['metadata']['model_used']}")
|
|
|
+ print(f"原始图片: {result['metadata']['original_image']}")
|
|
|
+ print(f"OCR结果: {result['metadata']['ocr_json']}")
|
|
|
+ print(f"分析配置: 温度={result['metadata'].get('temperature', 'N/A')}, "
|
|
|
+ f"最大tokens={result['metadata'].get('max_tokens', 'N/A')}")
|
|
|
+
|
|
|
+ # 表格验证总览
|
|
|
+ if "table_verification" in result:
|
|
|
+ tv = result["table_verification"]
|
|
|
+ print(f"\n📊 表格验证总览:")
|
|
|
+ print(f" 检查项目总数: {tv.get('total_items_checked', 'N/A')}")
|
|
|
+ print(f" 准确率: {tv.get('accuracy_rate', 'N/A')}")
|
|
|
+ print(f" 表格结构正确: {'✅' if tv.get('table_structure_correct') else '❌'}")
|
|
|
+
|
|
|
+ print(f"\n1. ✅ 准确识别项目:")
|
|
|
+ if "accurate_items" in result and result["accurate_items"]:
|
|
|
+ for i, item in enumerate(result["accurate_items"], 1):
|
|
|
+ pos = item.get('table_position', '未知位置')
|
|
|
+ bbox = item.get('bbox', 'N/A')
|
|
|
+ text = item.get('original_text', 'N/A')
|
|
|
+ print(f" {i}. {pos} - 位置{bbox}")
|
|
|
+ print(f" 内容: {text}")
|
|
|
+ if item.get('notes'):
|
|
|
+ print(f" 说明: {item['notes']}")
|
|
|
+ print()
|
|
|
+ else:
|
|
|
+ print(" 无准确识别项目或未能解析")
|
|
|
+
|
|
|
+ print(f"\n2. ❌ 识别错误 (需要重点关注):")
|
|
|
+ if "errors" in result and result["errors"]:
|
|
|
+ for i, error in enumerate(result["errors"], 1):
|
|
|
+ pos = error.get('table_position', '未知位置')
|
|
|
+ severity = error.get('severity', '未知')
|
|
|
+ severity_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(severity, "⚪")
|
|
|
+
|
|
|
+ print(f" {severity_icon} 错误 {i} - {pos}")
|
|
|
+ print(f" 位置坐标: {error.get('bbox', 'N/A')}")
|
|
|
+ print(f" OCR识别: 「{error.get('original_text', 'N/A')}」")
|
|
|
+ print(f" 正确内容: 「{error.get('correct_text', 'N/A')}」")
|
|
|
+ print(f" 错误类型: {error.get('error_type', 'N/A')}")
|
|
|
+ print(f" 严重程度: {severity}")
|
|
|
+ print(f" 置信度: {error.get('confidence', 'N/A')}")
|
|
|
+ if error.get('detailed_description'):
|
|
|
+ print(f" 详细说明: {error['detailed_description']}")
|
|
|
+ print()
|
|
|
+ else:
|
|
|
+ print(" ✅ 未发现识别错误")
|
|
|
+
|
|
|
+ print(f"\n3. 📝 遗漏项目:")
|
|
|
+ if "missing_items" in result and result["missing_items"]:
|
|
|
+ for i, missing in enumerate(result["missing_items"], 1):
|
|
|
+ pos = missing.get('table_position', '未知位置')
|
|
|
+ importance = missing.get('importance', '未知')
|
|
|
+ importance_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(importance, "⚪")
|
|
|
+
|
|
|
+ print(f" {importance_icon} 遗漏 {i} - {pos}")
|
|
|
+ print(f" 预估位置: {missing.get('estimated_bbox', 'N/A')}")
|
|
|
+ print(f" 遗漏内容: 「{missing.get('missing_text', 'N/A')}」")
|
|
|
+ print(f" 内容类别: {missing.get('category', 'N/A')}")
|
|
|
+ print(f" 重要程度: {importance}")
|
|
|
+ if missing.get('impact'):
|
|
|
+ print(f" 影响说明: {missing['impact']}")
|
|
|
+ print()
|
|
|
+ else:
|
|
|
+ print(" ✅ 未发现遗漏项目")
|
|
|
+
|
|
|
+ print(f"\n4. 📍 位置错误:")
|
|
|
+ if "position_errors" in result and result["position_errors"]:
|
|
|
+ for i, pos_error in enumerate(result["position_errors"], 1):
|
|
|
+ table_pos = pos_error.get('table_position', '未知位置')
|
|
|
+ print(f" 📍 位置错误 {i} - {table_pos}")
|
|
|
+ print(f" 文本内容: 「{pos_error.get('text', 'N/A')}」")
|
|
|
+ print(f" OCR位置: {pos_error.get('ocr_bbox', 'N/A')}")
|
|
|
+ print(f" 实际位置: {pos_error.get('actual_bbox', 'N/A')}")
|
|
|
+ print(f" 偏差描述: {pos_error.get('deviation', 'N/A')}")
|
|
|
+ print()
|
|
|
+ else:
|
|
|
+ print(" ✅ 未发现位置错误")
|
|
|
+
|
|
|
+ # 新增:格式问题
|
|
|
+ print(f"\n5. 🎨 格式问题:")
|
|
|
+ if "format_issues" in result and result["format_issues"]:
|
|
|
+ for i, format_issue in enumerate(result["format_issues"], 1):
|
|
|
+ pos = format_issue.get('table_position', '未知位置')
|
|
|
+ print(f" 🎨 格式问题 {i} - {pos}")
|
|
|
+ print(f" 受影响项目: 「{format_issue.get('item', 'N/A')}」")
|
|
|
+ print(f" 问题类型: {format_issue.get('issue_type', 'N/A')}")
|
|
|
+ print(f" OCR格式: 「{format_issue.get('original_format', 'N/A')}」")
|
|
|
+ print(f" 正确格式: 「{format_issue.get('correct_format', 'N/A')}」")
|
|
|
+ print()
|
|
|
+ else:
|
|
|
+ print(" ✅ 未发现格式问题")
|
|
|
+
|
|
|
+ # 生成错误摘要
|
|
|
+ total_errors = 0
|
|
|
+ if "errors" in result:
|
|
|
+ total_errors += len(result["errors"])
|
|
|
+ if "missing_items" in result:
|
|
|
+ total_errors += len(result["missing_items"])
|
|
|
+ if "position_errors" in result:
|
|
|
+ total_errors += len(result["position_errors"])
|
|
|
+ if "format_issues" in result:
|
|
|
+ total_errors += len(result["format_issues"])
|
|
|
+
|
|
|
+ print(f"\n" + "="*60)
|
|
|
+ print(f"📈 验证摘要:")
|
|
|
+ print(f" 总错误数量: {total_errors}")
|
|
|
+ if "errors" in result:
|
|
|
+ high_severity = len([e for e in result["errors"] if e.get('severity') == '高'])
|
|
|
+ if high_severity > 0:
|
|
|
+ print(f" 🔴 高严重程度错误: {high_severity} 个")
|
|
|
+
|
|
|
+ if total_errors == 0:
|
|
|
+ print(" 🎉 恭喜!未发现任何OCR错误")
|
|
|
+ else:
|
|
|
+ print(f" ⚠️ 建议:仔细检查所有标记为'高'严重程度的错误")
|
|
|
+
|
|
|
+ # 如果有原始分析文本,也显示出来
|
|
|
+ if "raw_analysis" in result:
|
|
|
+ print(f"\n6. 📄 VLM原始分析内容:")
|
|
|
+ print("-" * 50)
|
|
|
+ print(result["raw_analysis"])
|
|
|
+ print("-" * 50)
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 示例用法
|
|
|
+ image_path = "至远彩色印刷工业有限公司-2022年母公司_2.png" # 假设这是利润表图片
|
|
|
+ ocr_json_path = "demo_54fa7ad0_page_1.json" # OCR结果文件
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 进行OCR验证
|
|
|
+ result = verify_ocr_with_vlm(image_path, ocr_json_path)
|
|
|
+
|
|
|
+ # 分析差异
|
|
|
+ analyze_differences("ocr_differences.json")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"OCR验证失败: {e}")
|