|
|
@@ -0,0 +1,598 @@
|
|
|
+import os
|
|
|
+import re
|
|
|
+import yaml
|
|
|
+import base64
|
|
|
+import json
|
|
|
+import time
|
|
|
+import argparse
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, Any, Optional
|
|
|
+from openai import OpenAI
|
|
|
+from dotenv import load_dotenv
|
|
|
+
|
|
|
+# 加载环境变量
|
|
|
+load_dotenv(override=True)
|
|
|
+
|
|
|
+class LocalVLMProcessor:
|
|
|
+ def __init__(self, config_path: str = "config.yaml"):
|
|
|
+ """
|
|
|
+ 初始化本地VLM处理器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: 配置文件路径
|
|
|
+ """
|
|
|
+ self.config_path = Path(config_path)
|
|
|
+ self.config = self._load_config()
|
|
|
+
|
|
|
+ def _load_config(self) -> Dict[str, Any]:
|
|
|
+ """加载配置文件"""
|
|
|
+ if not self.config_path.exists():
|
|
|
+ raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
|
|
|
+
|
|
|
+ with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
|
+ config = yaml.safe_load(f)
|
|
|
+
|
|
|
+ return config
|
|
|
+
|
|
|
+ def _resolve_env_variable(self, value: str) -> str:
|
|
|
+ """
|
|
|
+ 解析环境变量:将 ${VAR_NAME} 格式替换为实际的环境变量值
|
|
|
+
|
|
|
+ Args:
|
|
|
+ value: 可能包含环境变量的字符串
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 解析后的字符串
|
|
|
+ """
|
|
|
+ if not isinstance(value, str):
|
|
|
+ return value
|
|
|
+
|
|
|
+ # 匹配 ${VAR_NAME} 格式的环境变量
|
|
|
+ pattern = r'\$\{([^}]+)\}'
|
|
|
+
|
|
|
+ def replace_env_var(match):
|
|
|
+ env_var_name = match.group(1)
|
|
|
+ env_value = os.getenv(env_var_name)
|
|
|
+ if env_value is None:
|
|
|
+ print(f"⚠️ 警告: 环境变量 {env_var_name} 未设置,使用原值")
|
|
|
+ return match.group(0)
|
|
|
+ return env_value
|
|
|
+
|
|
|
+ return re.sub(pattern, replace_env_var, value)
|
|
|
+
|
|
|
+ def _is_image_generation_prompt(self, prompt_name: str) -> bool:
|
|
|
+ """
|
|
|
+ 判断是否为图片生成相关的提示词
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prompt_name: 提示词名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True if 是图片生成任务
|
|
|
+ """
|
|
|
+ image_generation_prompts = [
|
|
|
+ 'photo_restore_classroom',
|
|
|
+ 'photo_restore_advanced',
|
|
|
+ 'photo_colorize_classroom',
|
|
|
+ 'simple_photo_fix'
|
|
|
+ ]
|
|
|
+ return prompt_name in image_generation_prompts
|
|
|
+
|
|
|
+ def _extract_base64_image(self, response_text: str) -> Optional[str]:
|
|
|
+ """
|
|
|
+ 从响应文本中提取base64编码的图片
|
|
|
+
|
|
|
+ Args:
|
|
|
+ response_text: API响应文本
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ base64编码的图片数据,如果没找到返回None
|
|
|
+ """
|
|
|
+ # 常见的base64图片数据格式
|
|
|
+ patterns = [
|
|
|
+ r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', # data URL格式
|
|
|
+ r'base64:([A-Za-z0-9+/=]{100,})', # base64:前缀
|
|
|
+ r'```base64\s*\n([A-Za-z0-9+/=\s]+)\n```', # markdown代码块
|
|
|
+ r'<img[^>]*src="data:image/[^;]+;base64,([A-Za-z0-9+/=]+)"[^>]*>', # HTML img标签
|
|
|
+ ]
|
|
|
+
|
|
|
+ for pattern in patterns:
|
|
|
+ match = re.search(pattern, response_text, re.MULTILINE | re.DOTALL)
|
|
|
+ if match:
|
|
|
+ base64_data = match.group(1).replace('\n', '').replace(' ', '')
|
|
|
+ if len(base64_data) > 1000: # 合理的图片大小
|
|
|
+ return base64_data
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ def list_models(self) -> None:
|
|
|
+ """列出所有可用的模型"""
|
|
|
+ print("📋 可用模型列表:")
|
|
|
+ for model_key, model_config in self.config['models'].items():
|
|
|
+ resolved_api_key = self._resolve_env_variable(model_config['api_key'])
|
|
|
+ api_key_status = "✅ 已配置" if resolved_api_key else "❌ 未配置"
|
|
|
+
|
|
|
+ print(f" 🤖 {model_key}: {model_config['name']}")
|
|
|
+ print(f" API地址: {model_config['api_base']}")
|
|
|
+ print(f" 模型ID: {model_config['model_id']}")
|
|
|
+ print(f" API密钥: {api_key_status}")
|
|
|
+ print()
|
|
|
+
|
|
|
+ def list_prompts(self) -> None:
|
|
|
+ """列出所有可用的提示词模板"""
|
|
|
+ print("📝 可用提示词模板:")
|
|
|
+ for prompt_key, prompt_config in self.config['prompts'].items():
|
|
|
+ is_image_gen = self._is_image_generation_prompt(prompt_key)
|
|
|
+ task_type = "🖼️ 图片生成" if is_image_gen else "📝 文本生成"
|
|
|
+
|
|
|
+ print(f" 💬 {prompt_key}: {prompt_config['name']} ({task_type})")
|
|
|
+ # 显示模板的前100个字符
|
|
|
+ template_preview = prompt_config['template'][:100].replace('\n', ' ')
|
|
|
+ print(f" 预览: {template_preview}...")
|
|
|
+ print()
|
|
|
+
|
|
|
+ def get_model_config(self, model_name: str) -> Dict[str, Any]:
|
|
|
+ """获取模型配置"""
|
|
|
+ if model_name not in self.config['models']:
|
|
|
+ raise ValueError(f"未找到模型配置: {model_name},可用模型: {list(self.config['models'].keys())}")
|
|
|
+
|
|
|
+ model_config = self.config['models'][model_name].copy()
|
|
|
+
|
|
|
+ # 解析环境变量
|
|
|
+ model_config['api_key'] = self._resolve_env_variable(model_config['api_key'])
|
|
|
+
|
|
|
+ return model_config
|
|
|
+
|
|
|
+ def get_prompt_template(self, prompt_name: str) -> str:
|
|
|
+ """获取提示词模板"""
|
|
|
+ if prompt_name not in self.config['prompts']:
|
|
|
+ raise ValueError(f"未找到提示词模板: {prompt_name},可用模板: {list(self.config['prompts'].keys())}")
|
|
|
+
|
|
|
+ return self.config['prompts'][prompt_name]['template']
|
|
|
+
|
|
|
+ def normalize_financial_numbers(self, text: str) -> str:
|
|
|
+ """
|
|
|
+ 标准化财务数字:将全角字符转换为半角字符
|
|
|
+ """
|
|
|
+ if not text:
|
|
|
+ return text
|
|
|
+
|
|
|
+ # 定义全角到半角的映射
|
|
|
+ fullwidth_to_halfwidth = {
|
|
|
+ '0': '0', '1': '1', '2': '2', '3': '3', '4': '4',
|
|
|
+ '5': '5', '6': '6', '7': '7', '8': '8', '9': '9',
|
|
|
+ ',': ',', '。': '.', '.': '.', ':': ':',
|
|
|
+ ';': ';', '(': '(', ')': ')', '-': '-',
|
|
|
+ '+': '+', '%': '%',
|
|
|
+ }
|
|
|
+
|
|
|
+ # 执行字符替换
|
|
|
+ normalized_text = text
|
|
|
+ for fullwidth, halfwidth in fullwidth_to_halfwidth.items():
|
|
|
+ normalized_text = normalized_text.replace(fullwidth, halfwidth)
|
|
|
+
|
|
|
+ return normalized_text
|
|
|
+
|
|
|
+ def process_image(self,
|
|
|
+ image_path: str,
|
|
|
+ model_name: Optional[str] = None,
|
|
|
+ prompt_name: Optional[str] = None,
|
|
|
+ output_dir: str = "./output",
|
|
|
+ temperature: Optional[float] = None,
|
|
|
+ max_tokens: Optional[int] = None,
|
|
|
+ timeout: Optional[int] = None,
|
|
|
+ normalize_numbers: Optional[bool] = None,
|
|
|
+ custom_prompt: Optional[str] = None) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 处理单张图片
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_path: 图片路径
|
|
|
+ model_name: 模型名称
|
|
|
+ prompt_name: 提示词模板名称
|
|
|
+ output_dir: 输出目录
|
|
|
+ temperature: 生成温度
|
|
|
+ max_tokens: 最大token数
|
|
|
+ timeout: 超时时间
|
|
|
+ normalize_numbers: 是否标准化数字
|
|
|
+ custom_prompt: 自定义提示词(优先级高于prompt_name)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 处理结果字典
|
|
|
+ """
|
|
|
+ # 使用默认值或配置值
|
|
|
+ model_name = model_name or self.config['default']['model']
|
|
|
+ prompt_name = prompt_name or self.config['default']['prompt']
|
|
|
+
|
|
|
+ # 判断是否为图片生成任务
|
|
|
+ is_image_generation = custom_prompt is None and self._is_image_generation_prompt(prompt_name)
|
|
|
+
|
|
|
+ # 图片生成任务默认不进行数字标准化
|
|
|
+ if is_image_generation:
|
|
|
+ normalize_numbers = False
|
|
|
+ print(f"🖼️ 检测到图片生成任务,自动禁用数字标准化")
|
|
|
+ else:
|
|
|
+ normalize_numbers = normalize_numbers if normalize_numbers is not None else self.config['default']['normalize_numbers']
|
|
|
+
|
|
|
+ # 获取模型配置
|
|
|
+ model_config = self.get_model_config(model_name)
|
|
|
+
|
|
|
+ # 设置参数,优先使用传入的参数
|
|
|
+ temperature = temperature if temperature is not None else model_config['default_params']['temperature']
|
|
|
+ max_tokens = max_tokens if max_tokens is not None else model_config['default_params']['max_tokens']
|
|
|
+ timeout = timeout if timeout is not None else model_config['default_params']['timeout']
|
|
|
+
|
|
|
+ # 获取提示词
|
|
|
+ if custom_prompt:
|
|
|
+ prompt = custom_prompt
|
|
|
+ print(f"🎯 使用自定义提示词")
|
|
|
+ else:
|
|
|
+ prompt = self.get_prompt_template(prompt_name)
|
|
|
+ task_type = "图片生成" if is_image_generation else "文本分析"
|
|
|
+ print(f"🎯 使用提示词模板: {prompt_name} ({task_type})")
|
|
|
+
|
|
|
+ # 读取图片文件并转换为base64
|
|
|
+ if not Path(image_path).exists():
|
|
|
+ raise FileNotFoundError(f"找不到图片文件: {image_path}")
|
|
|
+
|
|
|
+ with open(image_path, "rb") as image_file:
|
|
|
+ image_data = base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
+
|
|
|
+ # 获取图片的MIME类型
|
|
|
+ file_extension = Path(image_path).suffix.lower()
|
|
|
+ mime_type_map = {
|
|
|
+ '.jpg': 'image/jpeg',
|
|
|
+ '.jpeg': 'image/jpeg',
|
|
|
+ '.png': 'image/png',
|
|
|
+ '.gif': 'image/gif',
|
|
|
+ '.webp': 'image/webp'
|
|
|
+ }
|
|
|
+ mime_type = mime_type_map.get(file_extension, 'image/jpeg')
|
|
|
+
|
|
|
+ # 创建OpenAI客户端
|
|
|
+ client = OpenAI(
|
|
|
+ api_key=model_config['api_key'] or "dummy-key",
|
|
|
+ base_url=model_config['api_base']
|
|
|
+ )
|
|
|
+
|
|
|
+ # 构建消息
|
|
|
+ messages = [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": prompt
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": f"data:{mime_type};base64,{image_data}"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 显示处理信息
|
|
|
+ print(f"\n🚀 开始处理图片: {Path(image_path).name}")
|
|
|
+ print(f"🤖 使用模型: {model_config['name']} ({model_name})")
|
|
|
+ print(f"🌐 API地址: {model_config['api_base']}")
|
|
|
+ print(f"🔧 参数配置:")
|
|
|
+ print(f" - 温度: {temperature}")
|
|
|
+ print(f" - 最大Token: {max_tokens}")
|
|
|
+ print(f" - 超时时间: {timeout}秒")
|
|
|
+ print(f" - 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
+ print(f" - 任务类型: {'图片生成' if is_image_generation else '文本分析'}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 调用API
|
|
|
+ response = client.chat.completions.create(
|
|
|
+ model=model_config['model_id'],
|
|
|
+ messages=messages,
|
|
|
+ temperature=temperature,
|
|
|
+ max_tokens=max_tokens,
|
|
|
+ timeout=timeout
|
|
|
+ )
|
|
|
+
|
|
|
+ # 提取响应内容
|
|
|
+ generated_text = response.choices[0].message.content
|
|
|
+
|
|
|
+ if not generated_text:
|
|
|
+ raise Exception("模型没有生成内容")
|
|
|
+
|
|
|
+ # 处理图片生成结果
|
|
|
+ if is_image_generation:
|
|
|
+ # 尝试提取base64图片数据
|
|
|
+ base64_image = self._extract_base64_image(generated_text)
|
|
|
+ if base64_image:
|
|
|
+ print("🖼️ 检测到生成的图片数据")
|
|
|
+ return self._save_image_results(
|
|
|
+ image_path=image_path,
|
|
|
+ output_dir=output_dir,
|
|
|
+ generated_text=generated_text,
|
|
|
+ base64_image=base64_image,
|
|
|
+ model_name=model_name,
|
|
|
+ prompt_name=prompt_name,
|
|
|
+ model_config=model_config,
|
|
|
+ processing_params={
|
|
|
+ 'temperature': temperature,
|
|
|
+ 'max_tokens': max_tokens,
|
|
|
+ 'timeout': timeout,
|
|
|
+ 'normalize_numbers': normalize_numbers,
|
|
|
+ 'custom_prompt_used': custom_prompt is not None,
|
|
|
+ 'is_image_generation': True
|
|
|
+ }
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ print("⚠️ 未检测到图片数据,保存为文本结果")
|
|
|
+
|
|
|
+ # 标准化数字格式(如果启用)
|
|
|
+ original_text = generated_text
|
|
|
+ if normalize_numbers:
|
|
|
+ print("🔧 正在标准化数字格式...")
|
|
|
+ generated_text = self.normalize_financial_numbers(generated_text)
|
|
|
+
|
|
|
+ # 统计标准化的变化
|
|
|
+ changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
|
|
|
+ if changes_count > 0:
|
|
|
+ print(f"✅ 已标准化 {changes_count} 个字符(全角→半角)")
|
|
|
+ else:
|
|
|
+ print("ℹ️ 无需标准化(已是标准格式)")
|
|
|
+
|
|
|
+ print(f"✅ 成功完成处理!")
|
|
|
+
|
|
|
+ # 保存文本结果
|
|
|
+ return self._save_text_results(
|
|
|
+ image_path=image_path,
|
|
|
+ output_dir=output_dir,
|
|
|
+ generated_text=generated_text,
|
|
|
+ original_text=original_text,
|
|
|
+ model_name=model_name,
|
|
|
+ prompt_name=prompt_name,
|
|
|
+ model_config=model_config,
|
|
|
+ processing_params={
|
|
|
+ 'temperature': temperature,
|
|
|
+ 'max_tokens': max_tokens,
|
|
|
+ 'timeout': timeout,
|
|
|
+ 'normalize_numbers': normalize_numbers,
|
|
|
+ 'custom_prompt_used': custom_prompt is not None,
|
|
|
+ 'is_image_generation': is_image_generation
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 处理失败: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def _save_image_results(self,
|
|
|
+ image_path: str,
|
|
|
+ output_dir: str,
|
|
|
+ generated_text: str,
|
|
|
+ base64_image: str,
|
|
|
+ model_name: str,
|
|
|
+ prompt_name: str,
|
|
|
+ model_config: Dict[str, Any],
|
|
|
+ processing_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ """保存图片生成结果"""
|
|
|
+ # 创建输出目录
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 生成输出文件名
|
|
|
+ base_name = Path(image_path).stem
|
|
|
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
|
+
|
|
|
+ # 保存生成的图片
|
|
|
+ try:
|
|
|
+ image_bytes = base64.b64decode(base64_image)
|
|
|
+ image_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.png"
|
|
|
+
|
|
|
+ with open(image_file, 'wb') as f:
|
|
|
+ f.write(image_bytes)
|
|
|
+ print(f"🖼️ 生成的图片已保存到: {image_file}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 图片保存失败: {e}")
|
|
|
+ # 如果图片保存失败,保存为文本
|
|
|
+ text_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}.txt"
|
|
|
+ with open(text_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(generated_text)
|
|
|
+ print(f"📄 响应内容已保存为文本: {text_file}")
|
|
|
+ image_file = text_file
|
|
|
+
|
|
|
+ # 保存原始响应文本(包含可能的说明文字)
|
|
|
+ if len(generated_text.strip()) > len(base64_image) + 100: # 如果有额外的说明文字
|
|
|
+ description_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_description.txt"
|
|
|
+ with open(description_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(generated_text)
|
|
|
+ print(f"📝 响应说明已保存到: {description_file}")
|
|
|
+
|
|
|
+ # 保存元数据
|
|
|
+ metadata = {
|
|
|
+ "processing_info": {
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "image_path": Path(image_path).resolve().as_posix(),
|
|
|
+ "output_file": image_file.resolve().as_posix(),
|
|
|
+ "model_used": model_name,
|
|
|
+ "model_config": model_config,
|
|
|
+ "prompt_template": prompt_name,
|
|
|
+ "processing_params": processing_params,
|
|
|
+ "result_type": "image",
|
|
|
+ "text_stats": {
|
|
|
+ "response_length": len(generated_text),
|
|
|
+ "has_image_data": True,
|
|
|
+ "base64_length": len(base64_image)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ metadata_file = output_path / f"{base_name}_{model_name}_{prompt_name}_{timestamp}_metadata.json"
|
|
|
+ with open(metadata_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
|
+ print(f"📊 元数据已保存到: {metadata_file}")
|
|
|
+
|
|
|
+ return metadata
|
|
|
+
|
|
|
+ def _save_text_results(self,
|
|
|
+ image_path: str,
|
|
|
+ output_dir: str,
|
|
|
+ generated_text: str,
|
|
|
+ original_text: str,
|
|
|
+ model_name: str,
|
|
|
+ prompt_name: str,
|
|
|
+ model_config: Dict[str, Any],
|
|
|
+ processing_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ """保存文本结果"""
|
|
|
+ # 创建输出目录
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 生成输出文件名
|
|
|
+ base_name = Path(image_path).stem
|
|
|
+
|
|
|
+ # 保存主结果文件
|
|
|
+ if prompt_name in ['ocr_standard', 'table_extract']:
|
|
|
+ # OCR相关任务保存为Markdown格式
|
|
|
+ result_file = output_path / f"{base_name}_{model_name}.md"
|
|
|
+ with open(result_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(generated_text)
|
|
|
+ print(f"📄 结果已保存到: {result_file}")
|
|
|
+ else:
|
|
|
+ # 其他任务保存为文本格式
|
|
|
+ result_file = output_path / f"{base_name}_{model_name}_{prompt_name}.txt"
|
|
|
+ with open(result_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(generated_text)
|
|
|
+ print(f"📄 结果已保存到: {result_file}")
|
|
|
+
|
|
|
+ # 如果进行了数字标准化,保存原始版本
|
|
|
+ if processing_params['normalize_numbers'] and original_text != generated_text:
|
|
|
+ original_file = output_path / f"{base_name}_{model_name}_original.txt"
|
|
|
+ with open(original_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(original_text)
|
|
|
+ print(f"📄 原始结果已保存到: {original_file}")
|
|
|
+
|
|
|
+ # 保存元数据
|
|
|
+ metadata = {
|
|
|
+ "processing_info": {
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "image_path": Path(image_path).resolve().as_posix(),
|
|
|
+ "output_file": result_file.resolve().as_posix(),
|
|
|
+ "model_used": model_name,
|
|
|
+ "model_config": model_config,
|
|
|
+ "prompt_template": prompt_name,
|
|
|
+ "processing_params": processing_params,
|
|
|
+ "result_type": "text",
|
|
|
+ "text_stats": {
|
|
|
+ "original_length": len(original_text),
|
|
|
+ "final_length": len(generated_text),
|
|
|
+ "character_changes": len([1 for o, n in zip(original_text, generated_text) if o != n]) if processing_params['normalize_numbers'] else 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ metadata_file = output_path / f"{base_name}_{model_name}_metadata.json"
|
|
|
+ with open(metadata_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
|
+ print(f"📊 元数据已保存到: {metadata_file}")
|
|
|
+
|
|
|
+ return metadata
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主函数"""
|
|
|
+ parser = argparse.ArgumentParser(description='本地VLM图片处理工具')
|
|
|
+
|
|
|
+ # 基本参数
|
|
|
+ parser.add_argument('image_path', nargs='?', help='图片文件路径')
|
|
|
+ parser.add_argument('-c', '--config', default='config.yaml', help='配置文件路径')
|
|
|
+ parser.add_argument('-o', '--output', default='./output', help='输出目录')
|
|
|
+
|
|
|
+ # 模型和提示词选择
|
|
|
+ parser.add_argument('-m', '--model', help='模型名称')
|
|
|
+ parser.add_argument('-p', '--prompt', help='提示词模板名称')
|
|
|
+ parser.add_argument('--custom-prompt', help='自定义提示词(优先级高于-p参数)')
|
|
|
+
|
|
|
+ # 处理参数
|
|
|
+ parser.add_argument('-t', '--temperature', type=float, help='生成温度')
|
|
|
+ parser.add_argument('--max-tokens', type=int, help='最大token数')
|
|
|
+ parser.add_argument('--timeout', type=int, help='超时时间(秒)')
|
|
|
+ parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化, 只有提取表格或ocr相关任务才启用')
|
|
|
+
|
|
|
+ # 信息查询
|
|
|
+ parser.add_argument('--list-models', action='store_true', help='列出所有可用模型')
|
|
|
+ parser.add_argument('--list-prompts', action='store_true', help='列出所有提示词模板')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 初始化处理器
|
|
|
+ processor = LocalVLMProcessor(args.config)
|
|
|
+
|
|
|
+ # 处理信息查询请求
|
|
|
+ if args.list_models:
|
|
|
+ processor.list_models()
|
|
|
+ return 0
|
|
|
+
|
|
|
+ if args.list_prompts:
|
|
|
+ processor.list_prompts()
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # 检查是否提供了图片路径
|
|
|
+ if not args.image_path:
|
|
|
+ print("❌ 错误: 请提供图片文件路径")
|
|
|
+ print("\n使用示例:")
|
|
|
+ print(" python local_vlm_processor.py image.jpg")
|
|
|
+ print(" python local_vlm_processor.py image.jpg -m qwen2_vl -p photo_analysis")
|
|
|
+ print(" python local_vlm_processor.py image.jpg -p simple_photo_fix # 图片修复")
|
|
|
+ print(" python local_vlm_processor.py --list-models")
|
|
|
+ print(" python local_vlm_processor.py --list-prompts")
|
|
|
+ return 1
|
|
|
+
|
|
|
+ # 处理图片
|
|
|
+ result = processor.process_image(
|
|
|
+ image_path=args.image_path,
|
|
|
+ model_name=args.model,
|
|
|
+ prompt_name=args.prompt,
|
|
|
+ output_dir=args.output,
|
|
|
+ temperature=args.temperature,
|
|
|
+ max_tokens=args.max_tokens,
|
|
|
+ timeout=args.timeout,
|
|
|
+ normalize_numbers=not args.no_normalize,
|
|
|
+ custom_prompt=args.custom_prompt
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n🎉 处理完成!")
|
|
|
+ print(f"📊 处理统计:")
|
|
|
+
|
|
|
+ if result['processing_info']['result_type'] == 'image':
|
|
|
+ stats = result['processing_info']['text_stats']
|
|
|
+ print(f" 响应长度: {stats['response_length']} 字符")
|
|
|
+ print(f" 图片数据: {'包含' if stats['has_image_data'] else '不包含'}")
|
|
|
+ if stats['has_image_data']:
|
|
|
+ print(f" Base64长度: {stats['base64_length']} 字符")
|
|
|
+ else:
|
|
|
+ stats = result['processing_info']['text_stats']
|
|
|
+ print(f" 原始长度: {stats['original_length']} 字符")
|
|
|
+ print(f" 最终长度: {stats['final_length']} 字符")
|
|
|
+ if stats['character_changes'] > 0:
|
|
|
+ print(f" 标准化变更: {stats['character_changes']} 字符")
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 程序执行失败: {e}")
|
|
|
+ return 1
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 如果sys.argv没有被传入参数,则提供默认参数用于测试
|
|
|
+ import sys
|
|
|
+ if len(sys.argv) == 1:
|
|
|
+ sys.argv.extend([
|
|
|
+ '../sample_data/工大照片-1.jpg',
|
|
|
+ '-p', 'simple_photo_fix',
|
|
|
+ '-o', './output',
|
|
|
+ '--no-normalize'])
|
|
|
+
|
|
|
+ exit(main())
|