import os import base64 import json import time from pathlib import Path from openai import OpenAI from dotenv import load_dotenv from typing import Any, Dict, List # 加载环境变量 load_dotenv() def verify_ocr_with_vlm(image_path, ocr_json_path, output_path="ocr_differences.json", api_key=None, api_base=None, model_id=None, temperature=0.1, max_tokens=4096, timeout=180): """ 使用VLM对比OCR识别结果和原图,找出差异部分 Args: image_path: 原图路径 ocr_json_path: OCR识别结果JSON文件路径 output_path: 差异分析输出文件路径 api_key: API密钥,如果为None则从环境变量获取 api_base: API基础URL,如果为None则从环境变量获取 model_id: 模型ID,如果为None则从环境变量获取 temperature: 生成温度,默认0.1 max_tokens: 最大输出token数,默认4096 timeout: 请求超时时间,默认180秒 """ # 从参数或环境变量获取API配置 api_key = api_key or os.getenv("YUSYS_MULTIMODAL_API_KEY") api_base = api_base or os.getenv("YUSYS_MULTIMODAL_API_BASE") model_id = model_id or os.getenv("YUSYS_MULTIMODAL_ID") if not api_key: raise ValueError("未找到API密钥,请通过参数传入或设置YUSYS_MULTIMODAL_API_KEY环境变量") if not api_base: raise ValueError("未找到API基础URL,请通过参数传入或设置YUSYS_MULTIMODAL_API_BASE环境变量") if not model_id: raise ValueError("未找到模型ID,请通过参数传入或设置YUSYS_MULTIMODAL_ID环境变量") # 去掉openai/前缀 model_name = model_id.replace("openai/", "") # 读取图片文件并转换为base64 try: with open(image_path, "rb") as image_file: image_data = base64.b64encode(image_file.read()).decode('utf-8') except FileNotFoundError: raise FileNotFoundError(f"找不到图片文件: {image_path}") # 读取OCR结果 try: with open(ocr_json_path, "r", encoding='utf-8') as f: ocr_results = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"找不到OCR结果文件: {ocr_json_path}") # 获取图片的MIME类型 file_extension = Path(image_path).suffix.lower() mime_type_map = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.webp': 'image/webp' } mime_type = mime_type_map.get(file_extension, 'image/jpeg') # 构建详细的OCR结果文本,包含位置信息 ocr_text = "OCR识别结果:\n" for item in ocr_results: bbox = item.get('bbox', []) category = item.get('category', '') text = item.get('text', '') ocr_text += f"位置坐标[{bbox}] - 类别: {category} - 文本: {text}\n" # 构建分析提示词 prompt = f"""请仔细分析这张图片,并与以下OCR识别结果进行逐项详细对比: {ocr_text} 重要要求: 1. 对于表格中的每一个数据项(特别是数字、金额、项目名称),都必须逐一验证 2. 即使发现微小差异也要报告(如小数点位数、千分符、标点符号等) 3. 对于表格结构要仔细检查行列对应关系 4. 必须输出所有发现的问题,不要遗漏任何错误 请执行以下详细任务: 1. 逐行逐列验证表格中的每个数据项是否准确 2. 检查数字格式:小数点、千分符、负号等 3. 验证项目名称的完整性和准确性 4. 检查表格标题和表头信息 5. 验证行次编号的正确性 6. 识别任何遗漏的表格内容 7. 检查文本的位置坐标是否与实际位置匹配 请以JSON格式返回详细分析结果,对于表格中的每个识别项都要有明确的验证结果: {{ "table_verification": {{ "total_items_checked": 数字, "accuracy_rate": "百分比", "table_structure_correct": true/false }}, "accurate_items": [ {{ "bbox": [x1, y1, x2, y2], "original_text": "OCR识别的文本", "verified": true, "table_position": "行X列Y", "notes": "验证说明" }} ], "errors": [ {{ "bbox": [x1, y1, x2, y2], "original_text": "OCR识别的文本", "correct_text": "正确的文本", "error_type": "数字错误|格式错误|字符错误|缺失内容", "table_position": "行X列Y", "severity": "高|中|低", "confidence": "高|中|低", "detailed_description": "详细错误描述" }} ], "missing_items": [ {{ "estimated_bbox": [x1, y1, x2, y2], "missing_text": "遗漏的文本", "category": "数字|文本|标题|其他", "table_position": "行X列Y", "importance": "高|中|低", "impact": "对数据准确性的影响描述" }} ], "position_errors": [ {{ "text": "文本内容", "ocr_bbox": [x1, y1, x2, y2], "actual_bbox": [x1, y1, x2, y2], "table_position": "行X列Y", "deviation": "具体偏差描述" }} ], "format_issues": [ {{ "item": "受影响的项目", "issue_type": "千分符缺失|小数位错误|符号错误", "original_format": "OCR识别的格式", "correct_format": "正确格式", "table_position": "行X列Y" }} ] }} 特别注意: - 表格中的所有金额数据都必须精确验证 - 对于大额数字,要特别注意千分符和小数点 - 项目名称要检查是否完整,包括括号、冒号等标点符号 - 行次编号必须与对应项目正确匹配 - 任何发现的问题都要详细记录,包括位置和具体错误内容""" # 创建OpenAI客户端 client = OpenAI( api_key=api_key, base_url=api_base ) # 构建消息内容 messages: List[Dict[str, Any]] = [ { "role": "user", "content": [ { "type": "text", "text": prompt }, { "type": "image_url", "image_url": { "url": f"data:{mime_type};base64,{image_data}" } } ] } ] try: print(f"正在使用模型 {model_name} 进行OCR验证...") print(f"API地址: {api_base}") # 调用API response = client.chat.completions.create( model=model_name, messages=messages, # type: ignore temperature=temperature, max_tokens=max_tokens, timeout=timeout ) # 提取响应内容 generated_text = response.choices[0].message.content if not generated_text: raise Exception("模型没有生成文本内容") print(f"成功使用模型 {model_name} 完成OCR验证!") # 尝试解析JSON结果 verification_result: Dict[str, Any] = {} try: # 查找JSON部分 json_start = generated_text.find('{') json_end = generated_text.rfind('}') + 1 if json_start != -1 and json_end > json_start: json_content = generated_text[json_start:json_end] verification_result = json.loads(json_content) else: # 如果没有找到JSON,创建一个包含原始文本的结果 verification_result = { "raw_analysis": generated_text, "parsing_note": "无法解析为标准JSON格式,返回原始分析文本" } except json.JSONDecodeError: verification_result = { "raw_analysis": generated_text, "parsing_note": "JSON解析失败,返回原始分析文本" } # 添加元数据 verification_result["metadata"] = { "model_used": model_name, "model_id": model_id, "api_base": api_base, "temperature": temperature, "max_tokens": max_tokens, "timeout": timeout, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "original_image": image_path, "ocr_json": ocr_json_path } # 保存结果 with open(output_path, 'w', encoding='utf-8') as f: json.dump(verification_result, f, ensure_ascii=False, indent=2) print(f"OCR验证结果已保存到: {output_path}") # 打印详细统计 print("\n📊 验证结果统计:") if "table_verification" in verification_result: tv = verification_result["table_verification"] print(f" 📋 表格检查项目: {tv.get('total_items_checked', 'N/A')}") print(f" 📈 准确率: {tv.get('accuracy_rate', 'N/A')}") if "accurate_items" in verification_result: print(f" ✅ 准确项目数量: {len(verification_result['accurate_items'])}") if "errors" in verification_result: error_count = len(verification_result['errors']) print(f" ❌ 发现错误数量: {error_count}") if error_count > 0 and verification_result['errors']: high_errors = len([e for e in verification_result['errors'] if e.get('severity') == '高']) if high_errors > 0: print(f" 🔴 高严重程度错误: {high_errors}") if "missing_items" in verification_result: missing_count = len(verification_result['missing_items']) print(f" 📝 遗漏项目数量: {missing_count}") if missing_count > 0 and verification_result['missing_items']: high_missing = len([m for m in verification_result['missing_items'] if m.get('importance') == '高']) if high_missing > 0: print(f" 🔴 重要遗漏项目: {high_missing}") if "position_errors" in verification_result: print(f" 📍 位置错误数量: {len(verification_result['position_errors'])}") if "format_issues" in verification_result: print(f" 🎨 格式问题数量: {len(verification_result['format_issues'])}") return verification_result except Exception as e: print(f"OCR验证失败: {e}") raise Exception(f"OCR验证任务失败: {e}") def analyze_differences(verification_result_path): """ 分析OCR验证结果,生成详细的表格差异报告 Args: verification_result_path: 验证结果JSON文件路径 """ try: with open(verification_result_path, 'r', encoding='utf-8') as f: result = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"找不到验证结果文件: {verification_result_path}") print("\n" + "="*60) print("OCR验证详细差异分析报告") print("="*60) if "metadata" in result: print(f"分析时间: {result['metadata']['timestamp']}") print(f"使用模型: {result['metadata']['model_used']}") print(f"原始图片: {result['metadata']['original_image']}") print(f"OCR结果: {result['metadata']['ocr_json']}") print(f"分析配置: 温度={result['metadata'].get('temperature', 'N/A')}, " f"最大tokens={result['metadata'].get('max_tokens', 'N/A')}") # 表格验证总览 if "table_verification" in result: tv = result["table_verification"] print(f"\n📊 表格验证总览:") print(f" 检查项目总数: {tv.get('total_items_checked', 'N/A')}") print(f" 准确率: {tv.get('accuracy_rate', 'N/A')}") print(f" 表格结构正确: {'✅' if tv.get('table_structure_correct') else '❌'}") print(f"\n1. ✅ 准确识别项目:") if "accurate_items" in result and result["accurate_items"]: for i, item in enumerate(result["accurate_items"], 1): pos = item.get('table_position', '未知位置') bbox = item.get('bbox', 'N/A') text = item.get('original_text', 'N/A') print(f" {i}. {pos} - 位置{bbox}") print(f" 内容: {text}") if item.get('notes'): print(f" 说明: {item['notes']}") print() else: print(" 无准确识别项目或未能解析") print(f"\n2. ❌ 识别错误 (需要重点关注):") if "errors" in result and result["errors"]: for i, error in enumerate(result["errors"], 1): pos = error.get('table_position', '未知位置') severity = error.get('severity', '未知') severity_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(severity, "⚪") print(f" {severity_icon} 错误 {i} - {pos}") print(f" 位置坐标: {error.get('bbox', 'N/A')}") print(f" OCR识别: 「{error.get('original_text', 'N/A')}」") print(f" 正确内容: 「{error.get('correct_text', 'N/A')}」") print(f" 错误类型: {error.get('error_type', 'N/A')}") print(f" 严重程度: {severity}") print(f" 置信度: {error.get('confidence', 'N/A')}") if error.get('detailed_description'): print(f" 详细说明: {error['detailed_description']}") print() else: print(" ✅ 未发现识别错误") print(f"\n3. 📝 遗漏项目:") if "missing_items" in result and result["missing_items"]: for i, missing in enumerate(result["missing_items"], 1): pos = missing.get('table_position', '未知位置') importance = missing.get('importance', '未知') importance_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(importance, "⚪") print(f" {importance_icon} 遗漏 {i} - {pos}") print(f" 预估位置: {missing.get('estimated_bbox', 'N/A')}") print(f" 遗漏内容: 「{missing.get('missing_text', 'N/A')}」") print(f" 内容类别: {missing.get('category', 'N/A')}") print(f" 重要程度: {importance}") if missing.get('impact'): print(f" 影响说明: {missing['impact']}") print() else: print(" ✅ 未发现遗漏项目") print(f"\n4. 📍 位置错误:") if "position_errors" in result and result["position_errors"]: for i, pos_error in enumerate(result["position_errors"], 1): table_pos = pos_error.get('table_position', '未知位置') print(f" 📍 位置错误 {i} - {table_pos}") print(f" 文本内容: 「{pos_error.get('text', 'N/A')}」") print(f" OCR位置: {pos_error.get('ocr_bbox', 'N/A')}") print(f" 实际位置: {pos_error.get('actual_bbox', 'N/A')}") print(f" 偏差描述: {pos_error.get('deviation', 'N/A')}") print() else: print(" ✅ 未发现位置错误") # 新增:格式问题 print(f"\n5. 🎨 格式问题:") if "format_issues" in result and result["format_issues"]: for i, format_issue in enumerate(result["format_issues"], 1): pos = format_issue.get('table_position', '未知位置') print(f" 🎨 格式问题 {i} - {pos}") print(f" 受影响项目: 「{format_issue.get('item', 'N/A')}」") print(f" 问题类型: {format_issue.get('issue_type', 'N/A')}") print(f" OCR格式: 「{format_issue.get('original_format', 'N/A')}」") print(f" 正确格式: 「{format_issue.get('correct_format', 'N/A')}」") print() else: print(" ✅ 未发现格式问题") # 生成错误摘要 total_errors = 0 if "errors" in result: total_errors += len(result["errors"]) if "missing_items" in result: total_errors += len(result["missing_items"]) if "position_errors" in result: total_errors += len(result["position_errors"]) if "format_issues" in result: total_errors += len(result["format_issues"]) print(f"\n" + "="*60) print(f"📈 验证摘要:") print(f" 总错误数量: {total_errors}") if "errors" in result: high_severity = len([e for e in result["errors"] if e.get('severity') == '高']) if high_severity > 0: print(f" 🔴 高严重程度错误: {high_severity} 个") if total_errors == 0: print(" 🎉 恭喜!未发现任何OCR错误") else: print(f" ⚠️ 建议:仔细检查所有标记为'高'严重程度的错误") # 如果有原始分析文本,也显示出来 if "raw_analysis" in result: print(f"\n6. 📄 VLM原始分析内容:") print("-" * 50) print(result["raw_analysis"]) print("-" * 50) if __name__ == "__main__": # 示例用法 image_path = "至远彩色印刷工业有限公司-2022年母公司_2.png" # 假设这是利润表图片 ocr_json_path = "demo_54fa7ad0_page_1.json" # OCR结果文件 try: # 进行OCR验证 result = verify_ocr_with_vlm(image_path, ocr_json_path) # 分析差异 analyze_differences("ocr_differences.json") except Exception as e: print(f"OCR验证失败: {e}")