import os import re import yaml import base64 import json import time import argparse from pathlib import Path from typing import Dict, Any, Optional from openai import OpenAI from dotenv import load_dotenv # 加载环境变量 load_dotenv(override=True) class LocalVLMProcessor: def __init__(self, config_path: str = "config.yaml"): """ 初始化本地VLM处理器 Args: config_path: 配置文件路径 """ self.config_path = Path(config_path) self.config = self._load_config() def _load_config(self) -> Dict[str, Any]: """加载配置文件""" if not self.config_path.exists(): raise FileNotFoundError(f"配置文件不存在: {self.config_path}") with open(self.config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return config def _resolve_env_variable(self, value: str) -> str: """ 解析环境变量:将 ${VAR_NAME} 格式替换为实际的环境变量值 Args: value: 可能包含环境变量的字符串 Returns: 解析后的字符串 """ if not isinstance(value, str): return value # 匹配 ${VAR_NAME} 格式的环境变量 pattern = r'\$\{([^}]+)\}' def replace_env_var(match): env_var_name = match.group(1) env_value = os.getenv(env_var_name) if env_value is None: print(f"⚠️ 警告: 环境变量 {env_var_name} 未设置,使用原值") return match.group(0) return env_value return re.sub(pattern, replace_env_var, value) def _is_image_generation_prompt(self, prompt_name: str) -> bool: """ 判断是否为图片生成相关的提示词 Args: prompt_name: 提示词名称 Returns: True if 是图片生成任务 """ image_generation_prompts = [ 'photo_restore_classroom', 'photo_restore_advanced', 'photo_colorize_classroom', 'simple_photo_fix' ] return prompt_name in image_generation_prompts def _extract_base64_image(self, response_text: str) -> Optional[str]: """ 从响应文本中提取base64编码的图片 Args: response_text: API响应文本 Returns: base64编码的图片数据,如果没找到返回None """ # 常见的base64图片数据格式 patterns = [ r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', # data URL格式 r'base64:([A-Za-z0-9+/=]{100,})', # base64:前缀 r'```base64\s*\n([A-Za-z0-9+/=\s]+)\n```', # markdown代码块 r']*src="data:image/[^;]+;base64,([A-Za-z0-9+/=]+)"[^>]*>', # HTML img标签 ] for pattern in patterns: match = re.search(pattern, response_text, re.MULTILINE | re.DOTALL) if match: base64_data = match.group(1).replace('\n', '').replace(' ', '') if len(base64_data) > 1000: # 合理的图片大小 return base64_data return None def list_models(self) -> None: """列出所有可用的模型""" print("📋 可用模型列表:") for model_key, model_config in self.config['models'].items(): resolved_api_key = self._resolve_env_variable(model_config['api_key']) api_key_status = "✅ 已配置" if resolved_api_key else "❌ 未配置" print(f" 🤖 {model_key}: {model_config['name']}") print(f" API地址: {model_config['api_base']}") print(f" 模型ID: {model_config['model_id']}") print(f" API密钥: {api_key_status}") print() def list_prompts(self) -> None: """列出所有可用的提示词模板""" print("📝 可用提示词模板:") for prompt_key, prompt_config in self.config['prompts'].items(): is_image_gen = self._is_image_generation_prompt(prompt_key) task_type = "🖼️ 图片生成" if is_image_gen else "📝 文本生成" print(f" 💬 {prompt_key}: {prompt_config['name']} ({task_type})") # 显示模板的前100个字符 template_preview = prompt_config['template'][:100].replace('\n', ' ') print(f" 预览: {template_preview}...") print() def get_model_config(self, model_name: str) -> Dict[str, Any]: """获取模型配置""" if model_name not in self.config['models']: raise ValueError(f"未找到模型配置: {model_name},可用模型: {list(self.config['models'].keys())}") model_config = self.config['models'][model_name].copy() # 解析环境变量 model_config['api_key'] = self._resolve_env_variable(model_config['api_key']) return model_config def get_prompt_template(self, prompt_name: str) -> str: """获取提示词模板""" if prompt_name not in self.config['prompts']: raise ValueError(f"未找到提示词模板: {prompt_name},可用模板: {list(self.config['prompts'].keys())}") return self.config['prompts'][prompt_name]['template'] def normalize_financial_numbers(self, text: str) -> str: """ 标准化财务数字:将全角字符转换为半角字符 """ if not text: return text # 定义全角到半角的映射 fullwidth_to_halfwidth = { '0': '0', '1': '1', '2': '2', '3': '3', '4': '4', '5': '5', '6': '6', '7': '7', '8': '8', '9': '9', ',': ',', '。': '.', '.': '.', ':': ':', ';': ';', '(': '(', ')': ')', '-': '-', '+': '+', '%': '%', } # 执行字符替换 normalized_text = text for fullwidth, halfwidth in fullwidth_to_halfwidth.items(): normalized_text = normalized_text.replace(fullwidth, halfwidth) return normalized_text def process_image(self, image_path: str, model_name: Optional[str] = None, prompt_name: Optional[str] = None, output_dir: str = "./output", temperature: Optional[float] = None, max_tokens: Optional[int] = None, timeout: Optional[int] = None, normalize_numbers: Optional[bool] = None, custom_prompt: Optional[str] = None) -> Dict[str, Any]: """ 处理单张图片 Args: image_path: 图片路径 model_name: 模型名称 prompt_name: 提示词模板名称 output_dir: 输出目录 temperature: 生成温度 max_tokens: 最大token数 timeout: 超时时间 normalize_numbers: 是否标准化数字 custom_prompt: 自定义提示词(优先级高于prompt_name) Returns: 处理结果字典 """ # 使用默认值或配置值 model_name = model_name or self.config['default']['model'] prompt_name = prompt_name or self.config['default']['prompt'] # 判断是否为图片生成任务 is_image_generation = custom_prompt is None and self._is_image_generation_prompt(prompt_name) # 图片生成任务默认不进行数字标准化 if is_image_generation: normalize_numbers = False print(f"🖼️ 检测到图片生成任务,自动禁用数字标准化") else: normalize_numbers = normalize_numbers if normalize_numbers is not None else self.config['default']['normalize_numbers'] # 获取模型配置 model_config = self.get_model_config(model_name) # 设置参数,优先使用传入的参数 temperature = temperature if temperature is not None else model_config['default_params']['temperature'] max_tokens = max_tokens if max_tokens is not None else model_config['default_params']['max_tokens'] timeout = timeout if timeout is not None else model_config['default_params']['timeout'] # 获取提示词 if custom_prompt: prompt = custom_prompt print(f"🎯 使用自定义提示词") else: prompt = self.get_prompt_template(prompt_name) task_type = "图片生成" if is_image_generation else "文本分析" print(f"🎯 使用提示词模板: {prompt_name} ({task_type})") # 读取图片文件并转换为base64 if not Path(image_path).exists(): raise FileNotFoundError(f"找不到图片文件: {image_path}") with open(image_path, "rb") as image_file: image_data = base64.b64encode(image_file.read()).decode('utf-8') # 获取图片的MIME类型 file_extension = Path(image_path).suffix.lower() mime_type_map = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif', '.webp': 'image/webp' } mime_type = mime_type_map.get(file_extension, 'image/jpeg') # 创建OpenAI客户端 client = OpenAI( api_key=model_config['api_key'] or "dummy-key", base_url=model_config['api_base'] ) # 构建消息 messages = [ { "role": "user", "content": [ { "type": "text", "text": prompt }, { "type": "image_url", "image_url": { "url": f"data:{mime_type};base64,{image_data}" } } ] } ] # 显示处理信息 print(f"\n🚀 开始处理图片: {Path(image_path).name}") print(f"🤖 使用模型: {model_config['name']} ({model_name})") print(f"🌐 API地址: {model_config['api_base']}") print(f"🔧 参数配置:") print(f" - 温度: {temperature}") print(f" - 最大Token: {max_tokens}") print(f" - 超时时间: {timeout}秒") print(f" - 数字标准化: {'启用' if normalize_numbers else '禁用'}") print(f" - 任务类型: {'图片生成' if is_image_generation else '文本分析'}") try: # 调用API response = client.chat.completions.create( model=model_config['model_id'], messages=messages, temperature=temperature, max_tokens=max_tokens, timeout=timeout ) # 提取响应内容 generated_text = response.choices[0].message.content if not generated_text: raise Exception("模型没有生成内容") # 处理图片生成结果 if is_image_generation: # 尝试提取base64图片数据 base64_image = self._extract_base64_image(generated_text) if base64_image: print("🖼️ 检测到生成的图片数据") return self._save_image_results( image_path=image_path, output_dir=output_dir, generated_text=generated_text, base64_image=base64_image, model_name=model_name, prompt_name=prompt_name, model_config=model_config, processing_params={ 'temperature': temperature, 'max_tokens': max_tokens, 'timeout': timeout, 'normalize_numbers': normalize_numbers, 'custom_prompt_used': custom_prompt is not None, 'is_image_generation': True } ) else: print("⚠️ 未检测到图片数据,保存为文本结果") # 标准化数字格式(如果启用) original_text = generated_text if normalize_numbers: print("🔧 正在标准化数字格式...") generated_text = self.normalize_financial_numbers(generated_text) # 统计标准化的变化 changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n]) if changes_count > 0: print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)") else: print("ℹ️ 无需标准化(已是标准格式)") print(f"✅ 成功完成处理!") # 保存文本结果 return self._save_text_results( image_path=image_path, output_dir=output_dir, generated_text=generated_text, original_text=original_text, model_name=model_name, prompt_name=prompt_name, model_config=model_config, processing_params={ 'temperature': temperature, 'max_tokens': max_tokens, 'timeout': timeout, 'normalize_numbers': normalize_numbers, 'custom_prompt_used': custom_prompt is not None, 'is_image_generation': is_image_generation } ) except Exception as e: print(f"❌ 处理失败: {e}") raise def _save_image_results(self, image_path: str, output_dir: str, generated_text: str, base64_image: str, model_name: str, prompt_name: str, model_config: Dict[str, Any], processing_params: Dict[str, Any]) -> Dict[str, Any]: """保存图片生成结果""" # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # 生成输出文件名 base_name = Path(image_path).stem timestamp = time.strftime("%Y%m%d_%H%M%S") # 保存生成的图片 try: image_bytes = base64.b64decode(base64_image) image_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.png" with open(image_file, 'wb') as f: f.write(image_bytes) print(f"🖼️ 生成的图片已保存到: {image_file}") except Exception as e: print(f"❌ 图片保存失败: {e}") # 如果图片保存失败,保存为文本 text_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.txt" with open(text_file, 'w', encoding='utf-8') as f: f.write(generated_text) print(f"📄 响应内容已保存为文本: {text_file}") image_file = text_file # 保存原始响应文本(包含可能的说明文字) if len(generated_text.strip()) > len(base64_image) + 100: # 如果有额外的说明文字 description_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_description.txt" with open(description_file, 'w', encoding='utf-8') as f: f.write(generated_text) print(f"📝 响应说明已保存到: {description_file}") # 保存元数据 metadata = { "processing_info": { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "image_path": Path(image_path).resolve().as_posix(), "output_file": image_file.resolve().as_posix(), "model_used": model_name, "model_config": model_config, "prompt_template": prompt_name, "processing_params": processing_params, "result_type": "image", "text_stats": { "response_length": len(generated_text), "has_image_data": True, "base64_length": len(base64_image) } } } metadata_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_metadata.json" with open(metadata_file, 'w', encoding='utf-8') as f: json.dump(metadata, f, ensure_ascii=False, indent=2) print(f"📊 元数据已保存到: {metadata_file}") return metadata def _save_text_results(self, image_path: str, output_dir: str, generated_text: str, original_text: str, model_name: str, prompt_name: str, model_config: Dict[str, Any], processing_params: Dict[str, Any]) -> Dict[str, Any]: """保存文本结果""" # 创建输出目录 output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # 生成输出文件名 base_name = Path(image_path).stem # 保存主结果文件 if prompt_name in ['ocr_standard', 'table_extract']: # OCR相关任务保存为Markdown格式 result_file = output_path / f"{base_name}_{model_name}.md" with open(result_file, 'w', encoding='utf-8') as f: f.write(generated_text) print(f"📄 结果已保存到: {result_file}") else: # 其他任务保存为文本格式 result_file = output_path / f"{base_name}_{model_name}_{prompt_name}.txt" with open(result_file, 'w', encoding='utf-8') as f: f.write(generated_text) print(f"📄 结果已保存到: {result_file}") # 如果进行了数字标准化,保存原始版本 if processing_params['normalize_numbers'] and original_text != generated_text: original_file = output_path / f"{base_name}_{model_name}_original.txt" with open(original_file, 'w', encoding='utf-8') as f: f.write(original_text) print(f"📄 原始结果已保存到: {original_file}") # 保存元数据 metadata = { "processing_info": { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "image_path": Path(image_path).resolve().as_posix(), "output_file": result_file.resolve().as_posix(), "model_used": model_name, "model_config": model_config, "prompt_template": prompt_name, "processing_params": processing_params, "result_type": "text", "text_stats": { "original_length": len(original_text), "final_length": len(generated_text), "character_changes": len([1 for o, n in zip(original_text, generated_text) if o != n]) if processing_params['normalize_numbers'] else 0 } } } metadata_file = output_path / f"{base_name}_{model_name}_metadata.json" with open(metadata_file, 'w', encoding='utf-8') as f: json.dump(metadata, f, ensure_ascii=False, indent=2) print(f"📊 元数据已保存到: {metadata_file}") return metadata def main(): """主函数""" parser = argparse.ArgumentParser(description='本地VLM图片处理工具') # 基本参数 parser.add_argument('image_path', nargs='?', help='图片文件路径') parser.add_argument('-c', '--config', default='config.yaml', help='配置文件路径') parser.add_argument('-o', '--output', default='./output', help='输出目录') # 模型和提示词选择 parser.add_argument('-m', '--model', help='模型名称') parser.add_argument('-p', '--prompt', help='提示词模板名称') parser.add_argument('--custom-prompt', help='自定义提示词(优先级高于-p参数)') # 处理参数 parser.add_argument('-t', '--temperature', type=float, help='生成温度') parser.add_argument('--max-tokens', type=int, help='最大token数') parser.add_argument('--timeout', type=int, help='超时时间(秒)') parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化, 只有提取表格或ocr相关任务才启用') # 信息查询 parser.add_argument('--list-models', action='store_true', help='列出所有可用模型') parser.add_argument('--list-prompts', action='store_true', help='列出所有提示词模板') args = parser.parse_args() try: # 初始化处理器 processor = LocalVLMProcessor(args.config) # 处理信息查询请求 if args.list_models: processor.list_models() return 0 if args.list_prompts: processor.list_prompts() return 0 # 检查是否提供了图片路径 if not args.image_path: print("❌ 错误: 请提供图片文件路径") print("\n使用示例:") print(" python local_vlm_processor.py image.jpg") print(" python local_vlm_processor.py image.jpg -m qwen2_vl -p photo_analysis") print(" python local_vlm_processor.py image.jpg -p simple_photo_fix # 图片修复") print(" python local_vlm_processor.py --list-models") print(" python local_vlm_processor.py --list-prompts") return 1 # 处理图片 result = processor.process_image( image_path=args.image_path, model_name=args.model, prompt_name=args.prompt, output_dir=args.output, temperature=args.temperature, max_tokens=args.max_tokens, timeout=args.timeout, normalize_numbers=not args.no_normalize, custom_prompt=args.custom_prompt ) print(f"\n🎉 处理完成!") print(f"📊 处理统计:") if result['processing_info']['result_type'] == 'image': stats = result['processing_info']['text_stats'] print(f" 响应长度: {stats['response_length']} 字符") print(f" 图片数据: {'包含' if stats['has_image_data'] else '不包含'}") if stats['has_image_data']: print(f" Base64长度: {stats['base64_length']} 字符") else: stats = result['processing_info']['text_stats'] print(f" 原始长度: {stats['original_length']} 字符") print(f" 最终长度: {stats['final_length']} 字符") if stats['character_changes'] > 0: print(f" 标准化变更: {stats['character_changes']} 字符") return 0 except Exception as e: print(f"❌ 程序执行失败: {e}") return 1 if __name__ == "__main__": # 如果sys.argv没有被传入参数,则提供默认参数用于测试 import sys if len(sys.argv) == 1: sys.argv.extend([ '../sample_data/工大照片-1.jpg', '-p', 'simple_photo_fix', '-o', './output', '--no-normalize']) exit(main())