local_vlm_processor.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. import os
  2. import re
  3. import yaml
  4. import base64
  5. import json
  6. import time
  7. import argparse
  8. from pathlib import Path
  9. from typing import Dict, Any, Optional
  10. from openai import OpenAI
  11. from dotenv import load_dotenv
  12. # 加载环境变量
  13. load_dotenv(override=True)
  14. class LocalVLMProcessor:
  15. def __init__(self, config_path: str = "config.yaml"):
  16. """
  17. 初始化本地VLM处理器
  18. Args:
  19. config_path: 配置文件路径
  20. """
  21. self.config_path = Path(config_path)
  22. self.config = self._load_config()
  23. def _load_config(self) -> Dict[str, Any]:
  24. """加载配置文件"""
  25. if not self.config_path.exists():
  26. raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
  27. with open(self.config_path, 'r', encoding='utf-8') as f:
  28. config = yaml.safe_load(f)
  29. return config
  30. def _resolve_env_variable(self, value: str) -> str:
  31. """
  32. 解析环境变量:将 ${VAR_NAME} 格式替换为实际的环境变量值
  33. Args:
  34. value: 可能包含环境变量的字符串
  35. Returns:
  36. 解析后的字符串
  37. """
  38. if not isinstance(value, str):
  39. return value
  40. # 匹配 ${VAR_NAME} 格式的环境变量
  41. pattern = r'\$\{([^}]+)\}'
  42. def replace_env_var(match):
  43. env_var_name = match.group(1)
  44. env_value = os.getenv(env_var_name)
  45. if env_value is None:
  46. print(f"⚠️ 警告: 环境变量 {env_var_name} 未设置,使用原值")
  47. return match.group(0)
  48. return env_value
  49. return re.sub(pattern, replace_env_var, value)
  50. def _is_image_generation_prompt(self, prompt_name: str) -> bool:
  51. """
  52. 判断是否为图片生成相关的提示词
  53. Args:
  54. prompt_name: 提示词名称
  55. Returns:
  56. True if 是图片生成任务
  57. """
  58. image_generation_prompts = [
  59. 'photo_restore_classroom',
  60. 'photo_restore_advanced',
  61. 'photo_colorize_classroom',
  62. 'simple_photo_fix'
  63. ]
  64. return prompt_name in image_generation_prompts
  65. def _extract_base64_image(self, response_text: str) -> Optional[str]:
  66. """
  67. 从响应文本中提取base64编码的图片
  68. Args:
  69. response_text: API响应文本
  70. Returns:
  71. base64编码的图片数据,如果没找到返回None
  72. """
  73. # 常见的base64图片数据格式
  74. patterns = [
  75. r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', # data URL格式
  76. r'base64:([A-Za-z0-9+/=]{100,})', # base64:前缀
  77. r'```base64\s*\n([A-Za-z0-9+/=\s]+)\n```', # markdown代码块
  78. r'<img[^>]*src="data:image/[^;]+;base64,([A-Za-z0-9+/=]+)"[^>]*>', # HTML img标签
  79. ]
  80. for pattern in patterns:
  81. match = re.search(pattern, response_text, re.MULTILINE | re.DOTALL)
  82. if match:
  83. base64_data = match.group(1).replace('\n', '').replace(' ', '')
  84. if len(base64_data) > 1000: # 合理的图片大小
  85. return base64_data
  86. return None
  87. def list_models(self) -> None:
  88. """列出所有可用的模型"""
  89. print("📋 可用模型列表:")
  90. for model_key, model_config in self.config['models'].items():
  91. resolved_api_key = self._resolve_env_variable(model_config['api_key'])
  92. api_key_status = "✅ 已配置" if resolved_api_key else "❌ 未配置"
  93. print(f" 🤖 {model_key}: {model_config['name']}")
  94. print(f" API地址: {model_config['api_base']}")
  95. print(f" 模型ID: {model_config['model_id']}")
  96. print(f" API密钥: {api_key_status}")
  97. print()
  98. def list_prompts(self) -> None:
  99. """列出所有可用的提示词模板"""
  100. print("📝 可用提示词模板:")
  101. for prompt_key, prompt_config in self.config['prompts'].items():
  102. is_image_gen = self._is_image_generation_prompt(prompt_key)
  103. task_type = "🖼️ 图片生成" if is_image_gen else "📝 文本生成"
  104. print(f" 💬 {prompt_key}: {prompt_config['name']} ({task_type})")
  105. # 显示模板的前100个字符
  106. template_preview = prompt_config['template'][:100].replace('\n', ' ')
  107. print(f" 预览: {template_preview}...")
  108. print()
  109. def get_model_config(self, model_name: str) -> Dict[str, Any]:
  110. """获取模型配置"""
  111. if model_name not in self.config['models']:
  112. raise ValueError(f"未找到模型配置: {model_name},可用模型: {list(self.config['models'].keys())}")
  113. model_config = self.config['models'][model_name].copy()
  114. # 解析环境变量
  115. model_config['api_key'] = self._resolve_env_variable(model_config['api_key'])
  116. return model_config
  117. def get_prompt_template(self, prompt_name: str) -> str:
  118. """获取提示词模板"""
  119. if prompt_name not in self.config['prompts']:
  120. raise ValueError(f"未找到提示词模板: {prompt_name},可用模板: {list(self.config['prompts'].keys())}")
  121. return self.config['prompts'][prompt_name]['template']
  122. def normalize_financial_numbers(self, text: str) -> str:
  123. """
  124. 标准化财务数字:将全角字符转换为半角字符
  125. """
  126. if not text:
  127. return text
  128. # 定义全角到半角的映射
  129. fullwidth_to_halfwidth = {
  130. '0': '0', '1': '1', '2': '2', '3': '3', '4': '4',
  131. '5': '5', '6': '6', '7': '7', '8': '8', '9': '9',
  132. ',': ',', '。': '.', '.': '.', ':': ':',
  133. ';': ';', '(': '(', ')': ')', '-': '-',
  134. '+': '+', '%': '%',
  135. }
  136. # 执行字符替换
  137. normalized_text = text
  138. for fullwidth, halfwidth in fullwidth_to_halfwidth.items():
  139. normalized_text = normalized_text.replace(fullwidth, halfwidth)
  140. return normalized_text
  141. def process_image(self,
  142. image_path: str,
  143. model_name: Optional[str] = None,
  144. prompt_name: Optional[str] = None,
  145. output_dir: str = "./output",
  146. temperature: Optional[float] = None,
  147. max_tokens: Optional[int] = None,
  148. timeout: Optional[int] = None,
  149. normalize_numbers: Optional[bool] = None,
  150. custom_prompt: Optional[str] = None) -> Dict[str, Any]:
  151. """
  152. 处理单张图片
  153. Args:
  154. image_path: 图片路径
  155. model_name: 模型名称
  156. prompt_name: 提示词模板名称
  157. output_dir: 输出目录
  158. temperature: 生成温度
  159. max_tokens: 最大token数
  160. timeout: 超时时间
  161. normalize_numbers: 是否标准化数字
  162. custom_prompt: 自定义提示词(优先级高于prompt_name)
  163. Returns:
  164. 处理结果字典
  165. """
  166. # 使用默认值或配置值
  167. model_name = model_name or self.config['default']['model']
  168. prompt_name = prompt_name or self.config['default']['prompt']
  169. # 判断是否为图片生成任务
  170. is_image_generation = custom_prompt is None and self._is_image_generation_prompt(prompt_name)
  171. # 图片生成任务默认不进行数字标准化
  172. if is_image_generation:
  173. normalize_numbers = False
  174. print(f"🖼️ 检测到图片生成任务,自动禁用数字标准化")
  175. else:
  176. normalize_numbers = normalize_numbers if normalize_numbers is not None else self.config['default']['normalize_numbers']
  177. # 获取模型配置
  178. model_config = self.get_model_config(model_name)
  179. # 设置参数,优先使用传入的参数
  180. temperature = temperature if temperature is not None else model_config['default_params']['temperature']
  181. max_tokens = max_tokens if max_tokens is not None else model_config['default_params']['max_tokens']
  182. timeout = timeout if timeout is not None else model_config['default_params']['timeout']
  183. # 获取提示词
  184. if custom_prompt:
  185. prompt = custom_prompt
  186. print(f"🎯 使用自定义提示词")
  187. else:
  188. prompt = self.get_prompt_template(prompt_name)
  189. task_type = "图片生成" if is_image_generation else "文本分析"
  190. print(f"🎯 使用提示词模板: {prompt_name} ({task_type})")
  191. # 读取图片文件并转换为base64
  192. if not Path(image_path).exists():
  193. raise FileNotFoundError(f"找不到图片文件: {image_path}")
  194. with open(image_path, "rb") as image_file:
  195. image_data = base64.b64encode(image_file.read()).decode('utf-8')
  196. # 获取图片的MIME类型
  197. file_extension = Path(image_path).suffix.lower()
  198. mime_type_map = {
  199. '.jpg': 'image/jpeg',
  200. '.jpeg': 'image/jpeg',
  201. '.png': 'image/png',
  202. '.gif': 'image/gif',
  203. '.webp': 'image/webp'
  204. }
  205. mime_type = mime_type_map.get(file_extension, 'image/jpeg')
  206. # 创建OpenAI客户端
  207. client = OpenAI(
  208. api_key=model_config['api_key'] or "dummy-key",
  209. base_url=model_config['api_base']
  210. )
  211. # 构建消息
  212. messages = [
  213. {
  214. "role": "user",
  215. "content": [
  216. {
  217. "type": "text",
  218. "text": prompt
  219. },
  220. {
  221. "type": "image_url",
  222. "image_url": {
  223. "url": f"data:{mime_type};base64,{image_data}"
  224. }
  225. }
  226. ]
  227. }
  228. ]
  229. # 显示处理信息
  230. print(f"\n🚀 开始处理图片: {Path(image_path).name}")
  231. print(f"🤖 使用模型: {model_config['name']} ({model_name})")
  232. print(f"🌐 API地址: {model_config['api_base']}")
  233. print(f"🔧 参数配置:")
  234. print(f" - 温度: {temperature}")
  235. print(f" - 最大Token: {max_tokens}")
  236. print(f" - 超时时间: {timeout}秒")
  237. print(f" - 数字标准化: {'启用' if normalize_numbers else '禁用'}")
  238. print(f" - 任务类型: {'图片生成' if is_image_generation else '文本分析'}")
  239. try:
  240. # 调用API
  241. response = client.chat.completions.create(
  242. model=model_config['model_id'],
  243. messages=messages,
  244. temperature=temperature,
  245. max_tokens=max_tokens,
  246. timeout=timeout
  247. )
  248. # 提取响应内容
  249. generated_text = response.choices[0].message.content
  250. if not generated_text:
  251. raise Exception("模型没有生成内容")
  252. # 处理图片生成结果
  253. if is_image_generation:
  254. # 尝试提取base64图片数据
  255. base64_image = self._extract_base64_image(generated_text)
  256. if base64_image:
  257. print("🖼️ 检测到生成的图片数据")
  258. return self._save_image_results(
  259. image_path=image_path,
  260. output_dir=output_dir,
  261. generated_text=generated_text,
  262. base64_image=base64_image,
  263. model_name=model_name,
  264. prompt_name=prompt_name,
  265. model_config=model_config,
  266. processing_params={
  267. 'temperature': temperature,
  268. 'max_tokens': max_tokens,
  269. 'timeout': timeout,
  270. 'normalize_numbers': normalize_numbers,
  271. 'custom_prompt_used': custom_prompt is not None,
  272. 'is_image_generation': True
  273. }
  274. )
  275. else:
  276. print("⚠️ 未检测到图片数据,保存为文本结果")
  277. # 标准化数字格式(如果启用)
  278. original_text = generated_text
  279. if normalize_numbers:
  280. print("🔧 正在标准化数字格式...")
  281. generated_text = self.normalize_financial_numbers(generated_text)
  282. # 统计标准化的变化
  283. changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
  284. if changes_count > 0:
  285. print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
  286. else:
  287. print("ℹ️ 无需标准化(已是标准格式)")
  288. print(f"✅ 成功完成处理!")
  289. # 保存文本结果
  290. return self._save_text_results(
  291. image_path=image_path,
  292. output_dir=output_dir,
  293. generated_text=generated_text,
  294. original_text=original_text,
  295. model_name=model_name,
  296. prompt_name=prompt_name,
  297. model_config=model_config,
  298. processing_params={
  299. 'temperature': temperature,
  300. 'max_tokens': max_tokens,
  301. 'timeout': timeout,
  302. 'normalize_numbers': normalize_numbers,
  303. 'custom_prompt_used': custom_prompt is not None,
  304. 'is_image_generation': is_image_generation
  305. }
  306. )
  307. except Exception as e:
  308. print(f"❌ 处理失败: {e}")
  309. raise
  310. def _save_image_results(self,
  311. image_path: str,
  312. output_dir: str,
  313. generated_text: str,
  314. base64_image: str,
  315. model_name: str,
  316. prompt_name: str,
  317. model_config: Dict[str, Any],
  318. processing_params: Dict[str, Any]) -> Dict[str, Any]:
  319. """保存图片生成结果"""
  320. # 创建输出目录
  321. output_path = Path(output_dir)
  322. output_path.mkdir(parents=True, exist_ok=True)
  323. # 生成输出文件名
  324. base_name = Path(image_path).stem
  325. timestamp = time.strftime("%Y%m%d_%H%M%S")
  326. # 保存生成的图片
  327. try:
  328. image_bytes = base64.b64decode(base64_image)
  329. image_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.png"
  330. with open(image_file, 'wb') as f:
  331. f.write(image_bytes)
  332. print(f"🖼️ 生成的图片已保存到: {image_file}")
  333. except Exception as e:
  334. print(f"❌ 图片保存失败: {e}")
  335. # 如果图片保存失败,保存为文本
  336. text_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.txt"
  337. with open(text_file, 'w', encoding='utf-8') as f:
  338. f.write(generated_text)
  339. print(f"📄 响应内容已保存为文本: {text_file}")
  340. image_file = text_file
  341. # 保存原始响应文本(包含可能的说明文字)
  342. if len(generated_text.strip()) > len(base64_image) + 100: # 如果有额外的说明文字
  343. description_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_description.txt"
  344. with open(description_file, 'w', encoding='utf-8') as f:
  345. f.write(generated_text)
  346. print(f"📝 响应说明已保存到: {description_file}")
  347. # 保存元数据
  348. metadata = {
  349. "processing_info": {
  350. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
  351. "image_path": Path(image_path).resolve().as_posix(),
  352. "output_file": image_file.resolve().as_posix(),
  353. "model_used": model_name,
  354. "model_config": model_config,
  355. "prompt_template": prompt_name,
  356. "processing_params": processing_params,
  357. "result_type": "image",
  358. "text_stats": {
  359. "response_length": len(generated_text),
  360. "has_image_data": True,
  361. "base64_length": len(base64_image)
  362. }
  363. }
  364. }
  365. metadata_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_metadata.json"
  366. with open(metadata_file, 'w', encoding='utf-8') as f:
  367. json.dump(metadata, f, ensure_ascii=False, indent=2)
  368. print(f"📊 元数据已保存到: {metadata_file}")
  369. return metadata
  370. def _save_text_results(self,
  371. image_path: str,
  372. output_dir: str,
  373. generated_text: str,
  374. original_text: str,
  375. model_name: str,
  376. prompt_name: str,
  377. model_config: Dict[str, Any],
  378. processing_params: Dict[str, Any]) -> Dict[str, Any]:
  379. """保存文本结果"""
  380. # 创建输出目录
  381. output_path = Path(output_dir)
  382. output_path.mkdir(parents=True, exist_ok=True)
  383. # 生成输出文件名
  384. base_name = Path(image_path).stem
  385. # 保存主结果文件
  386. if prompt_name in ['ocr_standard', 'table_extract']:
  387. # OCR相关任务保存为Markdown格式
  388. result_file = output_path / f"{base_name}_{model_name}.md"
  389. with open(result_file, 'w', encoding='utf-8') as f:
  390. f.write(generated_text)
  391. print(f"📄 结果已保存到: {result_file}")
  392. else:
  393. # 其他任务保存为文本格式
  394. result_file = output_path / f"{base_name}_{model_name}_{prompt_name}.txt"
  395. with open(result_file, 'w', encoding='utf-8') as f:
  396. f.write(generated_text)
  397. print(f"📄 结果已保存到: {result_file}")
  398. # 如果进行了数字标准化,保存原始版本
  399. if processing_params['normalize_numbers'] and original_text != generated_text:
  400. original_file = output_path / f"{base_name}_{model_name}_original.txt"
  401. with open(original_file, 'w', encoding='utf-8') as f:
  402. f.write(original_text)
  403. print(f"📄 原始结果已保存到: {original_file}")
  404. # 保存元数据
  405. metadata = {
  406. "processing_info": {
  407. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
  408. "image_path": Path(image_path).resolve().as_posix(),
  409. "output_file": result_file.resolve().as_posix(),
  410. "model_used": model_name,
  411. "model_config": model_config,
  412. "prompt_template": prompt_name,
  413. "processing_params": processing_params,
  414. "result_type": "text",
  415. "text_stats": {
  416. "original_length": len(original_text),
  417. "final_length": len(generated_text),
  418. "character_changes": len([1 for o, n in zip(original_text, generated_text) if o != n]) if processing_params['normalize_numbers'] else 0
  419. }
  420. }
  421. }
  422. metadata_file = output_path / f"{base_name}_{model_name}_metadata.json"
  423. with open(metadata_file, 'w', encoding='utf-8') as f:
  424. json.dump(metadata, f, ensure_ascii=False, indent=2)
  425. print(f"📊 元数据已保存到: {metadata_file}")
  426. return metadata
  427. def main():
  428. """主函数"""
  429. parser = argparse.ArgumentParser(description='本地VLM图片处理工具')
  430. # 基本参数
  431. parser.add_argument('image_path', nargs='?', help='图片文件路径')
  432. parser.add_argument('-c', '--config', default='config.yaml', help='配置文件路径')
  433. parser.add_argument('-o', '--output', default='./output', help='输出目录')
  434. # 模型和提示词选择
  435. parser.add_argument('-m', '--model', help='模型名称')
  436. parser.add_argument('-p', '--prompt', help='提示词模板名称')
  437. parser.add_argument('--custom-prompt', help='自定义提示词(优先级高于-p参数)')
  438. # 处理参数
  439. parser.add_argument('-t', '--temperature', type=float, help='生成温度')
  440. parser.add_argument('--max-tokens', type=int, help='最大token数')
  441. parser.add_argument('--timeout', type=int, help='超时时间(秒)')
  442. parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化, 只有提取表格或ocr相关任务才启用')
  443. # 信息查询
  444. parser.add_argument('--list-models', action='store_true', help='列出所有可用模型')
  445. parser.add_argument('--list-prompts', action='store_true', help='列出所有提示词模板')
  446. args = parser.parse_args()
  447. try:
  448. # 初始化处理器
  449. processor = LocalVLMProcessor(args.config)
  450. # 处理信息查询请求
  451. if args.list_models:
  452. processor.list_models()
  453. return 0
  454. if args.list_prompts:
  455. processor.list_prompts()
  456. return 0
  457. # 检查是否提供了图片路径
  458. if not args.image_path:
  459. print("❌ 错误: 请提供图片文件路径")
  460. print("\n使用示例:")
  461. print(" python local_vlm_processor.py image.jpg")
  462. print(" python local_vlm_processor.py image.jpg -m qwen2_vl -p photo_analysis")
  463. print(" python local_vlm_processor.py image.jpg -p simple_photo_fix # 图片修复")
  464. print(" python local_vlm_processor.py --list-models")
  465. print(" python local_vlm_processor.py --list-prompts")
  466. return 1
  467. # 处理图片
  468. result = processor.process_image(
  469. image_path=args.image_path,
  470. model_name=args.model,
  471. prompt_name=args.prompt,
  472. output_dir=args.output,
  473. temperature=args.temperature,
  474. max_tokens=args.max_tokens,
  475. timeout=args.timeout,
  476. normalize_numbers=not args.no_normalize,
  477. custom_prompt=args.custom_prompt
  478. )
  479. print(f"\n🎉 处理完成!")
  480. print(f"📊 处理统计:")
  481. if result['processing_info']['result_type'] == 'image':
  482. stats = result['processing_info']['text_stats']
  483. print(f" 响应长度: {stats['response_length']} 字符")
  484. print(f" 图片数据: {'包含' if stats['has_image_data'] else '不包含'}")
  485. if stats['has_image_data']:
  486. print(f" Base64长度: {stats['base64_length']} 字符")
  487. else:
  488. stats = result['processing_info']['text_stats']
  489. print(f" 原始长度: {stats['original_length']} 字符")
  490. print(f" 最终长度: {stats['final_length']} 字符")
  491. if stats['character_changes'] > 0:
  492. print(f" 标准化变更: {stats['character_changes']} 字符")
  493. return 0
  494. except Exception as e:
  495. print(f"❌ 程序执行失败: {e}")
  496. return 1
  497. if __name__ == "__main__":
  498. # 如果sys.argv没有被传入参数,则提供默认参数用于测试
  499. import sys
  500. if len(sys.argv) == 1:
  501. sys.argv.extend([
  502. '../sample_data/工大照片-1.jpg',
  503. '-p', 'simple_photo_fix',
  504. '-o', './output',
  505. '--no-normalize'])
  506. exit(main())