| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440 |
- import os
- import base64
- import json
- import time
- from pathlib import Path
- from openai import OpenAI
- from dotenv import load_dotenv
- from typing import Any, Dict, List
- # 加载环境变量
- load_dotenv()
- def verify_ocr_with_vlm(image_path, ocr_json_path, output_path="ocr_differences.json",
- api_key=None, api_base=None, model_id=None,
- temperature=0.1, max_tokens=4096, timeout=180):
- """
- 使用VLM对比OCR识别结果和原图,找出差异部分
-
- Args:
- image_path: 原图路径
- ocr_json_path: OCR识别结果JSON文件路径
- output_path: 差异分析输出文件路径
- api_key: API密钥,如果为None则从环境变量获取
- api_base: API基础URL,如果为None则从环境变量获取
- model_id: 模型ID,如果为None则从环境变量获取
- temperature: 生成温度,默认0.1
- max_tokens: 最大输出token数,默认4096
- timeout: 请求超时时间,默认180秒
- """
- # 从参数或环境变量获取API配置
- api_key = api_key or os.getenv("YUSYS_MULTIMODAL_API_KEY")
- api_base = api_base or os.getenv("YUSYS_MULTIMODAL_API_BASE")
- model_id = model_id or os.getenv("YUSYS_MULTIMODAL_ID")
-
- if not api_key:
- raise ValueError("未找到API密钥,请通过参数传入或设置YUSYS_MULTIMODAL_API_KEY环境变量")
- if not api_base:
- raise ValueError("未找到API基础URL,请通过参数传入或设置YUSYS_MULTIMODAL_API_BASE环境变量")
- if not model_id:
- raise ValueError("未找到模型ID,请通过参数传入或设置YUSYS_MULTIMODAL_ID环境变量")
-
- # 去掉openai/前缀
- model_name = model_id.replace("openai/", "")
-
- # 读取图片文件并转换为base64
- try:
- with open(image_path, "rb") as image_file:
- image_data = base64.b64encode(image_file.read()).decode('utf-8')
- except FileNotFoundError:
- raise FileNotFoundError(f"找不到图片文件: {image_path}")
-
- # 读取OCR结果
- try:
- with open(ocr_json_path, "r", encoding='utf-8') as f:
- ocr_results = json.load(f)
- except FileNotFoundError:
- raise FileNotFoundError(f"找不到OCR结果文件: {ocr_json_path}")
-
- # 获取图片的MIME类型
- file_extension = Path(image_path).suffix.lower()
- mime_type_map = {
- '.jpg': 'image/jpeg',
- '.jpeg': 'image/jpeg',
- '.png': 'image/png',
- '.gif': 'image/gif',
- '.webp': 'image/webp'
- }
- mime_type = mime_type_map.get(file_extension, 'image/jpeg')
-
- # 构建详细的OCR结果文本,包含位置信息
- ocr_text = "OCR识别结果:\n"
- for item in ocr_results:
- bbox = item.get('bbox', [])
- category = item.get('category', '')
- text = item.get('text', '')
- ocr_text += f"位置坐标[{bbox}] - 类别: {category} - 文本: {text}\n"
-
- # 构建分析提示词
- prompt = f"""请仔细分析这张图片,并与以下OCR识别结果进行逐项详细对比:
- {ocr_text}
- 重要要求:
- 1. 对于表格中的每一个数据项(特别是数字、金额、项目名称),都必须逐一验证
- 2. 即使发现微小差异也要报告(如小数点位数、千分符、标点符号等)
- 3. 对于表格结构要仔细检查行列对应关系
- 4. 必须输出所有发现的问题,不要遗漏任何错误
- 请执行以下详细任务:
- 1. 逐行逐列验证表格中的每个数据项是否准确
- 2. 检查数字格式:小数点、千分符、负号等
- 3. 验证项目名称的完整性和准确性
- 4. 检查表格标题和表头信息
- 5. 验证行次编号的正确性
- 6. 识别任何遗漏的表格内容
- 7. 检查文本的位置坐标是否与实际位置匹配
- 请以JSON格式返回详细分析结果,对于表格中的每个识别项都要有明确的验证结果:
- {{
- "table_verification": {{
- "total_items_checked": 数字,
- "accuracy_rate": "百分比",
- "table_structure_correct": true/false
- }},
- "accurate_items": [
- {{
- "bbox": [x1, y1, x2, y2],
- "original_text": "OCR识别的文本",
- "verified": true,
- "table_position": "行X列Y",
- "notes": "验证说明"
- }}
- ],
- "errors": [
- {{
- "bbox": [x1, y1, x2, y2],
- "original_text": "OCR识别的文本",
- "correct_text": "正确的文本",
- "error_type": "数字错误|格式错误|字符错误|缺失内容",
- "table_position": "行X列Y",
- "severity": "高|中|低",
- "confidence": "高|中|低",
- "detailed_description": "详细错误描述"
- }}
- ],
- "missing_items": [
- {{
- "estimated_bbox": [x1, y1, x2, y2],
- "missing_text": "遗漏的文本",
- "category": "数字|文本|标题|其他",
- "table_position": "行X列Y",
- "importance": "高|中|低",
- "impact": "对数据准确性的影响描述"
- }}
- ],
- "position_errors": [
- {{
- "text": "文本内容",
- "ocr_bbox": [x1, y1, x2, y2],
- "actual_bbox": [x1, y1, x2, y2],
- "table_position": "行X列Y",
- "deviation": "具体偏差描述"
- }}
- ],
- "format_issues": [
- {{
- "item": "受影响的项目",
- "issue_type": "千分符缺失|小数位错误|符号错误",
- "original_format": "OCR识别的格式",
- "correct_format": "正确格式",
- "table_position": "行X列Y"
- }}
- ]
- }}
- 特别注意:
- - 表格中的所有金额数据都必须精确验证
- - 对于大额数字,要特别注意千分符和小数点
- - 项目名称要检查是否完整,包括括号、冒号等标点符号
- - 行次编号必须与对应项目正确匹配
- - 任何发现的问题都要详细记录,包括位置和具体错误内容"""
-
- # 创建OpenAI客户端
- client = OpenAI(
- api_key=api_key,
- base_url=api_base
- )
-
- # 构建消息内容
- messages: List[Dict[str, Any]] = [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": prompt
- },
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:{mime_type};base64,{image_data}"
- }
- }
- ]
- }
- ]
-
- try:
- print(f"正在使用模型 {model_name} 进行OCR验证...")
- print(f"API地址: {api_base}")
-
- # 调用API
- response = client.chat.completions.create(
- model=model_name,
- messages=messages, # type: ignore
- temperature=temperature,
- max_tokens=max_tokens,
- timeout=timeout
- )
-
- # 提取响应内容
- generated_text = response.choices[0].message.content
-
- if not generated_text:
- raise Exception("模型没有生成文本内容")
-
- print(f"成功使用模型 {model_name} 完成OCR验证!")
-
- # 尝试解析JSON结果
- verification_result: Dict[str, Any] = {}
- try:
- # 查找JSON部分
- json_start = generated_text.find('{')
- json_end = generated_text.rfind('}') + 1
-
- if json_start != -1 and json_end > json_start:
- json_content = generated_text[json_start:json_end]
- verification_result = json.loads(json_content)
- else:
- # 如果没有找到JSON,创建一个包含原始文本的结果
- verification_result = {
- "raw_analysis": generated_text,
- "parsing_note": "无法解析为标准JSON格式,返回原始分析文本"
- }
- except json.JSONDecodeError:
- verification_result = {
- "raw_analysis": generated_text,
- "parsing_note": "JSON解析失败,返回原始分析文本"
- }
-
- # 添加元数据
- verification_result["metadata"] = {
- "model_used": model_name,
- "model_id": model_id,
- "api_base": api_base,
- "temperature": temperature,
- "max_tokens": max_tokens,
- "timeout": timeout,
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
- "original_image": image_path,
- "ocr_json": ocr_json_path
- }
-
- # 保存结果
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(verification_result, f, ensure_ascii=False, indent=2)
-
- print(f"OCR验证结果已保存到: {output_path}")
-
- # 打印详细统计
- print("\n📊 验证结果统计:")
- if "table_verification" in verification_result:
- tv = verification_result["table_verification"]
- print(f" 📋 表格检查项目: {tv.get('total_items_checked', 'N/A')}")
- print(f" 📈 准确率: {tv.get('accuracy_rate', 'N/A')}")
-
- if "accurate_items" in verification_result:
- print(f" ✅ 准确项目数量: {len(verification_result['accurate_items'])}")
- if "errors" in verification_result:
- error_count = len(verification_result['errors'])
- print(f" ❌ 发现错误数量: {error_count}")
- if error_count > 0 and verification_result['errors']:
- high_errors = len([e for e in verification_result['errors'] if e.get('severity') == '高'])
- if high_errors > 0:
- print(f" 🔴 高严重程度错误: {high_errors}")
- if "missing_items" in verification_result:
- missing_count = len(verification_result['missing_items'])
- print(f" 📝 遗漏项目数量: {missing_count}")
- if missing_count > 0 and verification_result['missing_items']:
- high_missing = len([m for m in verification_result['missing_items'] if m.get('importance') == '高'])
- if high_missing > 0:
- print(f" 🔴 重要遗漏项目: {high_missing}")
- if "position_errors" in verification_result:
- print(f" 📍 位置错误数量: {len(verification_result['position_errors'])}")
- if "format_issues" in verification_result:
- print(f" 🎨 格式问题数量: {len(verification_result['format_issues'])}")
-
- return verification_result
-
- except Exception as e:
- print(f"OCR验证失败: {e}")
- raise Exception(f"OCR验证任务失败: {e}")
- def analyze_differences(verification_result_path):
- """
- 分析OCR验证结果,生成详细的表格差异报告
-
- Args:
- verification_result_path: 验证结果JSON文件路径
- """
- try:
- with open(verification_result_path, 'r', encoding='utf-8') as f:
- result = json.load(f)
- except FileNotFoundError:
- raise FileNotFoundError(f"找不到验证结果文件: {verification_result_path}")
-
- print("\n" + "="*60)
- print("OCR验证详细差异分析报告")
- print("="*60)
-
- if "metadata" in result:
- print(f"分析时间: {result['metadata']['timestamp']}")
- print(f"使用模型: {result['metadata']['model_used']}")
- print(f"原始图片: {result['metadata']['original_image']}")
- print(f"OCR结果: {result['metadata']['ocr_json']}")
- print(f"分析配置: 温度={result['metadata'].get('temperature', 'N/A')}, "
- f"最大tokens={result['metadata'].get('max_tokens', 'N/A')}")
-
- # 表格验证总览
- if "table_verification" in result:
- tv = result["table_verification"]
- print(f"\n📊 表格验证总览:")
- print(f" 检查项目总数: {tv.get('total_items_checked', 'N/A')}")
- print(f" 准确率: {tv.get('accuracy_rate', 'N/A')}")
- print(f" 表格结构正确: {'✅' if tv.get('table_structure_correct') else '❌'}")
-
- print(f"\n1. ✅ 准确识别项目:")
- if "accurate_items" in result and result["accurate_items"]:
- for i, item in enumerate(result["accurate_items"], 1):
- pos = item.get('table_position', '未知位置')
- bbox = item.get('bbox', 'N/A')
- text = item.get('original_text', 'N/A')
- print(f" {i}. {pos} - 位置{bbox}")
- print(f" 内容: {text}")
- if item.get('notes'):
- print(f" 说明: {item['notes']}")
- print()
- else:
- print(" 无准确识别项目或未能解析")
-
- print(f"\n2. ❌ 识别错误 (需要重点关注):")
- if "errors" in result and result["errors"]:
- for i, error in enumerate(result["errors"], 1):
- pos = error.get('table_position', '未知位置')
- severity = error.get('severity', '未知')
- severity_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(severity, "⚪")
-
- print(f" {severity_icon} 错误 {i} - {pos}")
- print(f" 位置坐标: {error.get('bbox', 'N/A')}")
- print(f" OCR识别: 「{error.get('original_text', 'N/A')}」")
- print(f" 正确内容: 「{error.get('correct_text', 'N/A')}」")
- print(f" 错误类型: {error.get('error_type', 'N/A')}")
- print(f" 严重程度: {severity}")
- print(f" 置信度: {error.get('confidence', 'N/A')}")
- if error.get('detailed_description'):
- print(f" 详细说明: {error['detailed_description']}")
- print()
- else:
- print(" ✅ 未发现识别错误")
-
- print(f"\n3. 📝 遗漏项目:")
- if "missing_items" in result and result["missing_items"]:
- for i, missing in enumerate(result["missing_items"], 1):
- pos = missing.get('table_position', '未知位置')
- importance = missing.get('importance', '未知')
- importance_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(importance, "⚪")
-
- print(f" {importance_icon} 遗漏 {i} - {pos}")
- print(f" 预估位置: {missing.get('estimated_bbox', 'N/A')}")
- print(f" 遗漏内容: 「{missing.get('missing_text', 'N/A')}」")
- print(f" 内容类别: {missing.get('category', 'N/A')}")
- print(f" 重要程度: {importance}")
- if missing.get('impact'):
- print(f" 影响说明: {missing['impact']}")
- print()
- else:
- print(" ✅ 未发现遗漏项目")
-
- print(f"\n4. 📍 位置错误:")
- if "position_errors" in result and result["position_errors"]:
- for i, pos_error in enumerate(result["position_errors"], 1):
- table_pos = pos_error.get('table_position', '未知位置')
- print(f" 📍 位置错误 {i} - {table_pos}")
- print(f" 文本内容: 「{pos_error.get('text', 'N/A')}」")
- print(f" OCR位置: {pos_error.get('ocr_bbox', 'N/A')}")
- print(f" 实际位置: {pos_error.get('actual_bbox', 'N/A')}")
- print(f" 偏差描述: {pos_error.get('deviation', 'N/A')}")
- print()
- else:
- print(" ✅ 未发现位置错误")
-
- # 新增:格式问题
- print(f"\n5. 🎨 格式问题:")
- if "format_issues" in result and result["format_issues"]:
- for i, format_issue in enumerate(result["format_issues"], 1):
- pos = format_issue.get('table_position', '未知位置')
- print(f" 🎨 格式问题 {i} - {pos}")
- print(f" 受影响项目: 「{format_issue.get('item', 'N/A')}」")
- print(f" 问题类型: {format_issue.get('issue_type', 'N/A')}")
- print(f" OCR格式: 「{format_issue.get('original_format', 'N/A')}」")
- print(f" 正确格式: 「{format_issue.get('correct_format', 'N/A')}」")
- print()
- else:
- print(" ✅ 未发现格式问题")
-
- # 生成错误摘要
- total_errors = 0
- if "errors" in result:
- total_errors += len(result["errors"])
- if "missing_items" in result:
- total_errors += len(result["missing_items"])
- if "position_errors" in result:
- total_errors += len(result["position_errors"])
- if "format_issues" in result:
- total_errors += len(result["format_issues"])
-
- print(f"\n" + "="*60)
- print(f"📈 验证摘要:")
- print(f" 总错误数量: {total_errors}")
- if "errors" in result:
- high_severity = len([e for e in result["errors"] if e.get('severity') == '高'])
- if high_severity > 0:
- print(f" 🔴 高严重程度错误: {high_severity} 个")
-
- if total_errors == 0:
- print(" 🎉 恭喜!未发现任何OCR错误")
- else:
- print(f" ⚠️ 建议:仔细检查所有标记为'高'严重程度的错误")
-
- # 如果有原始分析文本,也显示出来
- if "raw_analysis" in result:
- print(f"\n6. 📄 VLM原始分析内容:")
- print("-" * 50)
- print(result["raw_analysis"])
- print("-" * 50)
- if __name__ == "__main__":
- # 示例用法
- image_path = "至远彩色印刷工业有限公司-2022年母公司_2.png" # 假设这是利润表图片
- ocr_json_path = "demo_54fa7ad0_page_1.json" # OCR结果文件
-
- try:
- # 进行OCR验证
- result = verify_ocr_with_vlm(image_path, ocr_json_path)
-
- # 分析差异
- analyze_differences("ocr_differences.json")
-
- except Exception as e:
- print(f"OCR验证失败: {e}")
|