|
|
@@ -0,0 +1,692 @@
|
|
|
+import os
|
|
|
+import requests
|
|
|
+import time
|
|
|
+import json
|
|
|
+import yaml
|
|
|
+import base64
|
|
|
+import argparse
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, Any, Optional, List
|
|
|
+from PIL import Image
|
|
|
+from io import BytesIO
|
|
|
+from dotenv import load_dotenv
|
|
|
+
|
|
|
+# 加载环境变量
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+class ImageGenerator:
|
|
|
+ def __init__(self, config_path: str = "config.yaml"):
|
|
|
+ """
|
|
|
+ 初始化图片生成器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config_path: 配置文件路径
|
|
|
+ """
|
|
|
+ self.config_path = Path(config_path)
|
|
|
+ self.config = self._load_config()
|
|
|
+
|
|
|
+ def _load_config(self) -> Dict[str, Any]:
|
|
|
+ """加载配置文件"""
|
|
|
+ if not self.config_path.exists():
|
|
|
+ raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
|
|
|
+
|
|
|
+ with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
|
+ config = yaml.safe_load(f)
|
|
|
+
|
|
|
+ return config
|
|
|
+
|
|
|
+ def _resolve_env_variable(self, value: str) -> str:
|
|
|
+ """解析环境变量"""
|
|
|
+ if not isinstance(value, str):
|
|
|
+ return value
|
|
|
+
|
|
|
+ import re
|
|
|
+ pattern = r'\$\{([^}]+)\}'
|
|
|
+
|
|
|
+ def replace_env_var(match):
|
|
|
+ env_var_name = match.group(1)
|
|
|
+ env_value = os.getenv(env_var_name)
|
|
|
+ if env_value is None:
|
|
|
+ print(f"⚠️ 警告: 环境变量 {env_var_name} 未设置")
|
|
|
+ return ""
|
|
|
+ return env_value
|
|
|
+
|
|
|
+ return re.sub(pattern, replace_env_var, value)
|
|
|
+
|
|
|
+ def list_models(self, model_type: str = "image_generation") -> None:
|
|
|
+ """列出指定类型的模型"""
|
|
|
+ print(f"📋 可用的{model_type}模型列表:")
|
|
|
+ for model_key, model_config in self.config['models'].items():
|
|
|
+ if model_config.get('type') == model_type:
|
|
|
+ api_key = self._resolve_env_variable(model_config['api_key'])
|
|
|
+ api_key_status = "✅ 已配置" if api_key else "❌ 未配置"
|
|
|
+
|
|
|
+ print(f" 🎨 {model_key}: {model_config['name']}")
|
|
|
+ print(f" 生成类型: {model_config.get('generation_type', 'N/A')}")
|
|
|
+ print(f" API地址: {model_config['api_base']}")
|
|
|
+ print(f" API密钥: {api_key_status}")
|
|
|
+ print()
|
|
|
+
|
|
|
+ def list_styles(self) -> None:
|
|
|
+ """列出可用的风格预设"""
|
|
|
+ print("🎨 可用风格预设:")
|
|
|
+ for style_key, styles in self.config.get('style_presets', {}).items():
|
|
|
+ print(f"\n 📝 {style_key}:")
|
|
|
+ for style in styles:
|
|
|
+ print(f" {style['index']}: {style['name']} - {style['description']}")
|
|
|
+
|
|
|
+ def list_prompts(self, prompt_type: str = "image_generation") -> None:
|
|
|
+ """列出指定类型的提示词模板"""
|
|
|
+ print(f"📝 可用的{prompt_type}提示词模板:")
|
|
|
+ for prompt_key, prompt_config in self.config.get('prompts', {}).items():
|
|
|
+ if prompt_config.get('type') == prompt_type:
|
|
|
+ print(f" 💬 {prompt_key}: {prompt_config['name']}")
|
|
|
+
|
|
|
+ # 显示兼容的模型
|
|
|
+ compatible_models = prompt_config.get('compatible_models', [])
|
|
|
+ if compatible_models:
|
|
|
+ print(f" 兼容模型: {', '.join(compatible_models)}")
|
|
|
+
|
|
|
+ # 显示模板预览(前100个字符)
|
|
|
+ template_preview = prompt_config.get('template', '')[:100].replace('\n', ' ')
|
|
|
+ print(f" 模板预览: {template_preview}...")
|
|
|
+ print()
|
|
|
+
|
|
|
+ def get_model_config(self, model_name: str) -> Dict[str, Any]:
|
|
|
+ """获取模型配置并解析环境变量"""
|
|
|
+ if model_name not in self.config['models']:
|
|
|
+ raise ValueError(f"未找到模型配置: {model_name}")
|
|
|
+
|
|
|
+ model_config = self.config['models'][model_name].copy()
|
|
|
+ model_config['api_key'] = self._resolve_env_variable(model_config['api_key'])
|
|
|
+
|
|
|
+ return model_config
|
|
|
+
|
|
|
+ def get_prompt_template(self, prompt_name: str) -> str:
|
|
|
+ """获取提示词模板"""
|
|
|
+ if prompt_name not in self.config.get('prompts', {}):
|
|
|
+ raise ValueError(f"未找到提示词模板: {prompt_name},可用模板: {list(self.config.get('prompts', {}).keys())}")
|
|
|
+
|
|
|
+ return self.config['prompts'][prompt_name]['template']
|
|
|
+
|
|
|
+ def check_prompt_model_compatibility(self, prompt_name: str, model_name: str) -> bool:
|
|
|
+ """检查提示词模板与模型的兼容性"""
|
|
|
+ prompt_config = self.config.get('prompts', {}).get(prompt_name, {})
|
|
|
+ compatible_models = prompt_config.get('compatible_models', [])
|
|
|
+
|
|
|
+ # 如果没有指定兼容模型,则认为所有模型都兼容
|
|
|
+ if not compatible_models:
|
|
|
+ return True
|
|
|
+
|
|
|
+ return model_name in compatible_models
|
|
|
+
|
|
|
+ def upload_image_to_temp(self, image_path: str, convert_to_rgba: bool = False) -> str:
|
|
|
+ """
|
|
|
+ 上传图片到临时存储并转换为指定格式
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_path: 图片路径
|
|
|
+ convert_to_rgba: 是否转换为RGBA格式
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ base64编码的图片数据URL
|
|
|
+ """
|
|
|
+ if not Path(image_path).exists():
|
|
|
+ raise FileNotFoundError(f"找不到图片文件: {image_path}")
|
|
|
+
|
|
|
+ # 使用PIL打开图片
|
|
|
+ with Image.open(image_path) as img:
|
|
|
+ # 如果需要转换为RGBA格式
|
|
|
+ if convert_to_rgba:
|
|
|
+ if img.mode != 'RGBA':
|
|
|
+ print(f"🔄 将图片从 {img.mode} 模式转换为 RGBA 模式")
|
|
|
+ # 转换为RGBA模式
|
|
|
+ if img.mode == 'RGB':
|
|
|
+ # RGB转RGBA,添加不透明度通道
|
|
|
+ img = img.convert('RGBA')
|
|
|
+ elif img.mode == 'L':
|
|
|
+ # 灰度转RGBA
|
|
|
+ img = img.convert('RGBA')
|
|
|
+ elif img.mode == 'P':
|
|
|
+ # 调色板模式转RGBA
|
|
|
+ img = img.convert('RGBA')
|
|
|
+ else:
|
|
|
+ # 其他模式先转RGB再转RGBA
|
|
|
+ img = img.convert('RGB').convert('RGBA')
|
|
|
+ print(f"✅ 图片已转换为 RGBA 模式")
|
|
|
+
|
|
|
+ # 保存为PNG格式的字节流(PNG支持RGBA)
|
|
|
+ from io import BytesIO
|
|
|
+ img_buffer = BytesIO()
|
|
|
+
|
|
|
+ # 如果是RGBA模式,保存为PNG;否则保存为JPEG
|
|
|
+ if img.mode == 'RGBA':
|
|
|
+ img.save(img_buffer, format='PNG')
|
|
|
+ mime_type = 'image/png'
|
|
|
+ else:
|
|
|
+ # 对于RGB等模式,转换为RGB再保存为JPEG
|
|
|
+ if img.mode != 'RGB':
|
|
|
+ img = img.convert('RGB')
|
|
|
+ img.save(img_buffer, format='JPEG', quality=95)
|
|
|
+ mime_type = 'image/jpeg'
|
|
|
+
|
|
|
+ img_buffer.seek(0)
|
|
|
+ image_data = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
|
|
|
+
|
|
|
+ return f"data:{mime_type};base64,{image_data}"
|
|
|
+
|
|
|
+ def generate_image_dashscope_style_repaint(self,
|
|
|
+ model_config: Dict[str, Any],
|
|
|
+ image_path: str,
|
|
|
+ style_index: int = None,
|
|
|
+ custom_style_url: str = None,
|
|
|
+ prompt_template: str = None) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 使用通义万相进行风格重绘
|
|
|
+ 注意:通义万相的风格重绘API不支持文本提示词,只能通过style_index或style_ref_url控制风格
|
|
|
+ """
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {model_config['api_key']}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "X-DashScope-Async": "enable"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 上传图片(风格重绘不需要RGBA格式)
|
|
|
+ print(f"📤 读取图片: {Path(image_path).name}")
|
|
|
+ image_url = self.upload_image_to_temp(image_path, convert_to_rgba=False)
|
|
|
+
|
|
|
+ # 构建请求体
|
|
|
+ if custom_style_url:
|
|
|
+ # 使用自定义风格
|
|
|
+ body = {
|
|
|
+ "model": model_config['model_id'],
|
|
|
+ "input": {
|
|
|
+ "image_url": image_url,
|
|
|
+ "style_ref_url": custom_style_url,
|
|
|
+ "style_index": -1
|
|
|
+ }
|
|
|
+ }
|
|
|
+ print(f"🎨 使用自定义风格参考: {custom_style_url}")
|
|
|
+ else:
|
|
|
+ # 使用预置风格
|
|
|
+ style_idx = style_index if style_index is not None else model_config['default_params']['style_index']
|
|
|
+
|
|
|
+ # 如果有提示词模板但没有指定风格索引,尝试根据模板内容智能选择风格
|
|
|
+ if prompt_template and style_index is None:
|
|
|
+ style_idx = self._select_style_from_template(prompt_template)
|
|
|
+ print(f"🤖 根据提示词模板智能选择风格索引: {style_idx}")
|
|
|
+
|
|
|
+ body = {
|
|
|
+ "model": model_config['model_id'],
|
|
|
+ "input": {
|
|
|
+ "image_url": image_url,
|
|
|
+ "style_index": style_idx
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ # 显示提示词模板信息(仅用于记录,不影响API调用)
|
|
|
+ if prompt_template:
|
|
|
+ print(f"📝 提示词模板内容(仅作为风格选择参考):")
|
|
|
+ print(f" {prompt_template[:200]}...")
|
|
|
+ print(f"⚠️ 注意: 通义万相风格重绘API不支持文本提示,仅通过风格索引控制效果")
|
|
|
+
|
|
|
+ # 提交任务
|
|
|
+ print(f"🚀 提交风格重绘任务...")
|
|
|
+ print(f" 风格索引: {body['input'].get('style_index', '自定义')}")
|
|
|
+
|
|
|
+ response = requests.post(model_config['api_base'], headers=headers, json=body)
|
|
|
+
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
|
|
|
+
|
|
|
+ task_id = response.json().get('output', {}).get('task_id')
|
|
|
+ if not task_id:
|
|
|
+ raise Exception("未获取到任务ID")
|
|
|
+
|
|
|
+ print(f"✅ 任务提交成功,任务ID: {task_id}")
|
|
|
+
|
|
|
+ # 轮询查询结果
|
|
|
+ return self._poll_dashscope_task(model_config, task_id)
|
|
|
+
|
|
|
+ def generate_image_modelscope(self,
|
|
|
+ model_config: Dict[str, Any],
|
|
|
+ prompt: str,
|
|
|
+ prompt_template: str = None) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 使用ModelScope进行文生图
|
|
|
+ """
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {model_config['api_key']}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "X-ModelScope-Async-Mode": "true"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 如果提供了提示词模板,将其与用户提示词结合
|
|
|
+ final_prompt = prompt
|
|
|
+ if prompt_template and prompt_template.strip():
|
|
|
+ print(f"🎯 使用提示词模板优化提示词")
|
|
|
+ # 简单的模板应用:将用户提示词插入到模板中
|
|
|
+ if "{prompt}" in prompt_template:
|
|
|
+ final_prompt = prompt_template.replace("{prompt}", prompt)
|
|
|
+ else:
|
|
|
+ # 如果模板中没有占位符,则将用户提示词追加到模板后
|
|
|
+ final_prompt = f"{prompt_template}\n\n具体要求:{prompt}"
|
|
|
+
|
|
|
+ body = {
|
|
|
+ "model": model_config['model_id'],
|
|
|
+ "prompt": final_prompt
|
|
|
+ }
|
|
|
+
|
|
|
+ print(f"🚀 提交文生图任务...")
|
|
|
+ print(f" 最终提示词: {final_prompt[:100]}...")
|
|
|
+
|
|
|
+ response = requests.post(model_config['api_base'], headers=headers, json=body)
|
|
|
+
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
|
|
|
+
|
|
|
+ task_id = response.json().get("task_id")
|
|
|
+ if not task_id:
|
|
|
+ raise Exception("未获取到任务ID")
|
|
|
+
|
|
|
+ print(f"✅ 任务提交成功,任务ID: {task_id}")
|
|
|
+
|
|
|
+ # 轮询查询结果
|
|
|
+ return self._poll_modelscope_task(model_config, task_id)
|
|
|
+
|
|
|
+ def generate_image_dashscope_flux(self,
|
|
|
+ model_config: Dict[str, Any],
|
|
|
+ prompt: str,
|
|
|
+ size: str = None,
|
|
|
+ prompt_template: str = None) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 使用通义万相FLUX进行文生图
|
|
|
+ """
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {model_config['api_key']}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "X-DashScope-Async": "enable"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 如果提供了提示词模板,将其与用户提示词结合
|
|
|
+ final_prompt = prompt
|
|
|
+ if prompt_template and prompt_template.strip():
|
|
|
+ print(f"🎯 使用提示词模板优化提示词")
|
|
|
+ if "{prompt}" in prompt_template:
|
|
|
+ final_prompt = prompt_template.replace("{prompt}", prompt)
|
|
|
+ else:
|
|
|
+ final_prompt = f"{prompt_template}\n\n具体要求:{prompt}"
|
|
|
+
|
|
|
+ body = {
|
|
|
+ "model": model_config['model_id'],
|
|
|
+ "input": {
|
|
|
+ "prompt": final_prompt,
|
|
|
+ "size": size or model_config['default_params']['size']
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ print(f"🚀 提交FLUX文生图任务...")
|
|
|
+ print(f" 图片尺寸: {body['input']['size']}")
|
|
|
+ print(f" 最终提示词: {final_prompt[:100]}...")
|
|
|
+
|
|
|
+ response = requests.post(model_config['api_base'], headers=headers, json=body)
|
|
|
+
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
|
|
|
+
|
|
|
+ task_id = response.json().get('output', {}).get('task_id')
|
|
|
+ if not task_id:
|
|
|
+ raise Exception("未获取到任务ID")
|
|
|
+
|
|
|
+ print(f"✅ 任务提交成功,任务ID: {task_id}")
|
|
|
+
|
|
|
+ return self._poll_dashscope_task(model_config, task_id)
|
|
|
+
|
|
|
+ def generate_image_dashscope_background(self,
|
|
|
+ model_config: Dict[str, Any],
|
|
|
+ image_path: str,
|
|
|
+ ref_prompt: str,
|
|
|
+ prompt_template: str = None) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 使用通义万相进行背景生成
|
|
|
+ """
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {model_config['api_key']}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "X-DashScope-Async": "enable"
|
|
|
+ }
|
|
|
+
|
|
|
+ # 上传图片并转换为RGBA格式(背景生成API要求RGBA格式)
|
|
|
+ print(f"📤 读取并处理图片: {Path(image_path).name}")
|
|
|
+ image_url = self.upload_image_to_temp(image_path, convert_to_rgba=True)
|
|
|
+
|
|
|
+ # 如果提供了提示词模板,将其与用户提示词结合
|
|
|
+ final_prompt = ref_prompt
|
|
|
+ if prompt_template and prompt_template.strip():
|
|
|
+ print(f"🎯 使用提示词模板优化背景描述")
|
|
|
+ if "{prompt}" in prompt_template:
|
|
|
+ final_prompt = prompt_template.replace("{prompt}", ref_prompt)
|
|
|
+ else:
|
|
|
+ final_prompt = f"{prompt_template}\n\n具体要求:{ref_prompt}"
|
|
|
+
|
|
|
+ # 构建请求体
|
|
|
+ body = {
|
|
|
+ "model": model_config['model_id'],
|
|
|
+ "input": {
|
|
|
+ "base_image_url": image_url,
|
|
|
+ "ref_prompt": final_prompt
|
|
|
+ },
|
|
|
+ "parameters": {
|
|
|
+ "model_version": model_config['default_params'].get('model_version', 'v3'),
|
|
|
+ "n": model_config['default_params'].get('n', 1)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ # 提交任务
|
|
|
+ print(f"🚀 提交背景生成任务...")
|
|
|
+ print(f" 背景描述: {final_prompt}")
|
|
|
+ print(f" 模型版本: {body['parameters']['model_version']}")
|
|
|
+ print(f" 生成数量: {body['parameters']['n']}")
|
|
|
+
|
|
|
+ response = requests.post(model_config['api_base'], headers=headers, json=body)
|
|
|
+
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
|
|
|
+
|
|
|
+ task_id = response.json().get('output', {}).get('task_id')
|
|
|
+ if not task_id:
|
|
|
+ raise Exception("未获取到任务ID")
|
|
|
+
|
|
|
+ print(f"✅ 任务提交成功,任务ID: {task_id}")
|
|
|
+
|
|
|
+ # 轮询查询结果
|
|
|
+ return self._poll_dashscope_task(model_config, task_id)
|
|
|
+
|
|
|
+ def _poll_dashscope_task(self, model_config: Dict[str, Any], task_id: str) -> Dict[str, Any]:
|
|
|
+ """轮询通义万相任务结果"""
|
|
|
+ query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
|
|
+ headers = {"Authorization": f"Bearer {model_config['api_key']}"}
|
|
|
+
|
|
|
+ poll_interval = model_config['default_params'].get('poll_interval', 5)
|
|
|
+ timeout = model_config['default_params'].get('timeout', 300)
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ print("🔍 开始查询任务状态...")
|
|
|
+ while True:
|
|
|
+ if time.time() - start_time > timeout:
|
|
|
+ raise Exception(f"任务超时({timeout}秒)")
|
|
|
+
|
|
|
+ response = requests.get(query_url, headers=headers)
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise Exception(f"查询失败: {response.status_code}, {response.text}")
|
|
|
+
|
|
|
+ response_data = response.json()
|
|
|
+ task_status = response_data.get('output', {}).get('task_status')
|
|
|
+
|
|
|
+ if task_status == 'SUCCEEDED':
|
|
|
+ print("✅ 任务成功完成!")
|
|
|
+ return response_data
|
|
|
+ elif task_status == 'FAILED':
|
|
|
+ error_msg = response_data.get('output', {}).get('message', '未知错误')
|
|
|
+ raise Exception(f"任务失败: {error_msg}")
|
|
|
+ else:
|
|
|
+ print(f"⏳ 任务处理中,当前状态: {task_status}...")
|
|
|
+ time.sleep(poll_interval)
|
|
|
+
|
|
|
+ def _poll_modelscope_task(self, model_config: Dict[str, Any], task_id: str) -> Dict[str, Any]:
|
|
|
+ """轮询ModelScope任务结果"""
|
|
|
+ query_url = f"https://api-inference.modelscope.cn/v1/tasks/{task_id}"
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {model_config['api_key']}",
|
|
|
+ "X-ModelScope-Task-Type": "image_generation"
|
|
|
+ }
|
|
|
+
|
|
|
+ poll_interval = model_config['default_params'].get('poll_interval', 5)
|
|
|
+ timeout = model_config['default_params'].get('timeout', 300)
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ print("🔍 开始查询任务状态...")
|
|
|
+ while True:
|
|
|
+ if time.time() - start_time > timeout:
|
|
|
+ raise Exception(f"任务超时({timeout}秒)")
|
|
|
+
|
|
|
+ response = requests.get(query_url, headers=headers)
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise Exception(f"查询失败: {response.status_code}, {response.text}")
|
|
|
+
|
|
|
+ response_data = response.json()
|
|
|
+ task_status = response_data.get('task_status')
|
|
|
+
|
|
|
+ if task_status == 'SUCCEED':
|
|
|
+ print("✅ 任务成功完成!")
|
|
|
+ return response_data
|
|
|
+ elif task_status == 'FAILED':
|
|
|
+ raise Exception(f"任务失败: {response_data}")
|
|
|
+ else:
|
|
|
+ print(f"⏳ 任务处理中,当前状态: {task_status}...")
|
|
|
+ time.sleep(poll_interval)
|
|
|
+
|
|
|
+ def generate_image(self,
|
|
|
+ model_name: str,
|
|
|
+ prompt: str = None,
|
|
|
+ image_path: str = None,
|
|
|
+ style_index: int = None,
|
|
|
+ custom_style_url: str = None,
|
|
|
+ prompt_template_name: str = None,
|
|
|
+ output_dir: str = "./output") -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 统一的图片生成接口
|
|
|
+ """
|
|
|
+ model_config = self.get_model_config(model_name)
|
|
|
+
|
|
|
+ if model_config.get('type') != 'image_generation':
|
|
|
+ raise ValueError(f"模型 {model_name} 不是图片生成模型")
|
|
|
+
|
|
|
+ # 获取提示词模板
|
|
|
+ prompt_template = None
|
|
|
+ if prompt_template_name:
|
|
|
+ # 检查兼容性
|
|
|
+ if not self.check_prompt_model_compatibility(prompt_template_name, model_name):
|
|
|
+ print(f"⚠️ 警告: 提示词模板 {prompt_template_name} 可能与模型 {model_name} 不兼容")
|
|
|
+
|
|
|
+ prompt_template = self.get_prompt_template(prompt_template_name)
|
|
|
+ print(f"🎯 使用提示词模板: {prompt_template_name}")
|
|
|
+
|
|
|
+ print(f"🎨 使用模型: {model_config['name']}")
|
|
|
+ print(f"🔧 生成类型: {model_config.get('generation_type')}")
|
|
|
+
|
|
|
+ # 根据不同的模型调用对应的生成方法
|
|
|
+ if model_name == "dashscope_wanx":
|
|
|
+ if not image_path:
|
|
|
+ raise ValueError("风格重绘需要提供输入图片")
|
|
|
+ result = self.generate_image_dashscope_style_repaint(
|
|
|
+ model_config, image_path, style_index, custom_style_url, prompt_template
|
|
|
+ )
|
|
|
+ elif model_name == "dashscope_background":
|
|
|
+ if not image_path:
|
|
|
+ raise ValueError("背景生成需要提供输入图片")
|
|
|
+ if not prompt:
|
|
|
+ raise ValueError("背景生成需要提供背景描述")
|
|
|
+ result = self.generate_image_dashscope_background(
|
|
|
+ model_config, image_path, prompt, prompt_template
|
|
|
+ )
|
|
|
+ elif model_name == "modelscope_qwen":
|
|
|
+ if not prompt:
|
|
|
+ raise ValueError("文生图需要提供文本提示")
|
|
|
+ result = self.generate_image_modelscope(model_config, prompt, prompt_template)
|
|
|
+ elif model_name == "dashscope_flux":
|
|
|
+ if not prompt:
|
|
|
+ raise ValueError("FLUX文生图需要提供文本提示")
|
|
|
+ result = self.generate_image_dashscope_flux(model_config, prompt, None, prompt_template)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的模型: {model_name}")
|
|
|
+
|
|
|
+ # 保存结果
|
|
|
+ return self._save_generated_images(result, model_name, output_dir, prompt_template_name)
|
|
|
+
|
|
|
+ def _save_generated_images(self,
|
|
|
+ result: Dict[str, Any],
|
|
|
+ model_name: str,
|
|
|
+ output_dir: str,
|
|
|
+ prompt_template_name: str = None) -> Dict[str, Any]:
|
|
|
+ """保存生成的图片"""
|
|
|
+ output_path = Path(output_dir)
|
|
|
+ output_path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
|
+ saved_files = []
|
|
|
+
|
|
|
+ # 根据不同API的响应格式提取图片URL
|
|
|
+ if model_name == "dashscope_wanx" or model_name == "dashscope_flux":
|
|
|
+ # 通义万相格式
|
|
|
+ results = result.get('output', {}).get('results', [])
|
|
|
+ for i, img_result in enumerate(results):
|
|
|
+ img_url = img_result.get('url')
|
|
|
+ if img_url:
|
|
|
+ # 如果使用了提示词模板,在文件名中体现
|
|
|
+ template_suffix = f"_{prompt_template_name}" if prompt_template_name else ""
|
|
|
+ filename = f"{model_name}_{timestamp}{template_suffix}_{i+1}.png"
|
|
|
+ filepath = output_path / filename
|
|
|
+
|
|
|
+ # 下载并保存图片
|
|
|
+ img_response = requests.get(img_url)
|
|
|
+ if img_response.status_code == 200:
|
|
|
+ image = Image.open(BytesIO(img_response.content))
|
|
|
+ image.save(filepath)
|
|
|
+ saved_files.append(filepath)
|
|
|
+ print(f"🖼️ 图片已保存: {filepath}")
|
|
|
+ else:
|
|
|
+ print(f"❌ 下载图片失败: {img_url}")
|
|
|
+
|
|
|
+ elif model_name == "modelscope_qwen":
|
|
|
+ # ModelScope格式
|
|
|
+ output_images = result.get('output_images', [])
|
|
|
+ for i, img_url in enumerate(output_images):
|
|
|
+ template_suffix = f"_{prompt_template_name}" if prompt_template_name else ""
|
|
|
+ filename = f"{model_name}_{timestamp}{template_suffix}_{i+1}.png"
|
|
|
+ filepath = output_path / filename
|
|
|
+
|
|
|
+ img_response = requests.get(img_url)
|
|
|
+ if img_response.status_code == 200:
|
|
|
+ image = Image.open(BytesIO(img_response.content))
|
|
|
+ image.save(filepath)
|
|
|
+ saved_files.append(filepath)
|
|
|
+ print(f"🖼️ 图片已保存: {filepath}")
|
|
|
+ else:
|
|
|
+ print(f"❌ 下载图片失败: {img_url}")
|
|
|
+
|
|
|
+ # 保存元数据
|
|
|
+ metadata = {
|
|
|
+ "generation_info": {
|
|
|
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "model_used": model_name,
|
|
|
+ "prompt_template_used": prompt_template_name,
|
|
|
+ "saved_files": [str(f) for f in saved_files],
|
|
|
+ "api_response": result
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ template_suffix = f"_{prompt_template_name}" if prompt_template_name else ""
|
|
|
+ metadata_file = output_path / f"{model_name}_{timestamp}{template_suffix}_metadata.json"
|
|
|
+ with open(metadata_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(metadata, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ print(f"📊 元数据已保存: {metadata_file}")
|
|
|
+
|
|
|
+ return metadata
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主函数"""
|
|
|
+ parser = argparse.ArgumentParser(description='AI图片生成工具')
|
|
|
+
|
|
|
+ # 基本参数
|
|
|
+ parser.add_argument('-c', '--config', default='config.yaml', help='配置文件路径')
|
|
|
+ parser.add_argument('-o', '--output', default='./output', help='输出目录')
|
|
|
+
|
|
|
+ # 模型选择
|
|
|
+ parser.add_argument('-m', '--model', help='模型名称')
|
|
|
+
|
|
|
+ # 生成参数
|
|
|
+ parser.add_argument('-p', '--prompt', help='文本提示(用于文生图)')
|
|
|
+ parser.add_argument('-i', '--image', help='输入图片路径(用于风格重绘)')
|
|
|
+ parser.add_argument('-s', '--style', type=int, help='风格索引(0-6)')
|
|
|
+ parser.add_argument('--style-ref', help='自定义风格参考图片URL')
|
|
|
+
|
|
|
+ # 提示词模板
|
|
|
+ parser.add_argument('-t', '--template', help='提示词模板名称')
|
|
|
+
|
|
|
+ # 信息查询
|
|
|
+ parser.add_argument('--list-models', action='store_true', help='列出所有可用的图片生成模型')
|
|
|
+ parser.add_argument('--list-styles', action='store_true', help='列出所有可用风格')
|
|
|
+ parser.add_argument('--list-prompts', action='store_true', help='列出所有可用的提示词模板')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ try:
|
|
|
+ generator = ImageGenerator(args.config)
|
|
|
+
|
|
|
+ # 处理信息查询
|
|
|
+ if args.list_models:
|
|
|
+ generator.list_models("image_generation")
|
|
|
+ return 0
|
|
|
+
|
|
|
+ if args.list_styles:
|
|
|
+ generator.list_styles()
|
|
|
+ return 0
|
|
|
+
|
|
|
+ if args.list_prompts:
|
|
|
+ generator.list_prompts("image_generation")
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # 检查必要参数
|
|
|
+ if not args.model:
|
|
|
+ print("❌ 错误: 请指定模型名称")
|
|
|
+ print("\n使用示例:")
|
|
|
+ print(" # 风格重绘")
|
|
|
+ print(" python image_generator.py -m dashscope_wanx -i photo.jpg -s 3")
|
|
|
+ print(" # 使用提示词模板进行风格重绘")
|
|
|
+ print(" python image_generator.py -m dashscope_wanx -i photo.jpg -t photo_restoration")
|
|
|
+ print(" # 文生图")
|
|
|
+ print(" python image_generator.py -m modelscope_qwen -p '一只可爱的金色小猫'")
|
|
|
+ print(" # 使用提示词模板进行文生图")
|
|
|
+ print(" python image_generator.py -m modelscope_qwen -p '金色小猫' -t text_to_image_simple")
|
|
|
+ print(" # 查看信息")
|
|
|
+ print(" python image_generator.py --list-models")
|
|
|
+ print(" python image_generator.py --list-prompts")
|
|
|
+ return 1
|
|
|
+
|
|
|
+ # 生成图片
|
|
|
+ result = generator.generate_image(
|
|
|
+ model_name=args.model,
|
|
|
+ prompt=args.prompt,
|
|
|
+ image_path=args.image,
|
|
|
+ style_index=args.style,
|
|
|
+ custom_style_url=args.style_ref,
|
|
|
+ prompt_template_name=args.template,
|
|
|
+ output_dir=args.output
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n🎉 图片生成完成!")
|
|
|
+ saved_files = result.get('generation_info', {}).get('saved_files', [])
|
|
|
+ print(f"📊 生成统计: 共保存 {len(saved_files)} 张图片")
|
|
|
+
|
|
|
+ return 0
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 程序执行失败: {e}")
|
|
|
+ return 1
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 调试用的默认参数
|
|
|
+ import sys
|
|
|
+ if len(sys.argv) == 1:
|
|
|
+ sys.argv.extend([
|
|
|
+ '-m', 'dashscope_background',
|
|
|
+ '-i', '../sample_data/工大照片-1.jpg',
|
|
|
+ '-t', 'background_studio', # 使用提示词模板
|
|
|
+ '-p', '温馨的书房环境', # 文生图提示词
|
|
|
+ '-o', './output'
|
|
|
+ ])
|
|
|
+ exit(main())
|