ocr_by_vlm.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. import os
  2. import base64
  3. import json
  4. import time
  5. import re
  6. from pathlib import Path
  7. from openai import OpenAI
  8. from dotenv import load_dotenv
  9. from typing import Any, Dict, List
  10. # 加载环境变量
  11. load_dotenv()
  12. def normalize_financial_numbers(text: str) -> str:
  13. """
  14. 标准化财务数字:将全角字符转换为半角字符
  15. Args:
  16. text: 原始文本
  17. Returns:
  18. 标准化后的文本
  19. """
  20. if not text:
  21. return text
  22. # 定义全角到半角的映射
  23. fullwidth_to_halfwidth = {
  24. '0': '0', '1': '1', '2': '2', '3': '3', '4': '4',
  25. '5': '5', '6': '6', '7': '7', '8': '8', '9': '9',
  26. ',': ',', # 全角逗号转半角逗号
  27. '。': '.', # 全角句号转半角句号
  28. '.': '.', # 全角句点转半角句点
  29. ':': ':', # 全角冒号转半角冒号
  30. ';': ';', # 全角分号转半角分号
  31. '(': '(', # 全角左括号转半角左括号
  32. ')': ')', # 全角右括号转半角右括号
  33. '-': '-', # 全角减号转半角减号
  34. '+': '+', # 全角加号转半角加号
  35. '%': '%', # 全角百分号转半角百分号
  36. }
  37. # 执行字符替换
  38. normalized_text = text
  39. for fullwidth, halfwidth in fullwidth_to_halfwidth.items():
  40. normalized_text = normalized_text.replace(fullwidth, halfwidth)
  41. # 特别处理金额格式:识别数字模式并标准化
  42. # 匹配金额模式:数字+全角逗号+数字+小数点+数字
  43. amount_pattern = r'(\d+(?:[,,]\d{3})*(?:[。..]\d{2})?)'
  44. def normalize_amount(match):
  45. amount = match.group(1)
  46. # 将全角逗号替换为半角逗号
  47. amount = amount.replace(',', ',')
  48. # 将全角句号、句点替换为半角小数点
  49. amount = re.sub(r'[。.]', '.', amount)
  50. return amount
  51. normalized_text = re.sub(amount_pattern, normalize_amount, normalized_text)
  52. return normalized_text
  53. def normalize_markdown_table(markdown_content: str) -> str:
  54. """
  55. 专门处理Markdown表格中的数字标准化
  56. Args:
  57. markdown_content: Markdown内容
  58. Returns:
  59. 标准化后的Markdown内容
  60. """
  61. # 使用BeautifulSoup处理HTML表格
  62. from bs4 import BeautifulSoup
  63. soup = BeautifulSoup(markdown_content, 'html.parser')
  64. tables = soup.find_all('table')
  65. for table in tables:
  66. cells = table.find_all(['td', 'th'])
  67. for cell in cells:
  68. original_text = cell.get_text()
  69. normalized_text = normalize_financial_numbers(original_text)
  70. # 如果内容发生了变化,更新单元格内容
  71. if original_text != normalized_text:
  72. cell.string = normalized_text
  73. # 返回更新后的HTML
  74. return str(soup)
  75. def ocr_with_vlm(image_path, output_dir="./output",
  76. api_key=None, api_base=None, model_id=None,
  77. temperature=0.1, max_tokens=4096, timeout=180,
  78. normalize_numbers=True):
  79. """
  80. 使用VLM识别图片中的文本
  81. Args:
  82. image_path: 原图路径
  83. output_dir: 结果输出文件路径
  84. api_key: API密钥,如果为None则从环境变量获取
  85. api_base: API基础URL,如果为None则从环境变量获取
  86. model_id: 模型ID,如果为None则从环境变量获取
  87. temperature: 生成温度,默认0.1
  88. max_tokens: 最大输出token数,默认4096
  89. timeout: 请求超时时间,默认180秒
  90. normalize_numbers: 是否标准化数字格式,默认True
  91. """
  92. # 从参数或环境变量获取API配置
  93. api_key = api_key or os.getenv("YUSYS_MULTIMODAL_API_KEY")
  94. api_base = api_base or os.getenv("YUSYS_MULTIMODAL_API_BASE")
  95. model_id = model_id or os.getenv("YUSYS_MULTIMODAL_ID")
  96. if not api_key:
  97. raise ValueError("未找到API密钥,请通过参数传入或设置YUSYS_MULTIMODAL_API_KEY环境变量")
  98. if not api_base:
  99. raise ValueError("未找到API基础URL,请通过参数传入或设置YUSYS_MULTIMODAL_API_BASE环境变量")
  100. if not model_id:
  101. raise ValueError("未找到模型ID,请通过参数传入或设置YUSYS_MULTIMODAL_ID环境变量")
  102. # 去掉openai/前缀
  103. model_name = model_id.replace("openai/", "")
  104. # 读取图片文件并转换为base64
  105. try:
  106. with open(image_path, "rb") as image_file:
  107. image_data = base64.b64encode(image_file.read()).decode('utf-8')
  108. except FileNotFoundError:
  109. raise FileNotFoundError(f"找不到图片文件: {image_path}")
  110. # 获取图片的MIME类型
  111. file_extension = Path(image_path).suffix.lower()
  112. mime_type_map = {
  113. '.jpg': 'image/jpeg',
  114. '.jpeg': 'image/jpeg',
  115. '.png': 'image/png',
  116. '.gif': 'image/gif',
  117. '.webp': 'image/webp'
  118. }
  119. mime_type = mime_type_map.get(file_extension, 'image/jpeg')
  120. # 构建分析提示词
  121. prompt = r'''You are an AI assistant specialized in converting PDF images to Markdown format. Please follow these instructions for the conversion:
  122. 1. Text Processing:
  123. - Accurately recognize all text content in the PDF image without guessing or inferring.
  124. - Convert the recognized text into Markdown format.
  125. - Maintain the original document structure, including headings, paragraphs, lists, etc.
  126. - For financial amounts, use standard half-width characters (e.g., use "," for thousands separator and "." for decimal point)
  127. 2. Mathematical Formula Processing:
  128. - Convert all mathematical formulas to LaTeX format.
  129. - Enclose inline formulas with \( \). For example: This is an inline formula \( E = mc^2 \)
  130. - Enclose block formulas with \\[ \\]. For example: \[ \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \]
  131. 3. Table Processing:
  132. - Convert tables to HTML format.
  133. - Wrap the entire table with <table> and </table>.
  134. - For financial data in tables, ensure numbers use standard format with half-width commas and periods
  135. 4. Figure Handling:
  136. - Ignore figures content in the PDF image. Do not attempt to describe or convert images.
  137. 5. Output Format:
  138. - Ensure the output Markdown document has a clear structure with appropriate line breaks between elements.
  139. - For complex layouts, try to maintain the original document's structure and format as closely as possible.
  140. - Use standard ASCII characters for punctuation and numbers
  141. Please strictly follow these guidelines to ensure accuracy and consistency in the conversion. Your task is to accurately convert the content of the PDF image into Markdown format without adding any extra explanations or comments.
  142. '''
  143. # 创建OpenAI客户端
  144. client = OpenAI(
  145. api_key=api_key,
  146. base_url=api_base
  147. )
  148. # 构建消息内容
  149. messages: List[Dict[str, Any]] = [
  150. {
  151. "role": "user",
  152. "content": [
  153. {
  154. "type": "text",
  155. "text": prompt
  156. },
  157. {
  158. "type": "image_url",
  159. "image_url": {
  160. "url": f"data:{mime_type};base64,{image_data}"
  161. }
  162. }
  163. ]
  164. }
  165. ]
  166. try:
  167. print(f"正在通过模型 {model_name} 进行OCR...")
  168. print(f"API地址: {api_base}")
  169. print(f"数字标准化: {'启用' if normalize_numbers else '禁用'}")
  170. # 调用API
  171. response = client.chat.completions.create(
  172. model=model_name,
  173. messages=messages, # type: ignore
  174. temperature=temperature,
  175. max_tokens=max_tokens,
  176. timeout=timeout
  177. )
  178. # 提取响应内容
  179. generated_text = response.choices[0].message.content
  180. if not generated_text:
  181. raise Exception("模型没有生成文本内容")
  182. # 标准化数字格式(如果启用)
  183. original_text = generated_text
  184. if normalize_numbers:
  185. print("🔧 正在标准化数字格式...")
  186. # generated_text = normalize_financial_numbers(generated_text)
  187. # 只对Markdown表格进行数字标准化
  188. generated_text = normalize_markdown_table(generated_text)
  189. # 统计标准化的变化
  190. changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
  191. if changes_count > 0:
  192. print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
  193. else:
  194. print("ℹ️ 无需标准化(已是标准格式)")
  195. print(f"✅ 成功使用模型 {model_id} 完成OCR!")
  196. # 保存结果文件
  197. Path(output_dir).mkdir(parents=True, exist_ok=True)
  198. # 保存标准化后的Markdown文件
  199. markdown_path = Path(image_path).with_suffix('.md')
  200. markdown_path = Path(output_dir) / markdown_path.name
  201. markdown_path = markdown_path.resolve()
  202. with open(markdown_path, 'w', encoding='utf-8') as f:
  203. f.write(generated_text)
  204. # 如果启用了标准化,也保存原始版本用于对比
  205. if normalize_numbers and original_text != generated_text:
  206. original_markdown_path = Path(output_dir) / f"{Path(image_path).stem}_original.md"
  207. with open(original_markdown_path, 'w', encoding='utf-8') as f:
  208. f.write(original_text)
  209. print(f"📄 原始OCR结果已保存到: {original_markdown_path}")
  210. # 准备元数据
  211. ocr_result: Dict[str, Any] = {
  212. "processing_info": {
  213. "normalize_numbers": normalize_numbers,
  214. "changes_applied": original_text != generated_text if normalize_numbers else False,
  215. "character_changes_count": len([1 for o, n in zip(original_text, generated_text) if o != n]) if normalize_numbers else 0
  216. }
  217. }
  218. result_path = Path(image_path).with_suffix('.json')
  219. result_path = Path(output_dir) / result_path.name
  220. result_path = result_path.resolve()
  221. # 添加元数据
  222. ocr_result["metadata"] = {
  223. "model_used": model_id,
  224. "api_base": api_base,
  225. "temperature": temperature,
  226. "max_tokens": max_tokens,
  227. "timeout": timeout,
  228. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
  229. "original_image": Path(image_path).resolve().as_posix(),
  230. "output_path": Path(markdown_path).resolve().as_posix(),
  231. "normalize_numbers": normalize_numbers
  232. }
  233. # 保存结果
  234. with open(result_path, 'w', encoding='utf-8') as f:
  235. json.dump(ocr_result, f, ensure_ascii=False, indent=2)
  236. print(f"📄 OCR结果已保存到: {markdown_path}")
  237. print(f"📊 元数据已保存到: {result_path}")
  238. # 打印详细统计
  239. print("\n📊 OCR处理统计")
  240. print(f" 原始图片: {ocr_result['metadata']['original_image']}")
  241. print(f" 输出路径: {ocr_result['metadata']['output_path']}")
  242. print(f" 使用模型: {ocr_result['metadata']['model_used']}")
  243. print(f" 数字标准化: {ocr_result['metadata']['normalize_numbers']}")
  244. if normalize_numbers:
  245. print(f" 字符变化数: {ocr_result['processing_info']['character_changes_count']}")
  246. print(f" 应用了标准化: {ocr_result['processing_info']['changes_applied']}")
  247. print(f" 处理时间: {ocr_result['metadata']['timestamp']}")
  248. return ocr_result
  249. except Exception as e:
  250. import traceback
  251. traceback.print_exc()
  252. raise Exception(f"OCR任务失败: {e}")
  253. def batch_normalize_existing_files(input_dir: str, output_dir: str = None):
  254. """
  255. 批量标准化已有的Markdown文件中的数字格式
  256. Args:
  257. input_dir: 输入目录
  258. output_dir: 输出目录,如果为None则覆盖原文件
  259. """
  260. input_path = Path(input_dir)
  261. output_path = Path(output_dir) if output_dir else input_path
  262. if not input_path.exists():
  263. raise ValueError(f"输入目录不存在: {input_dir}")
  264. output_path.mkdir(parents=True, exist_ok=True)
  265. md_files = list(input_path.glob("*.md"))
  266. if not md_files:
  267. print("⚠️ 未找到Markdown文件")
  268. return
  269. print(f"🔧 开始批量标准化 {len(md_files)} 个Markdown文件...")
  270. for md_file in md_files:
  271. print(f" 处理: {md_file.name}")
  272. # 读取原文件
  273. with open(md_file, 'r', encoding='utf-8') as f:
  274. original_content = f.read()
  275. # 标准化内容
  276. normalized_content = normalize_financial_numbers(original_content)
  277. normalized_content = normalize_markdown_table(normalized_content)
  278. # 保存标准化后的文件
  279. output_file = output_path / md_file.name
  280. with open(output_file, 'w', encoding='utf-8') as f:
  281. f.write(normalized_content)
  282. # 统计变化
  283. changes = len([1 for o, n in zip(original_content, normalized_content) if o != n])
  284. if changes > 0:
  285. print(f" ✅ 标准化了 {changes} 个字符")
  286. else:
  287. print(f" ℹ️ 无需更改")
  288. print(f"✅ 批量标准化完成!结果保存到: {output_path}")
  289. if __name__ == "__main__":
  290. import argparse
  291. parser = argparse.ArgumentParser(description='VLM OCR识别工具')
  292. parser.add_argument('image_path', nargs='?', help='图片文件路径')
  293. parser.add_argument('-o', '--output', default='./output', help='输出目录')
  294. parser.add_argument('-t', '--temperature', type=float, default=0.1, help='生成温度')
  295. parser.add_argument('-m', '--max-tokens', type=int, default=4096, help='最大token数')
  296. parser.add_argument('--timeout', type=int, default=180, help='超时时间(秒)')
  297. parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
  298. parser.add_argument('--batch-normalize', help='批量标准化指定目录中的Markdown文件')
  299. args = parser.parse_args()
  300. if args.batch_normalize:
  301. # 批量标准化模式
  302. batch_normalize_existing_files(args.batch_normalize, args.output)
  303. elif args.image_path:
  304. # 单文件OCR模式
  305. try:
  306. result = ocr_with_vlm(
  307. image_path=args.image_path,
  308. output_dir=args.output,
  309. temperature=args.temperature,
  310. max_tokens=args.max_tokens,
  311. timeout=args.timeout,
  312. normalize_numbers=not args.no_normalize
  313. )
  314. print("\n🎉 OCR识别完成!")
  315. except Exception as e:
  316. print(f"❌ OCR识别失败: {e}")
  317. else:
  318. # 默认示例
  319. image_path = "sample_data/至远彩色印刷工业有限公司-2022年母公司_2.png"
  320. try:
  321. result = ocr_with_vlm(image_path)
  322. print("\n🎉 OCR识别完成!")
  323. except Exception as e:
  324. print(f"❌ OCR识别失败: {e}")