import os
import base64
import json
import time
import re
from pathlib import Path
from openai import OpenAI
from dotenv import load_dotenv
from typing import Any, Dict, List
from normalize_financial_numbers import normalize_financial_numbers, normalize_markdown_table
# 加载环境变量
load_dotenv()
def ocr_with_vlm(image_path, output_dir="./output",
api_key=None, api_base=None, model_id=None,
temperature=0.1, max_tokens=4096, timeout=180,
normalize_numbers=True):
"""
使用VLM识别图片中的文本
Args:
image_path: 原图路径
output_dir: 结果输出文件路径
api_key: API密钥,如果为None则从环境变量获取
api_base: API基础URL,如果为None则从环境变量获取
model_id: 模型ID,如果为None则从环境变量获取
temperature: 生成温度,默认0.1
max_tokens: 最大输出token数,默认4096
timeout: 请求超时时间,默认180秒
normalize_numbers: 是否标准化数字格式,默认True
"""
# 从参数或环境变量获取API配置
api_key = api_key or os.getenv("YUSYS_MULTIMODAL_API_KEY")
api_base = api_base or os.getenv("YUSYS_MULTIMODAL_API_BASE")
model_id = model_id or os.getenv("YUSYS_MULTIMODAL_ID")
if not api_key:
raise ValueError("未找到API密钥,请通过参数传入或设置YUSYS_MULTIMODAL_API_KEY环境变量")
if not api_base:
raise ValueError("未找到API基础URL,请通过参数传入或设置YUSYS_MULTIMODAL_API_BASE环境变量")
if not model_id:
raise ValueError("未找到模型ID,请通过参数传入或设置YUSYS_MULTIMODAL_ID环境变量")
# 去掉openai/前缀
model_name = model_id.replace("openai/", "")
# 读取图片文件并转换为base64
try:
with open(image_path, "rb") as image_file:
image_data = base64.b64encode(image_file.read()).decode('utf-8')
except FileNotFoundError:
raise FileNotFoundError(f"找不到图片文件: {image_path}")
# 获取图片的MIME类型
file_extension = Path(image_path).suffix.lower()
mime_type_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}
mime_type = mime_type_map.get(file_extension, 'image/jpeg')
# 构建分析提示词
prompt = r'''You are an AI assistant specialized in converting PDF images to Markdown format. Please follow these instructions for the conversion:
1. Text Processing:
- Accurately recognize all text content in the PDF image without guessing or inferring.
- Convert the recognized text into Markdown format.
- Maintain the original document structure, including headings, paragraphs, lists, etc.
- For financial amounts, use standard half-width characters (e.g., use "," for thousands separator and "." for decimal point)
2. Mathematical Formula Processing:
- Convert all mathematical formulas to LaTeX format.
- Enclose inline formulas with \( \). For example: This is an inline formula \( E = mc^2 \)
- Enclose block formulas with \\[ \\]. For example: \[ \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \]
3. Table Processing:
- Convert tables to HTML format.
- Wrap the entire table with
.
- For financial data in tables, ensure numbers use standard format with half-width commas and periods
4. Figure Handling:
- Ignore figures content in the PDF image. Do not attempt to describe or convert images.
5. Output Format:
- Ensure the output Markdown document has a clear structure with appropriate line breaks between elements.
- For complex layouts, try to maintain the original document's structure and format as closely as possible.
- Use standard ASCII characters for punctuation and numbers
Please strictly follow these guidelines to ensure accuracy and consistency in the conversion. Your task is to accurately convert the content of the PDF image into Markdown format without adding any extra explanations or comments.
'''
# 创建OpenAI客户端
client = OpenAI(
api_key=api_key,
base_url=api_base
)
# 构建消息内容
messages: List[Dict[str, Any]] = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{image_data}"
}
}
]
}
]
try:
print(f"正在通过模型 {model_name} 进行OCR...")
print(f"API地址: {api_base}")
print(f"数字标准化: {'启用' if normalize_numbers else '禁用'}")
# 调用API
response = client.chat.completions.create(
model=model_name,
messages=messages, # type: ignore
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout
)
# 提取响应内容
generated_text = response.choices[0].message.content
if not generated_text:
raise Exception("模型没有生成文本内容")
# 标准化数字格式(如果启用)
original_text = generated_text
if normalize_numbers:
print("🔧 正在标准化数字格式...")
# generated_text = normalize_financial_numbers(generated_text)
# 只对Markdown表格进行数字标准化
generated_text = normalize_markdown_table(generated_text)
# 统计标准化的变化
changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
if changes_count > 0:
print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
else:
print("ℹ️ 无需标准化(已是标准格式)")
print(f"✅ 成功使用模型 {model_id} 完成OCR!")
# 保存结果文件
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 保存标准化后的Markdown文件
markdown_path = Path(image_path).with_suffix('.md')
markdown_path = Path(output_dir) / markdown_path.name
markdown_path = markdown_path.resolve()
with open(markdown_path, 'w', encoding='utf-8') as f:
f.write(generated_text)
# 如果启用了标准化,也保存原始版本用于对比
if normalize_numbers and original_text != generated_text:
original_markdown_path = Path(output_dir) / f"{Path(image_path).stem}_original.md"
with open(original_markdown_path, 'w', encoding='utf-8') as f:
f.write(original_text)
print(f"📄 原始OCR结果已保存到: {original_markdown_path}")
# 准备元数据
ocr_result: Dict[str, Any] = {
"processing_info": {
"normalize_numbers": normalize_numbers,
"changes_applied": original_text != generated_text if normalize_numbers else False,
"character_changes_count": len([1 for o, n in zip(original_text, generated_text) if o != n]) if normalize_numbers else 0
}
}
result_path = Path(image_path).with_suffix('.json')
result_path = Path(output_dir) / result_path.name
result_path = result_path.resolve()
# 添加元数据
ocr_result["metadata"] = {
"model_used": model_id,
"api_base": api_base,
"temperature": temperature,
"max_tokens": max_tokens,
"timeout": timeout,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"original_image": Path(image_path).resolve().as_posix(),
"output_path": Path(markdown_path).resolve().as_posix(),
"normalize_numbers": normalize_numbers
}
# 保存结果
with open(result_path, 'w', encoding='utf-8') as f:
json.dump(ocr_result, f, ensure_ascii=False, indent=2)
print(f"📄 OCR结果已保存到: {markdown_path}")
print(f"📊 元数据已保存到: {result_path}")
# 打印详细统计
print("\n📊 OCR处理统计")
print(f" 原始图片: {ocr_result['metadata']['original_image']}")
print(f" 输出路径: {ocr_result['metadata']['output_path']}")
print(f" 使用模型: {ocr_result['metadata']['model_used']}")
print(f" 数字标准化: {ocr_result['metadata']['normalize_numbers']}")
if normalize_numbers:
print(f" 字符变化数: {ocr_result['processing_info']['character_changes_count']}")
print(f" 应用了标准化: {ocr_result['processing_info']['changes_applied']}")
print(f" 处理时间: {ocr_result['metadata']['timestamp']}")
return ocr_result
except Exception as e:
import traceback
traceback.print_exc()
raise Exception(f"OCR任务失败: {e}")
def batch_normalize_existing_files(input_dir: str, output_dir: str = None):
"""
批量标准化已有的Markdown文件中的数字格式
Args:
input_dir: 输入目录
output_dir: 输出目录,如果为None则覆盖原文件
"""
input_path = Path(input_dir)
output_path = Path(output_dir) if output_dir else input_path
if not input_path.exists():
raise ValueError(f"输入目录不存在: {input_dir}")
output_path.mkdir(parents=True, exist_ok=True)
md_files = list(input_path.glob("*.md"))
if not md_files:
print("⚠️ 未找到Markdown文件")
return
print(f"🔧 开始批量标准化 {len(md_files)} 个Markdown文件...")
for md_file in md_files:
print(f" 处理: {md_file.name}")
# 读取原文件
with open(md_file, 'r', encoding='utf-8') as f:
original_content = f.read()
# 标准化内容
normalized_content = normalize_financial_numbers(original_content)
normalized_content = normalize_markdown_table(normalized_content)
# 保存标准化后的文件
output_file = output_path / md_file.name
with open(output_file, 'w', encoding='utf-8') as f:
f.write(normalized_content)
# 统计变化
changes = len([1 for o, n in zip(original_content, normalized_content) if o != n])
if changes > 0:
print(f" ✅ 标准化了 {changes} 个字符")
else:
print(f" ℹ️ 无需更改")
print(f"✅ 批量标准化完成!结果保存到: {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='VLM OCR识别工具')
parser.add_argument('image_path', nargs='?', help='图片文件路径')
parser.add_argument('-o', '--output', default='./output', help='输出目录')
parser.add_argument('-t', '--temperature', type=float, default=0.1, help='生成温度')
parser.add_argument('-m', '--max-tokens', type=int, default=4096, help='最大token数')
parser.add_argument('--timeout', type=int, default=180, help='超时时间(秒)')
parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
parser.add_argument('--batch-normalize', help='批量标准化指定目录中的Markdown文件')
args = parser.parse_args()
if args.batch_normalize:
# 批量标准化模式
batch_normalize_existing_files(args.batch_normalize, args.output)
elif args.image_path:
# 单文件OCR模式
try:
result = ocr_with_vlm(
image_path=args.image_path,
output_dir=args.output,
temperature=args.temperature,
max_tokens=args.max_tokens,
timeout=args.timeout,
normalize_numbers=not args.no_normalize
)
print("\n🎉 OCR识别完成!")
except Exception as e:
print(f"❌ OCR识别失败: {e}")
else:
# 默认示例
image_path = "sample_data/至远彩色印刷工业有限公司-2022年母公司_2.png"
try:
result = ocr_with_vlm(image_path)
print("\n🎉 OCR识别完成!")
except Exception as e:
print(f"❌ OCR识别失败: {e}")