| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598 |
- import os
- import re
- import yaml
- import base64
- import json
- import time
- import argparse
- from pathlib import Path
- from typing import Dict, Any, Optional
- from openai import OpenAI
- from dotenv import load_dotenv
- # 加载环境变量
- load_dotenv(override=True)
- class LocalVLMProcessor:
- def __init__(self, config_path: str = "config.yaml"):
- """
- 初始化本地VLM处理器
-
- Args:
- config_path: 配置文件路径
- """
- self.config_path = Path(config_path)
- self.config = self._load_config()
-
- def _load_config(self) -> Dict[str, Any]:
- """加载配置文件"""
- if not self.config_path.exists():
- raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
-
- with open(self.config_path, 'r', encoding='utf-8') as f:
- config = yaml.safe_load(f)
-
- return config
-
- def _resolve_env_variable(self, value: str) -> str:
- """
- 解析环境变量:将 ${VAR_NAME} 格式替换为实际的环境变量值
-
- Args:
- value: 可能包含环境变量的字符串
-
- Returns:
- 解析后的字符串
- """
- if not isinstance(value, str):
- return value
-
- # 匹配 ${VAR_NAME} 格式的环境变量
- pattern = r'\$\{([^}]+)\}'
-
- def replace_env_var(match):
- env_var_name = match.group(1)
- env_value = os.getenv(env_var_name)
- if env_value is None:
- print(f"⚠️ 警告: 环境变量 {env_var_name} 未设置,使用原值")
- return match.group(0)
- return env_value
-
- return re.sub(pattern, replace_env_var, value)
-
- def _is_image_generation_prompt(self, prompt_name: str) -> bool:
- """
- 判断是否为图片生成相关的提示词
-
- Args:
- prompt_name: 提示词名称
-
- Returns:
- True if 是图片生成任务
- """
- image_generation_prompts = [
- 'photo_restore_classroom',
- 'photo_restore_advanced',
- 'photo_colorize_classroom',
- 'simple_photo_fix'
- ]
- return prompt_name in image_generation_prompts
-
- def _extract_base64_image(self, response_text: str) -> Optional[str]:
- """
- 从响应文本中提取base64编码的图片
-
- Args:
- response_text: API响应文本
-
- Returns:
- base64编码的图片数据,如果没找到返回None
- """
- # 常见的base64图片数据格式
- patterns = [
- r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', # data URL格式
- r'base64:([A-Za-z0-9+/=]{100,})', # base64:前缀
- r'```base64\s*\n([A-Za-z0-9+/=\s]+)\n```', # markdown代码块
- r'<img[^>]*src="data:image/[^;]+;base64,([A-Za-z0-9+/=]+)"[^>]*>', # HTML img标签
- ]
-
- for pattern in patterns:
- match = re.search(pattern, response_text, re.MULTILINE | re.DOTALL)
- if match:
- base64_data = match.group(1).replace('\n', '').replace(' ', '')
- if len(base64_data) > 1000: # 合理的图片大小
- return base64_data
-
- return None
-
- def list_models(self) -> None:
- """列出所有可用的模型"""
- print("📋 可用模型列表:")
- for model_key, model_config in self.config['models'].items():
- resolved_api_key = self._resolve_env_variable(model_config['api_key'])
- api_key_status = "✅ 已配置" if resolved_api_key else "❌ 未配置"
-
- print(f" 🤖 {model_key}: {model_config['name']}")
- print(f" API地址: {model_config['api_base']}")
- print(f" 模型ID: {model_config['model_id']}")
- print(f" API密钥: {api_key_status}")
- print()
-
- def list_prompts(self) -> None:
- """列出所有可用的提示词模板"""
- print("📝 可用提示词模板:")
- for prompt_key, prompt_config in self.config['prompts'].items():
- is_image_gen = self._is_image_generation_prompt(prompt_key)
- task_type = "🖼️ 图片生成" if is_image_gen else "📝 文本生成"
-
- print(f" 💬 {prompt_key}: {prompt_config['name']} ({task_type})")
- # 显示模板的前100个字符
- template_preview = prompt_config['template'][:100].replace('\n', ' ')
- print(f" 预览: {template_preview}...")
- print()
-
- def get_model_config(self, model_name: str) -> Dict[str, Any]:
- """获取模型配置"""
- if model_name not in self.config['models']:
- raise ValueError(f"未找到模型配置: {model_name},可用模型: {list(self.config['models'].keys())}")
-
- model_config = self.config['models'][model_name].copy()
-
- # 解析环境变量
- model_config['api_key'] = self._resolve_env_variable(model_config['api_key'])
-
- return model_config
-
- def get_prompt_template(self, prompt_name: str) -> str:
- """获取提示词模板"""
- if prompt_name not in self.config['prompts']:
- raise ValueError(f"未找到提示词模板: {prompt_name},可用模板: {list(self.config['prompts'].keys())}")
-
- return self.config['prompts'][prompt_name]['template']
-
- def normalize_financial_numbers(self, text: str) -> str:
- """
- 标准化财务数字:将全角字符转换为半角字符
- """
- if not text:
- return text
-
- # 定义全角到半角的映射
- fullwidth_to_halfwidth = {
- '0': '0', '1': '1', '2': '2', '3': '3', '4': '4',
- '5': '5', '6': '6', '7': '7', '8': '8', '9': '9',
- ',': ',', '。': '.', '.': '.', ':': ':',
- ';': ';', '(': '(', ')': ')', '-': '-',
- '+': '+', '%': '%',
- }
-
- # 执行字符替换
- normalized_text = text
- for fullwidth, halfwidth in fullwidth_to_halfwidth.items():
- normalized_text = normalized_text.replace(fullwidth, halfwidth)
-
- return normalized_text
-
- def process_image(self,
- image_path: str,
- model_name: Optional[str] = None,
- prompt_name: Optional[str] = None,
- output_dir: str = "./output",
- temperature: Optional[float] = None,
- max_tokens: Optional[int] = None,
- timeout: Optional[int] = None,
- normalize_numbers: Optional[bool] = None,
- custom_prompt: Optional[str] = None) -> Dict[str, Any]:
- """
- 处理单张图片
-
- Args:
- image_path: 图片路径
- model_name: 模型名称
- prompt_name: 提示词模板名称
- output_dir: 输出目录
- temperature: 生成温度
- max_tokens: 最大token数
- timeout: 超时时间
- normalize_numbers: 是否标准化数字
- custom_prompt: 自定义提示词(优先级高于prompt_name)
-
- Returns:
- 处理结果字典
- """
- # 使用默认值或配置值
- model_name = model_name or self.config['default']['model']
- prompt_name = prompt_name or self.config['default']['prompt']
-
- # 判断是否为图片生成任务
- is_image_generation = custom_prompt is None and self._is_image_generation_prompt(prompt_name)
-
- # 图片生成任务默认不进行数字标准化
- if is_image_generation:
- normalize_numbers = False
- print(f"🖼️ 检测到图片生成任务,自动禁用数字标准化")
- else:
- normalize_numbers = normalize_numbers if normalize_numbers is not None else self.config['default']['normalize_numbers']
-
- # 获取模型配置
- model_config = self.get_model_config(model_name)
-
- # 设置参数,优先使用传入的参数
- temperature = temperature if temperature is not None else model_config['default_params']['temperature']
- max_tokens = max_tokens if max_tokens is not None else model_config['default_params']['max_tokens']
- timeout = timeout if timeout is not None else model_config['default_params']['timeout']
-
- # 获取提示词
- if custom_prompt:
- prompt = custom_prompt
- print(f"🎯 使用自定义提示词")
- else:
- prompt = self.get_prompt_template(prompt_name)
- task_type = "图片生成" if is_image_generation else "文本分析"
- print(f"🎯 使用提示词模板: {prompt_name} ({task_type})")
-
- # 读取图片文件并转换为base64
- if not Path(image_path).exists():
- raise FileNotFoundError(f"找不到图片文件: {image_path}")
-
- with open(image_path, "rb") as image_file:
- image_data = base64.b64encode(image_file.read()).decode('utf-8')
-
- # 获取图片的MIME类型
- file_extension = Path(image_path).suffix.lower()
- mime_type_map = {
- '.jpg': 'image/jpeg',
- '.jpeg': 'image/jpeg',
- '.png': 'image/png',
- '.gif': 'image/gif',
- '.webp': 'image/webp'
- }
- mime_type = mime_type_map.get(file_extension, 'image/jpeg')
-
- # 创建OpenAI客户端
- client = OpenAI(
- api_key=model_config['api_key'] or "dummy-key",
- base_url=model_config['api_base']
- )
-
- # 构建消息
- messages = [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": prompt
- },
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:{mime_type};base64,{image_data}"
- }
- }
- ]
- }
- ]
-
- # 显示处理信息
- print(f"\n🚀 开始处理图片: {Path(image_path).name}")
- print(f"🤖 使用模型: {model_config['name']} ({model_name})")
- print(f"🌐 API地址: {model_config['api_base']}")
- print(f"🔧 参数配置:")
- print(f" - 温度: {temperature}")
- print(f" - 最大Token: {max_tokens}")
- print(f" - 超时时间: {timeout}秒")
- print(f" - 数字标准化: {'启用' if normalize_numbers else '禁用'}")
- print(f" - 任务类型: {'图片生成' if is_image_generation else '文本分析'}")
-
- try:
- # 调用API
- response = client.chat.completions.create(
- model=model_config['model_id'],
- messages=messages,
- temperature=temperature,
- max_tokens=max_tokens,
- timeout=timeout
- )
-
- # 提取响应内容
- generated_text = response.choices[0].message.content
-
- if not generated_text:
- raise Exception("模型没有生成内容")
-
- # 处理图片生成结果
- if is_image_generation:
- # 尝试提取base64图片数据
- base64_image = self._extract_base64_image(generated_text)
- if base64_image:
- print("🖼️ 检测到生成的图片数据")
- return self._save_image_results(
- image_path=image_path,
- output_dir=output_dir,
- generated_text=generated_text,
- base64_image=base64_image,
- model_name=model_name,
- prompt_name=prompt_name,
- model_config=model_config,
- processing_params={
- 'temperature': temperature,
- 'max_tokens': max_tokens,
- 'timeout': timeout,
- 'normalize_numbers': normalize_numbers,
- 'custom_prompt_used': custom_prompt is not None,
- 'is_image_generation': True
- }
- )
- else:
- print("⚠️ 未检测到图片数据,保存为文本结果")
-
- # 标准化数字格式(如果启用)
- original_text = generated_text
- if normalize_numbers:
- print("🔧 正在标准化数字格式...")
- generated_text = self.normalize_financial_numbers(generated_text)
-
- # 统计标准化的变化
- changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
- if changes_count > 0:
- print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
- else:
- print("ℹ️ 无需标准化(已是标准格式)")
-
- print(f"✅ 成功完成处理!")
-
- # 保存文本结果
- return self._save_text_results(
- image_path=image_path,
- output_dir=output_dir,
- generated_text=generated_text,
- original_text=original_text,
- model_name=model_name,
- prompt_name=prompt_name,
- model_config=model_config,
- processing_params={
- 'temperature': temperature,
- 'max_tokens': max_tokens,
- 'timeout': timeout,
- 'normalize_numbers': normalize_numbers,
- 'custom_prompt_used': custom_prompt is not None,
- 'is_image_generation': is_image_generation
- }
- )
-
- except Exception as e:
- print(f"❌ 处理失败: {e}")
- raise
-
- def _save_image_results(self,
- image_path: str,
- output_dir: str,
- generated_text: str,
- base64_image: str,
- model_name: str,
- prompt_name: str,
- model_config: Dict[str, Any],
- processing_params: Dict[str, Any]) -> Dict[str, Any]:
- """保存图片生成结果"""
- # 创建输出目录
- output_path = Path(output_dir)
- output_path.mkdir(parents=True, exist_ok=True)
-
- # 生成输出文件名
- base_name = Path(image_path).stem
- timestamp = time.strftime("%Y%m%d_%H%M%S")
-
- # 保存生成的图片
- try:
- image_bytes = base64.b64decode(base64_image)
- image_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.png"
-
- with open(image_file, 'wb') as f:
- f.write(image_bytes)
- print(f"🖼️ 生成的图片已保存到: {image_file}")
-
- except Exception as e:
- print(f"❌ 图片保存失败: {e}")
- # 如果图片保存失败,保存为文本
- text_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.txt"
- with open(text_file, 'w', encoding='utf-8') as f:
- f.write(generated_text)
- print(f"📄 响应内容已保存为文本: {text_file}")
- image_file = text_file
-
- # 保存原始响应文本(包含可能的说明文字)
- if len(generated_text.strip()) > len(base64_image) + 100: # 如果有额外的说明文字
- description_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_description.txt"
- with open(description_file, 'w', encoding='utf-8') as f:
- f.write(generated_text)
- print(f"📝 响应说明已保存到: {description_file}")
-
- # 保存元数据
- metadata = {
- "processing_info": {
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
- "image_path": Path(image_path).resolve().as_posix(),
- "output_file": image_file.resolve().as_posix(),
- "model_used": model_name,
- "model_config": model_config,
- "prompt_template": prompt_name,
- "processing_params": processing_params,
- "result_type": "image",
- "text_stats": {
- "response_length": len(generated_text),
- "has_image_data": True,
- "base64_length": len(base64_image)
- }
- }
- }
-
- metadata_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_metadata.json"
- with open(metadata_file, 'w', encoding='utf-8') as f:
- json.dump(metadata, f, ensure_ascii=False, indent=2)
- print(f"📊 元数据已保存到: {metadata_file}")
-
- return metadata
-
- def _save_text_results(self,
- image_path: str,
- output_dir: str,
- generated_text: str,
- original_text: str,
- model_name: str,
- prompt_name: str,
- model_config: Dict[str, Any],
- processing_params: Dict[str, Any]) -> Dict[str, Any]:
- """保存文本结果"""
- # 创建输出目录
- output_path = Path(output_dir)
- output_path.mkdir(parents=True, exist_ok=True)
-
- # 生成输出文件名
- base_name = Path(image_path).stem
-
- # 保存主结果文件
- if prompt_name in ['ocr_standard', 'table_extract']:
- # OCR相关任务保存为Markdown格式
- result_file = output_path / f"{base_name}_{model_name}.md"
- with open(result_file, 'w', encoding='utf-8') as f:
- f.write(generated_text)
- print(f"📄 结果已保存到: {result_file}")
- else:
- # 其他任务保存为文本格式
- result_file = output_path / f"{base_name}_{model_name}_{prompt_name}.txt"
- with open(result_file, 'w', encoding='utf-8') as f:
- f.write(generated_text)
- print(f"📄 结果已保存到: {result_file}")
-
- # 如果进行了数字标准化,保存原始版本
- if processing_params['normalize_numbers'] and original_text != generated_text:
- original_file = output_path / f"{base_name}_{model_name}_original.txt"
- with open(original_file, 'w', encoding='utf-8') as f:
- f.write(original_text)
- print(f"📄 原始结果已保存到: {original_file}")
-
- # 保存元数据
- metadata = {
- "processing_info": {
- "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
- "image_path": Path(image_path).resolve().as_posix(),
- "output_file": result_file.resolve().as_posix(),
- "model_used": model_name,
- "model_config": model_config,
- "prompt_template": prompt_name,
- "processing_params": processing_params,
- "result_type": "text",
- "text_stats": {
- "original_length": len(original_text),
- "final_length": len(generated_text),
- "character_changes": len([1 for o, n in zip(original_text, generated_text) if o != n]) if processing_params['normalize_numbers'] else 0
- }
- }
- }
-
- metadata_file = output_path / f"{base_name}_{model_name}_metadata.json"
- with open(metadata_file, 'w', encoding='utf-8') as f:
- json.dump(metadata, f, ensure_ascii=False, indent=2)
- print(f"📊 元数据已保存到: {metadata_file}")
-
- return metadata
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(description='本地VLM图片处理工具')
-
- # 基本参数
- parser.add_argument('image_path', nargs='?', help='图片文件路径')
- parser.add_argument('-c', '--config', default='config.yaml', help='配置文件路径')
- parser.add_argument('-o', '--output', default='./output', help='输出目录')
-
- # 模型和提示词选择
- parser.add_argument('-m', '--model', help='模型名称')
- parser.add_argument('-p', '--prompt', help='提示词模板名称')
- parser.add_argument('--custom-prompt', help='自定义提示词(优先级高于-p参数)')
-
- # 处理参数
- parser.add_argument('-t', '--temperature', type=float, help='生成温度')
- parser.add_argument('--max-tokens', type=int, help='最大token数')
- parser.add_argument('--timeout', type=int, help='超时时间(秒)')
- parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化, 只有提取表格或ocr相关任务才启用')
-
- # 信息查询
- parser.add_argument('--list-models', action='store_true', help='列出所有可用模型')
- parser.add_argument('--list-prompts', action='store_true', help='列出所有提示词模板')
-
- args = parser.parse_args()
-
- try:
- # 初始化处理器
- processor = LocalVLMProcessor(args.config)
-
- # 处理信息查询请求
- if args.list_models:
- processor.list_models()
- return 0
-
- if args.list_prompts:
- processor.list_prompts()
- return 0
-
- # 检查是否提供了图片路径
- if not args.image_path:
- print("❌ 错误: 请提供图片文件路径")
- print("\n使用示例:")
- print(" python local_vlm_processor.py image.jpg")
- print(" python local_vlm_processor.py image.jpg -m qwen2_vl -p photo_analysis")
- print(" python local_vlm_processor.py image.jpg -p simple_photo_fix # 图片修复")
- print(" python local_vlm_processor.py --list-models")
- print(" python local_vlm_processor.py --list-prompts")
- return 1
-
- # 处理图片
- result = processor.process_image(
- image_path=args.image_path,
- model_name=args.model,
- prompt_name=args.prompt,
- output_dir=args.output,
- temperature=args.temperature,
- max_tokens=args.max_tokens,
- timeout=args.timeout,
- normalize_numbers=not args.no_normalize,
- custom_prompt=args.custom_prompt
- )
-
- print(f"\n🎉 处理完成!")
- print(f"📊 处理统计:")
-
- if result['processing_info']['result_type'] == 'image':
- stats = result['processing_info']['text_stats']
- print(f" 响应长度: {stats['response_length']} 字符")
- print(f" 图片数据: {'包含' if stats['has_image_data'] else '不包含'}")
- if stats['has_image_data']:
- print(f" Base64长度: {stats['base64_length']} 字符")
- else:
- stats = result['processing_info']['text_stats']
- print(f" 原始长度: {stats['original_length']} 字符")
- print(f" 最终长度: {stats['final_length']} 字符")
- if stats['character_changes'] > 0:
- print(f" 标准化变更: {stats['character_changes']} 字符")
-
- return 0
-
- except Exception as e:
- print(f"❌ 程序执行失败: {e}")
- return 1
- if __name__ == "__main__":
- # 如果sys.argv没有被传入参数,则提供默认参数用于测试
- import sys
- if len(sys.argv) == 1:
- sys.argv.extend([
- '../sample_data/工大照片-1.jpg',
- '-p', 'simple_photo_fix',
- '-o', './output',
- '--no-normalize'])
- exit(main())
|