ocr_verification.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. import os
  2. import base64
  3. import json
  4. import time
  5. from pathlib import Path
  6. from openai import OpenAI
  7. from dotenv import load_dotenv
  8. from typing import Any, Dict, List
  9. # 加载环境变量
  10. load_dotenv()
  11. def verify_ocr_with_vlm(image_path, ocr_json_path, output_path="ocr_differences.json",
  12. api_key=None, api_base=None, model_id=None,
  13. temperature=0.1, max_tokens=4096, timeout=180):
  14. """
  15. 使用VLM对比OCR识别结果和原图,找出差异部分
  16. Args:
  17. image_path: 原图路径
  18. ocr_json_path: OCR识别结果JSON文件路径
  19. output_path: 差异分析输出文件路径
  20. api_key: API密钥,如果为None则从环境变量获取
  21. api_base: API基础URL,如果为None则从环境变量获取
  22. model_id: 模型ID,如果为None则从环境变量获取
  23. temperature: 生成温度,默认0.1
  24. max_tokens: 最大输出token数,默认4096
  25. timeout: 请求超时时间,默认180秒
  26. """
  27. # 从参数或环境变量获取API配置
  28. api_key = api_key or os.getenv("YUSYS_MULTIMODAL_API_KEY")
  29. api_base = api_base or os.getenv("YUSYS_MULTIMODAL_API_BASE")
  30. model_id = model_id or os.getenv("YUSYS_MULTIMODAL_ID")
  31. if not api_key:
  32. raise ValueError("未找到API密钥,请通过参数传入或设置YUSYS_MULTIMODAL_API_KEY环境变量")
  33. if not api_base:
  34. raise ValueError("未找到API基础URL,请通过参数传入或设置YUSYS_MULTIMODAL_API_BASE环境变量")
  35. if not model_id:
  36. raise ValueError("未找到模型ID,请通过参数传入或设置YUSYS_MULTIMODAL_ID环境变量")
  37. # 去掉openai/前缀
  38. model_name = model_id.replace("openai/", "")
  39. # 读取图片文件并转换为base64
  40. try:
  41. with open(image_path, "rb") as image_file:
  42. image_data = base64.b64encode(image_file.read()).decode('utf-8')
  43. except FileNotFoundError:
  44. raise FileNotFoundError(f"找不到图片文件: {image_path}")
  45. # 读取OCR结果
  46. try:
  47. with open(ocr_json_path, "r", encoding='utf-8') as f:
  48. ocr_results = json.load(f)
  49. except FileNotFoundError:
  50. raise FileNotFoundError(f"找不到OCR结果文件: {ocr_json_path}")
  51. # 获取图片的MIME类型
  52. file_extension = Path(image_path).suffix.lower()
  53. mime_type_map = {
  54. '.jpg': 'image/jpeg',
  55. '.jpeg': 'image/jpeg',
  56. '.png': 'image/png',
  57. '.gif': 'image/gif',
  58. '.webp': 'image/webp'
  59. }
  60. mime_type = mime_type_map.get(file_extension, 'image/jpeg')
  61. # 构建详细的OCR结果文本,包含位置信息
  62. ocr_text = "OCR识别结果:\n"
  63. for item in ocr_results:
  64. bbox = item.get('bbox', [])
  65. category = item.get('category', '')
  66. text = item.get('text', '')
  67. ocr_text += f"位置坐标[{bbox}] - 类别: {category} - 文本: {text}\n"
  68. # 构建分析提示词
  69. prompt = f"""请仔细分析这张图片,并与以下OCR识别结果进行逐项详细对比:
  70. {ocr_text}
  71. 重要要求:
  72. 1. 对于表格中的每一个数据项(特别是数字、金额、项目名称),都必须逐一验证
  73. 2. 即使发现微小差异也要报告(如小数点位数、千分符、标点符号等)
  74. 3. 对于表格结构要仔细检查行列对应关系
  75. 4. 必须输出所有发现的问题,不要遗漏任何错误
  76. 请执行以下详细任务:
  77. 1. 逐行逐列验证表格中的每个数据项是否准确
  78. 2. 检查数字格式:小数点、千分符、负号等
  79. 3. 验证项目名称的完整性和准确性
  80. 4. 检查表格标题和表头信息
  81. 5. 验证行次编号的正确性
  82. 6. 识别任何遗漏的表格内容
  83. 7. 检查文本的位置坐标是否与实际位置匹配
  84. 请以JSON格式返回详细分析结果,对于表格中的每个识别项都要有明确的验证结果:
  85. {{
  86. "table_verification": {{
  87. "total_items_checked": 数字,
  88. "accuracy_rate": "百分比",
  89. "table_structure_correct": true/false
  90. }},
  91. "accurate_items": [
  92. {{
  93. "bbox": [x1, y1, x2, y2],
  94. "original_text": "OCR识别的文本",
  95. "verified": true,
  96. "table_position": "行X列Y",
  97. "notes": "验证说明"
  98. }}
  99. ],
  100. "errors": [
  101. {{
  102. "bbox": [x1, y1, x2, y2],
  103. "original_text": "OCR识别的文本",
  104. "correct_text": "正确的文本",
  105. "error_type": "数字错误|格式错误|字符错误|缺失内容",
  106. "table_position": "行X列Y",
  107. "severity": "高|中|低",
  108. "confidence": "高|中|低",
  109. "detailed_description": "详细错误描述"
  110. }}
  111. ],
  112. "missing_items": [
  113. {{
  114. "estimated_bbox": [x1, y1, x2, y2],
  115. "missing_text": "遗漏的文本",
  116. "category": "数字|文本|标题|其他",
  117. "table_position": "行X列Y",
  118. "importance": "高|中|低",
  119. "impact": "对数据准确性的影响描述"
  120. }}
  121. ],
  122. "position_errors": [
  123. {{
  124. "text": "文本内容",
  125. "ocr_bbox": [x1, y1, x2, y2],
  126. "actual_bbox": [x1, y1, x2, y2],
  127. "table_position": "行X列Y",
  128. "deviation": "具体偏差描述"
  129. }}
  130. ],
  131. "format_issues": [
  132. {{
  133. "item": "受影响的项目",
  134. "issue_type": "千分符缺失|小数位错误|符号错误",
  135. "original_format": "OCR识别的格式",
  136. "correct_format": "正确格式",
  137. "table_position": "行X列Y"
  138. }}
  139. ]
  140. }}
  141. 特别注意:
  142. - 表格中的所有金额数据都必须精确验证
  143. - 对于大额数字,要特别注意千分符和小数点
  144. - 项目名称要检查是否完整,包括括号、冒号等标点符号
  145. - 行次编号必须与对应项目正确匹配
  146. - 任何发现的问题都要详细记录,包括位置和具体错误内容"""
  147. # 创建OpenAI客户端
  148. client = OpenAI(
  149. api_key=api_key,
  150. base_url=api_base
  151. )
  152. # 构建消息内容
  153. messages: List[Dict[str, Any]] = [
  154. {
  155. "role": "user",
  156. "content": [
  157. {
  158. "type": "text",
  159. "text": prompt
  160. },
  161. {
  162. "type": "image_url",
  163. "image_url": {
  164. "url": f"data:{mime_type};base64,{image_data}"
  165. }
  166. }
  167. ]
  168. }
  169. ]
  170. try:
  171. print(f"正在使用模型 {model_name} 进行OCR验证...")
  172. print(f"API地址: {api_base}")
  173. # 调用API
  174. response = client.chat.completions.create(
  175. model=model_name,
  176. messages=messages, # type: ignore
  177. temperature=temperature,
  178. max_tokens=max_tokens,
  179. timeout=timeout
  180. )
  181. # 提取响应内容
  182. generated_text = response.choices[0].message.content
  183. if not generated_text:
  184. raise Exception("模型没有生成文本内容")
  185. print(f"成功使用模型 {model_name} 完成OCR验证!")
  186. # 尝试解析JSON结果
  187. verification_result: Dict[str, Any] = {}
  188. try:
  189. # 查找JSON部分
  190. json_start = generated_text.find('{')
  191. json_end = generated_text.rfind('}') + 1
  192. if json_start != -1 and json_end > json_start:
  193. json_content = generated_text[json_start:json_end]
  194. verification_result = json.loads(json_content)
  195. else:
  196. # 如果没有找到JSON,创建一个包含原始文本的结果
  197. verification_result = {
  198. "raw_analysis": generated_text,
  199. "parsing_note": "无法解析为标准JSON格式,返回原始分析文本"
  200. }
  201. except json.JSONDecodeError:
  202. verification_result = {
  203. "raw_analysis": generated_text,
  204. "parsing_note": "JSON解析失败,返回原始分析文本"
  205. }
  206. # 添加元数据
  207. verification_result["metadata"] = {
  208. "model_used": model_name,
  209. "model_id": model_id,
  210. "api_base": api_base,
  211. "temperature": temperature,
  212. "max_tokens": max_tokens,
  213. "timeout": timeout,
  214. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
  215. "original_image": image_path,
  216. "ocr_json": ocr_json_path
  217. }
  218. # 保存结果
  219. with open(output_path, 'w', encoding='utf-8') as f:
  220. json.dump(verification_result, f, ensure_ascii=False, indent=2)
  221. print(f"OCR验证结果已保存到: {output_path}")
  222. # 打印详细统计
  223. print("\n📊 验证结果统计:")
  224. if "table_verification" in verification_result:
  225. tv = verification_result["table_verification"]
  226. print(f" 📋 表格检查项目: {tv.get('total_items_checked', 'N/A')}")
  227. print(f" 📈 准确率: {tv.get('accuracy_rate', 'N/A')}")
  228. if "accurate_items" in verification_result:
  229. print(f" ✅ 准确项目数量: {len(verification_result['accurate_items'])}")
  230. if "errors" in verification_result:
  231. error_count = len(verification_result['errors'])
  232. print(f" ❌ 发现错误数量: {error_count}")
  233. if error_count > 0 and verification_result['errors']:
  234. high_errors = len([e for e in verification_result['errors'] if e.get('severity') == '高'])
  235. if high_errors > 0:
  236. print(f" 🔴 高严重程度错误: {high_errors}")
  237. if "missing_items" in verification_result:
  238. missing_count = len(verification_result['missing_items'])
  239. print(f" 📝 遗漏项目数量: {missing_count}")
  240. if missing_count > 0 and verification_result['missing_items']:
  241. high_missing = len([m for m in verification_result['missing_items'] if m.get('importance') == '高'])
  242. if high_missing > 0:
  243. print(f" 🔴 重要遗漏项目: {high_missing}")
  244. if "position_errors" in verification_result:
  245. print(f" 📍 位置错误数量: {len(verification_result['position_errors'])}")
  246. if "format_issues" in verification_result:
  247. print(f" 🎨 格式问题数量: {len(verification_result['format_issues'])}")
  248. return verification_result
  249. except Exception as e:
  250. print(f"OCR验证失败: {e}")
  251. raise Exception(f"OCR验证任务失败: {e}")
  252. def analyze_differences(verification_result_path):
  253. """
  254. 分析OCR验证结果,生成详细的表格差异报告
  255. Args:
  256. verification_result_path: 验证结果JSON文件路径
  257. """
  258. try:
  259. with open(verification_result_path, 'r', encoding='utf-8') as f:
  260. result = json.load(f)
  261. except FileNotFoundError:
  262. raise FileNotFoundError(f"找不到验证结果文件: {verification_result_path}")
  263. print("\n" + "="*60)
  264. print("OCR验证详细差异分析报告")
  265. print("="*60)
  266. if "metadata" in result:
  267. print(f"分析时间: {result['metadata']['timestamp']}")
  268. print(f"使用模型: {result['metadata']['model_used']}")
  269. print(f"原始图片: {result['metadata']['original_image']}")
  270. print(f"OCR结果: {result['metadata']['ocr_json']}")
  271. print(f"分析配置: 温度={result['metadata'].get('temperature', 'N/A')}, "
  272. f"最大tokens={result['metadata'].get('max_tokens', 'N/A')}")
  273. # 表格验证总览
  274. if "table_verification" in result:
  275. tv = result["table_verification"]
  276. print(f"\n📊 表格验证总览:")
  277. print(f" 检查项目总数: {tv.get('total_items_checked', 'N/A')}")
  278. print(f" 准确率: {tv.get('accuracy_rate', 'N/A')}")
  279. print(f" 表格结构正确: {'✅' if tv.get('table_structure_correct') else '❌'}")
  280. print(f"\n1. ✅ 准确识别项目:")
  281. if "accurate_items" in result and result["accurate_items"]:
  282. for i, item in enumerate(result["accurate_items"], 1):
  283. pos = item.get('table_position', '未知位置')
  284. bbox = item.get('bbox', 'N/A')
  285. text = item.get('original_text', 'N/A')
  286. print(f" {i}. {pos} - 位置{bbox}")
  287. print(f" 内容: {text}")
  288. if item.get('notes'):
  289. print(f" 说明: {item['notes']}")
  290. print()
  291. else:
  292. print(" 无准确识别项目或未能解析")
  293. print(f"\n2. ❌ 识别错误 (需要重点关注):")
  294. if "errors" in result and result["errors"]:
  295. for i, error in enumerate(result["errors"], 1):
  296. pos = error.get('table_position', '未知位置')
  297. severity = error.get('severity', '未知')
  298. severity_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(severity, "⚪")
  299. print(f" {severity_icon} 错误 {i} - {pos}")
  300. print(f" 位置坐标: {error.get('bbox', 'N/A')}")
  301. print(f" OCR识别: 「{error.get('original_text', 'N/A')}」")
  302. print(f" 正确内容: 「{error.get('correct_text', 'N/A')}」")
  303. print(f" 错误类型: {error.get('error_type', 'N/A')}")
  304. print(f" 严重程度: {severity}")
  305. print(f" 置信度: {error.get('confidence', 'N/A')}")
  306. if error.get('detailed_description'):
  307. print(f" 详细说明: {error['detailed_description']}")
  308. print()
  309. else:
  310. print(" ✅ 未发现识别错误")
  311. print(f"\n3. 📝 遗漏项目:")
  312. if "missing_items" in result and result["missing_items"]:
  313. for i, missing in enumerate(result["missing_items"], 1):
  314. pos = missing.get('table_position', '未知位置')
  315. importance = missing.get('importance', '未知')
  316. importance_icon = {"高": "🔴", "中": "🟡", "低": "🟢"}.get(importance, "⚪")
  317. print(f" {importance_icon} 遗漏 {i} - {pos}")
  318. print(f" 预估位置: {missing.get('estimated_bbox', 'N/A')}")
  319. print(f" 遗漏内容: 「{missing.get('missing_text', 'N/A')}」")
  320. print(f" 内容类别: {missing.get('category', 'N/A')}")
  321. print(f" 重要程度: {importance}")
  322. if missing.get('impact'):
  323. print(f" 影响说明: {missing['impact']}")
  324. print()
  325. else:
  326. print(" ✅ 未发现遗漏项目")
  327. print(f"\n4. 📍 位置错误:")
  328. if "position_errors" in result and result["position_errors"]:
  329. for i, pos_error in enumerate(result["position_errors"], 1):
  330. table_pos = pos_error.get('table_position', '未知位置')
  331. print(f" 📍 位置错误 {i} - {table_pos}")
  332. print(f" 文本内容: 「{pos_error.get('text', 'N/A')}」")
  333. print(f" OCR位置: {pos_error.get('ocr_bbox', 'N/A')}")
  334. print(f" 实际位置: {pos_error.get('actual_bbox', 'N/A')}")
  335. print(f" 偏差描述: {pos_error.get('deviation', 'N/A')}")
  336. print()
  337. else:
  338. print(" ✅ 未发现位置错误")
  339. # 新增:格式问题
  340. print(f"\n5. 🎨 格式问题:")
  341. if "format_issues" in result and result["format_issues"]:
  342. for i, format_issue in enumerate(result["format_issues"], 1):
  343. pos = format_issue.get('table_position', '未知位置')
  344. print(f" 🎨 格式问题 {i} - {pos}")
  345. print(f" 受影响项目: 「{format_issue.get('item', 'N/A')}」")
  346. print(f" 问题类型: {format_issue.get('issue_type', 'N/A')}")
  347. print(f" OCR格式: 「{format_issue.get('original_format', 'N/A')}」")
  348. print(f" 正确格式: 「{format_issue.get('correct_format', 'N/A')}」")
  349. print()
  350. else:
  351. print(" ✅ 未发现格式问题")
  352. # 生成错误摘要
  353. total_errors = 0
  354. if "errors" in result:
  355. total_errors += len(result["errors"])
  356. if "missing_items" in result:
  357. total_errors += len(result["missing_items"])
  358. if "position_errors" in result:
  359. total_errors += len(result["position_errors"])
  360. if "format_issues" in result:
  361. total_errors += len(result["format_issues"])
  362. print(f"\n" + "="*60)
  363. print(f"📈 验证摘要:")
  364. print(f" 总错误数量: {total_errors}")
  365. if "errors" in result:
  366. high_severity = len([e for e in result["errors"] if e.get('severity') == '高'])
  367. if high_severity > 0:
  368. print(f" 🔴 高严重程度错误: {high_severity} 个")
  369. if total_errors == 0:
  370. print(" 🎉 恭喜!未发现任何OCR错误")
  371. else:
  372. print(f" ⚠️ 建议:仔细检查所有标记为'高'严重程度的错误")
  373. # 如果有原始分析文本,也显示出来
  374. if "raw_analysis" in result:
  375. print(f"\n6. 📄 VLM原始分析内容:")
  376. print("-" * 50)
  377. print(result["raw_analysis"])
  378. print("-" * 50)
  379. if __name__ == "__main__":
  380. # 示例用法
  381. image_path = "至远彩色印刷工业有限公司-2022年母公司_2.png" # 假设这是利润表图片
  382. ocr_json_path = "demo_54fa7ad0_page_1.json" # OCR结果文件
  383. try:
  384. # 进行OCR验证
  385. result = verify_ocr_with_vlm(image_path, ocr_json_path)
  386. # 分析差异
  387. analyze_differences("ocr_differences.json")
  388. except Exception as e:
  389. print(f"OCR验证失败: {e}")