image_generator.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. import os
  2. import requests
  3. import time
  4. import json
  5. import yaml
  6. import base64
  7. import argparse
  8. from pathlib import Path
  9. from typing import Dict, Any, Optional, List
  10. from PIL import Image
  11. from io import BytesIO
  12. from dotenv import load_dotenv
  13. # 加载环境变量
  14. load_dotenv()
  15. class ImageGenerator:
  16. def __init__(self, config_path: str = "config.yaml"):
  17. """
  18. 初始化图片生成器
  19. Args:
  20. config_path: 配置文件路径
  21. """
  22. self.config_path = Path(config_path)
  23. self.config = self._load_config()
  24. def _load_config(self) -> Dict[str, Any]:
  25. """加载配置文件"""
  26. if not self.config_path.exists():
  27. raise FileNotFoundError(f"配置文件不存在: {self.config_path}")
  28. with open(self.config_path, 'r', encoding='utf-8') as f:
  29. config = yaml.safe_load(f)
  30. return config
  31. def _resolve_env_variable(self, value: str) -> str:
  32. """解析环境变量"""
  33. if not isinstance(value, str):
  34. return value
  35. import re
  36. pattern = r'\$\{([^}]+)\}'
  37. def replace_env_var(match):
  38. env_var_name = match.group(1)
  39. env_value = os.getenv(env_var_name)
  40. if env_value is None:
  41. print(f"⚠️ 警告: 环境变量 {env_var_name} 未设置")
  42. return ""
  43. return env_value
  44. return re.sub(pattern, replace_env_var, value)
  45. def list_models(self, model_type: str = "image_generation") -> None:
  46. """列出指定类型的模型"""
  47. print(f"📋 可用的{model_type}模型列表:")
  48. for model_key, model_config in self.config['models'].items():
  49. if model_config.get('type') == model_type:
  50. api_key = self._resolve_env_variable(model_config['api_key'])
  51. api_key_status = "✅ 已配置" if api_key else "❌ 未配置"
  52. print(f" 🎨 {model_key}: {model_config['name']}")
  53. print(f" 生成类型: {model_config.get('generation_type', 'N/A')}")
  54. print(f" API地址: {model_config['api_base']}")
  55. print(f" API密钥: {api_key_status}")
  56. print()
  57. def list_styles(self) -> None:
  58. """列出可用的风格预设"""
  59. print("🎨 可用风格预设:")
  60. for style_key, styles in self.config.get('style_presets', {}).items():
  61. print(f"\n 📝 {style_key}:")
  62. for style in styles:
  63. print(f" {style['index']}: {style['name']} - {style['description']}")
  64. def list_prompts(self, prompt_type: str = "image_generation") -> None:
  65. """列出指定类型的提示词模板"""
  66. print(f"📝 可用的{prompt_type}提示词模板:")
  67. for prompt_key, prompt_config in self.config.get('prompts', {}).items():
  68. if prompt_config.get('type') == prompt_type:
  69. print(f" 💬 {prompt_key}: {prompt_config['name']}")
  70. # 显示兼容的模型
  71. compatible_models = prompt_config.get('compatible_models', [])
  72. if compatible_models:
  73. print(f" 兼容模型: {', '.join(compatible_models)}")
  74. # 显示模板预览(前100个字符)
  75. template_preview = prompt_config.get('template', '')[:100].replace('\n', ' ')
  76. print(f" 模板预览: {template_preview}...")
  77. print()
  78. def get_model_config(self, model_name: str) -> Dict[str, Any]:
  79. """获取模型配置并解析环境变量"""
  80. if model_name not in self.config['models']:
  81. raise ValueError(f"未找到模型配置: {model_name}")
  82. model_config = self.config['models'][model_name].copy()
  83. model_config['api_key'] = self._resolve_env_variable(model_config['api_key'])
  84. return model_config
  85. def get_prompt_template(self, prompt_name: str) -> str:
  86. """获取提示词模板"""
  87. if prompt_name not in self.config.get('prompts', {}):
  88. raise ValueError(f"未找到提示词模板: {prompt_name},可用模板: {list(self.config.get('prompts', {}).keys())}")
  89. return self.config['prompts'][prompt_name]['template']
  90. def check_prompt_model_compatibility(self, prompt_name: str, model_name: str) -> bool:
  91. """检查提示词模板与模型的兼容性"""
  92. prompt_config = self.config.get('prompts', {}).get(prompt_name, {})
  93. compatible_models = prompt_config.get('compatible_models', [])
  94. # 如果没有指定兼容模型,则认为所有模型都兼容
  95. if not compatible_models:
  96. return True
  97. return model_name in compatible_models
  98. def upload_image_to_temp(self, image_path: str, convert_to_rgba: bool = False) -> str:
  99. """
  100. 上传图片到临时存储并转换为指定格式
  101. Args:
  102. image_path: 图片路径
  103. convert_to_rgba: 是否转换为RGBA格式
  104. Returns:
  105. base64编码的图片数据URL
  106. """
  107. if not Path(image_path).exists():
  108. raise FileNotFoundError(f"找不到图片文件: {image_path}")
  109. # 使用PIL打开图片
  110. with Image.open(image_path) as img:
  111. # 如果需要转换为RGBA格式
  112. if convert_to_rgba:
  113. if img.mode != 'RGBA':
  114. print(f"🔄 将图片从 {img.mode} 模式转换为 RGBA 模式")
  115. # 转换为RGBA模式
  116. if img.mode == 'RGB':
  117. # RGB转RGBA,添加不透明度通道
  118. img = img.convert('RGBA')
  119. elif img.mode == 'L':
  120. # 灰度转RGBA
  121. img = img.convert('RGBA')
  122. elif img.mode == 'P':
  123. # 调色板模式转RGBA
  124. img = img.convert('RGBA')
  125. else:
  126. # 其他模式先转RGB再转RGBA
  127. img = img.convert('RGB').convert('RGBA')
  128. print(f"✅ 图片已转换为 RGBA 模式")
  129. # 保存为PNG格式的字节流(PNG支持RGBA)
  130. from io import BytesIO
  131. img_buffer = BytesIO()
  132. # 如果是RGBA模式,保存为PNG;否则保存为JPEG
  133. if img.mode == 'RGBA':
  134. img.save(img_buffer, format='PNG')
  135. mime_type = 'image/png'
  136. else:
  137. # 对于RGB等模式,转换为RGB再保存为JPEG
  138. if img.mode != 'RGB':
  139. img = img.convert('RGB')
  140. img.save(img_buffer, format='JPEG', quality=95)
  141. mime_type = 'image/jpeg'
  142. img_buffer.seek(0)
  143. image_data = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
  144. return f"data:{mime_type};base64,{image_data}"
  145. def generate_image_dashscope_style_repaint(self,
  146. model_config: Dict[str, Any],
  147. image_path: str,
  148. style_index: int = None,
  149. custom_style_url: str = None,
  150. prompt_template: str = None) -> Dict[str, Any]:
  151. """
  152. 使用通义万相进行风格重绘
  153. 注意:通义万相的风格重绘API不支持文本提示词,只能通过style_index或style_ref_url控制风格
  154. """
  155. headers = {
  156. "Authorization": f"Bearer {model_config['api_key']}",
  157. "Content-Type": "application/json",
  158. "X-DashScope-Async": "enable"
  159. }
  160. # 上传图片(风格重绘不需要RGBA格式)
  161. print(f"📤 读取图片: {Path(image_path).name}")
  162. image_url = self.upload_image_to_temp(image_path, convert_to_rgba=False)
  163. # 构建请求体
  164. if custom_style_url:
  165. # 使用自定义风格
  166. body = {
  167. "model": model_config['model_id'],
  168. "input": {
  169. "image_url": image_url,
  170. "style_ref_url": custom_style_url,
  171. "style_index": -1
  172. }
  173. }
  174. print(f"🎨 使用自定义风格参考: {custom_style_url}")
  175. else:
  176. # 使用预置风格
  177. style_idx = style_index if style_index is not None else model_config['default_params']['style_index']
  178. # 如果有提示词模板但没有指定风格索引,尝试根据模板内容智能选择风格
  179. if prompt_template and style_index is None:
  180. style_idx = self._select_style_from_template(prompt_template)
  181. print(f"🤖 根据提示词模板智能选择风格索引: {style_idx}")
  182. body = {
  183. "model": model_config['model_id'],
  184. "input": {
  185. "image_url": image_url,
  186. "style_index": style_idx
  187. }
  188. }
  189. # 显示提示词模板信息(仅用于记录,不影响API调用)
  190. if prompt_template:
  191. print(f"📝 提示词模板内容(仅作为风格选择参考):")
  192. print(f" {prompt_template[:200]}...")
  193. print(f"⚠️ 注意: 通义万相风格重绘API不支持文本提示,仅通过风格索引控制效果")
  194. # 提交任务
  195. print(f"🚀 提交风格重绘任务...")
  196. print(f" 风格索引: {body['input'].get('style_index', '自定义')}")
  197. response = requests.post(model_config['api_base'], headers=headers, json=body)
  198. if response.status_code != 200:
  199. raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
  200. task_id = response.json().get('output', {}).get('task_id')
  201. if not task_id:
  202. raise Exception("未获取到任务ID")
  203. print(f"✅ 任务提交成功,任务ID: {task_id}")
  204. # 轮询查询结果
  205. return self._poll_dashscope_task(model_config, task_id)
  206. def generate_image_modelscope(self,
  207. model_config: Dict[str, Any],
  208. prompt: str,
  209. prompt_template: str = None) -> Dict[str, Any]:
  210. """
  211. 使用ModelScope进行文生图
  212. """
  213. headers = {
  214. "Authorization": f"Bearer {model_config['api_key']}",
  215. "Content-Type": "application/json",
  216. "X-ModelScope-Async-Mode": "true"
  217. }
  218. # 如果提供了提示词模板,将其与用户提示词结合
  219. final_prompt = prompt
  220. if prompt_template and prompt_template.strip():
  221. print(f"🎯 使用提示词模板优化提示词")
  222. # 简单的模板应用:将用户提示词插入到模板中
  223. if "{prompt}" in prompt_template:
  224. final_prompt = prompt_template.replace("{prompt}", prompt)
  225. else:
  226. # 如果模板中没有占位符,则将用户提示词追加到模板后
  227. final_prompt = f"{prompt_template}\n\n具体要求:{prompt}"
  228. body = {
  229. "model": model_config['model_id'],
  230. "prompt": final_prompt
  231. }
  232. print(f"🚀 提交文生图任务...")
  233. print(f" 最终提示词: {final_prompt[:100]}...")
  234. response = requests.post(model_config['api_base'], headers=headers, json=body)
  235. if response.status_code != 200:
  236. raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
  237. task_id = response.json().get("task_id")
  238. if not task_id:
  239. raise Exception("未获取到任务ID")
  240. print(f"✅ 任务提交成功,任务ID: {task_id}")
  241. # 轮询查询结果
  242. return self._poll_modelscope_task(model_config, task_id)
  243. def generate_image_dashscope_flux(self,
  244. model_config: Dict[str, Any],
  245. prompt: str,
  246. size: str = None,
  247. prompt_template: str = None) -> Dict[str, Any]:
  248. """
  249. 使用通义万相FLUX进行文生图
  250. """
  251. headers = {
  252. "Authorization": f"Bearer {model_config['api_key']}",
  253. "Content-Type": "application/json",
  254. "X-DashScope-Async": "enable"
  255. }
  256. # 如果提供了提示词模板,将其与用户提示词结合
  257. final_prompt = prompt
  258. if prompt_template and prompt_template.strip():
  259. print(f"🎯 使用提示词模板优化提示词")
  260. if "{prompt}" in prompt_template:
  261. final_prompt = prompt_template.replace("{prompt}", prompt)
  262. else:
  263. final_prompt = f"{prompt_template}\n\n具体要求:{prompt}"
  264. body = {
  265. "model": model_config['model_id'],
  266. "input": {
  267. "prompt": final_prompt,
  268. "size": size or model_config['default_params']['size']
  269. }
  270. }
  271. print(f"🚀 提交FLUX文生图任务...")
  272. print(f" 图片尺寸: {body['input']['size']}")
  273. print(f" 最终提示词: {final_prompt[:100]}...")
  274. response = requests.post(model_config['api_base'], headers=headers, json=body)
  275. if response.status_code != 200:
  276. raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
  277. task_id = response.json().get('output', {}).get('task_id')
  278. if not task_id:
  279. raise Exception("未获取到任务ID")
  280. print(f"✅ 任务提交成功,任务ID: {task_id}")
  281. return self._poll_dashscope_task(model_config, task_id)
  282. def generate_image_dashscope_background(self,
  283. model_config: Dict[str, Any],
  284. image_path: str,
  285. ref_prompt: str,
  286. prompt_template: str = None) -> Dict[str, Any]:
  287. """
  288. 使用通义万相进行背景生成
  289. """
  290. headers = {
  291. "Authorization": f"Bearer {model_config['api_key']}",
  292. "Content-Type": "application/json",
  293. "X-DashScope-Async": "enable"
  294. }
  295. # 上传图片并转换为RGBA格式(背景生成API要求RGBA格式)
  296. print(f"📤 读取并处理图片: {Path(image_path).name}")
  297. image_url = self.upload_image_to_temp(image_path, convert_to_rgba=True)
  298. # 如果提供了提示词模板,将其与用户提示词结合
  299. final_prompt = ref_prompt
  300. if prompt_template and prompt_template.strip():
  301. print(f"🎯 使用提示词模板优化背景描述")
  302. if "{prompt}" in prompt_template:
  303. final_prompt = prompt_template.replace("{prompt}", ref_prompt)
  304. else:
  305. final_prompt = f"{prompt_template}\n\n具体要求:{ref_prompt}"
  306. # 构建请求体
  307. body = {
  308. "model": model_config['model_id'],
  309. "input": {
  310. "base_image_url": image_url,
  311. "ref_prompt": final_prompt
  312. },
  313. "parameters": {
  314. "model_version": model_config['default_params'].get('model_version', 'v3'),
  315. "n": model_config['default_params'].get('n', 1)
  316. }
  317. }
  318. # 提交任务
  319. print(f"🚀 提交背景生成任务...")
  320. print(f" 背景描述: {final_prompt}")
  321. print(f" 模型版本: {body['parameters']['model_version']}")
  322. print(f" 生成数量: {body['parameters']['n']}")
  323. response = requests.post(model_config['api_base'], headers=headers, json=body)
  324. if response.status_code != 200:
  325. raise Exception(f"任务提交失败: {response.status_code}, {response.text}")
  326. task_id = response.json().get('output', {}).get('task_id')
  327. if not task_id:
  328. raise Exception("未获取到任务ID")
  329. print(f"✅ 任务提交成功,任务ID: {task_id}")
  330. # 轮询查询结果
  331. return self._poll_dashscope_task(model_config, task_id)
  332. def _poll_dashscope_task(self, model_config: Dict[str, Any], task_id: str) -> Dict[str, Any]:
  333. """轮询通义万相任务结果"""
  334. query_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
  335. headers = {"Authorization": f"Bearer {model_config['api_key']}"}
  336. poll_interval = model_config['default_params'].get('poll_interval', 5)
  337. timeout = model_config['default_params'].get('timeout', 300)
  338. start_time = time.time()
  339. print("🔍 开始查询任务状态...")
  340. while True:
  341. if time.time() - start_time > timeout:
  342. raise Exception(f"任务超时({timeout}秒)")
  343. response = requests.get(query_url, headers=headers)
  344. if response.status_code != 200:
  345. raise Exception(f"查询失败: {response.status_code}, {response.text}")
  346. response_data = response.json()
  347. task_status = response_data.get('output', {}).get('task_status')
  348. if task_status == 'SUCCEEDED':
  349. print("✅ 任务成功完成!")
  350. return response_data
  351. elif task_status == 'FAILED':
  352. error_msg = response_data.get('output', {}).get('message', '未知错误')
  353. raise Exception(f"任务失败: {error_msg}")
  354. else:
  355. print(f"⏳ 任务处理中,当前状态: {task_status}...")
  356. time.sleep(poll_interval)
  357. def _poll_modelscope_task(self, model_config: Dict[str, Any], task_id: str) -> Dict[str, Any]:
  358. """轮询ModelScope任务结果"""
  359. query_url = f"https://api-inference.modelscope.cn/v1/tasks/{task_id}"
  360. headers = {
  361. "Authorization": f"Bearer {model_config['api_key']}",
  362. "X-ModelScope-Task-Type": "image_generation"
  363. }
  364. poll_interval = model_config['default_params'].get('poll_interval', 5)
  365. timeout = model_config['default_params'].get('timeout', 300)
  366. start_time = time.time()
  367. print("🔍 开始查询任务状态...")
  368. while True:
  369. if time.time() - start_time > timeout:
  370. raise Exception(f"任务超时({timeout}秒)")
  371. response = requests.get(query_url, headers=headers)
  372. if response.status_code != 200:
  373. raise Exception(f"查询失败: {response.status_code}, {response.text}")
  374. response_data = response.json()
  375. task_status = response_data.get('task_status')
  376. if task_status == 'SUCCEED':
  377. print("✅ 任务成功完成!")
  378. return response_data
  379. elif task_status == 'FAILED':
  380. raise Exception(f"任务失败: {response_data}")
  381. else:
  382. print(f"⏳ 任务处理中,当前状态: {task_status}...")
  383. time.sleep(poll_interval)
  384. def generate_image(self,
  385. model_name: str,
  386. prompt: str = None,
  387. image_path: str = None,
  388. style_index: int = None,
  389. custom_style_url: str = None,
  390. prompt_template_name: str = None,
  391. output_dir: str = "./output") -> Dict[str, Any]:
  392. """
  393. 统一的图片生成接口
  394. """
  395. model_config = self.get_model_config(model_name)
  396. if model_config.get('type') != 'image_generation':
  397. raise ValueError(f"模型 {model_name} 不是图片生成模型")
  398. # 获取提示词模板
  399. prompt_template = None
  400. if prompt_template_name:
  401. # 检查兼容性
  402. if not self.check_prompt_model_compatibility(prompt_template_name, model_name):
  403. print(f"⚠️ 警告: 提示词模板 {prompt_template_name} 可能与模型 {model_name} 不兼容")
  404. prompt_template = self.get_prompt_template(prompt_template_name)
  405. print(f"🎯 使用提示词模板: {prompt_template_name}")
  406. print(f"🎨 使用模型: {model_config['name']}")
  407. print(f"🔧 生成类型: {model_config.get('generation_type')}")
  408. # 根据不同的模型调用对应的生成方法
  409. if model_name == "dashscope_wanx":
  410. if not image_path:
  411. raise ValueError("风格重绘需要提供输入图片")
  412. result = self.generate_image_dashscope_style_repaint(
  413. model_config, image_path, style_index, custom_style_url, prompt_template
  414. )
  415. elif model_name == "dashscope_background":
  416. if not image_path:
  417. raise ValueError("背景生成需要提供输入图片")
  418. if not prompt:
  419. raise ValueError("背景生成需要提供背景描述")
  420. result = self.generate_image_dashscope_background(
  421. model_config, image_path, prompt, prompt_template
  422. )
  423. elif model_name == "modelscope_qwen":
  424. if not prompt:
  425. raise ValueError("文生图需要提供文本提示")
  426. result = self.generate_image_modelscope(model_config, prompt, prompt_template)
  427. elif model_name == "dashscope_flux":
  428. if not prompt:
  429. raise ValueError("FLUX文生图需要提供文本提示")
  430. result = self.generate_image_dashscope_flux(model_config, prompt, None, prompt_template)
  431. else:
  432. raise ValueError(f"不支持的模型: {model_name}")
  433. # 保存结果
  434. return self._save_generated_images(result, model_name, output_dir, prompt_template_name)
  435. def _save_generated_images(self,
  436. result: Dict[str, Any],
  437. model_name: str,
  438. output_dir: str,
  439. prompt_template_name: str = None) -> Dict[str, Any]:
  440. """保存生成的图片"""
  441. output_path = Path(output_dir)
  442. output_path.mkdir(parents=True, exist_ok=True)
  443. timestamp = time.strftime("%Y%m%d_%H%M%S")
  444. saved_files = []
  445. # 根据不同API的响应格式提取图片URL
  446. if model_name == "dashscope_wanx" or model_name == "dashscope_flux":
  447. # 通义万相格式
  448. results = result.get('output', {}).get('results', [])
  449. for i, img_result in enumerate(results):
  450. img_url = img_result.get('url')
  451. if img_url:
  452. # 如果使用了提示词模板,在文件名中体现
  453. template_suffix = f"_{prompt_template_name}" if prompt_template_name else ""
  454. filename = f"{model_name}_{timestamp}{template_suffix}_{i+1}.png"
  455. filepath = output_path / filename
  456. # 下载并保存图片
  457. img_response = requests.get(img_url)
  458. if img_response.status_code == 200:
  459. image = Image.open(BytesIO(img_response.content))
  460. image.save(filepath)
  461. saved_files.append(filepath)
  462. print(f"🖼️ 图片已保存: {filepath}")
  463. else:
  464. print(f"❌ 下载图片失败: {img_url}")
  465. elif model_name == "modelscope_qwen":
  466. # ModelScope格式
  467. output_images = result.get('output_images', [])
  468. for i, img_url in enumerate(output_images):
  469. template_suffix = f"_{prompt_template_name}" if prompt_template_name else ""
  470. filename = f"{model_name}_{timestamp}{template_suffix}_{i+1}.png"
  471. filepath = output_path / filename
  472. img_response = requests.get(img_url)
  473. if img_response.status_code == 200:
  474. image = Image.open(BytesIO(img_response.content))
  475. image.save(filepath)
  476. saved_files.append(filepath)
  477. print(f"🖼️ 图片已保存: {filepath}")
  478. else:
  479. print(f"❌ 下载图片失败: {img_url}")
  480. # 保存元数据
  481. metadata = {
  482. "generation_info": {
  483. "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
  484. "model_used": model_name,
  485. "prompt_template_used": prompt_template_name,
  486. "saved_files": [str(f) for f in saved_files],
  487. "api_response": result
  488. }
  489. }
  490. template_suffix = f"_{prompt_template_name}" if prompt_template_name else ""
  491. metadata_file = output_path / f"{model_name}_{timestamp}{template_suffix}_metadata.json"
  492. with open(metadata_file, 'w', encoding='utf-8') as f:
  493. json.dump(metadata, f, ensure_ascii=False, indent=2)
  494. print(f"📊 元数据已保存: {metadata_file}")
  495. return metadata
  496. def main():
  497. """主函数"""
  498. parser = argparse.ArgumentParser(description='AI图片生成工具')
  499. # 基本参数
  500. parser.add_argument('-c', '--config', default='config.yaml', help='配置文件路径')
  501. parser.add_argument('-o', '--output', default='./output', help='输出目录')
  502. # 模型选择
  503. parser.add_argument('-m', '--model', help='模型名称')
  504. # 生成参数
  505. parser.add_argument('-p', '--prompt', help='文本提示(用于文生图)')
  506. parser.add_argument('-i', '--image', help='输入图片路径(用于风格重绘)')
  507. parser.add_argument('-s', '--style', type=int, help='风格索引(0-6)')
  508. parser.add_argument('--style-ref', help='自定义风格参考图片URL')
  509. # 提示词模板
  510. parser.add_argument('-t', '--template', help='提示词模板名称')
  511. # 信息查询
  512. parser.add_argument('--list-models', action='store_true', help='列出所有可用的图片生成模型')
  513. parser.add_argument('--list-styles', action='store_true', help='列出所有可用风格')
  514. parser.add_argument('--list-prompts', action='store_true', help='列出所有可用的提示词模板')
  515. args = parser.parse_args()
  516. try:
  517. generator = ImageGenerator(args.config)
  518. # 处理信息查询
  519. if args.list_models:
  520. generator.list_models("image_generation")
  521. return 0
  522. if args.list_styles:
  523. generator.list_styles()
  524. return 0
  525. if args.list_prompts:
  526. generator.list_prompts("image_generation")
  527. return 0
  528. # 检查必要参数
  529. if not args.model:
  530. print("❌ 错误: 请指定模型名称")
  531. print("\n使用示例:")
  532. print(" # 风格重绘")
  533. print(" python image_generator.py -m dashscope_wanx -i photo.jpg -s 3")
  534. print(" # 使用提示词模板进行风格重绘")
  535. print(" python image_generator.py -m dashscope_wanx -i photo.jpg -t photo_restoration")
  536. print(" # 文生图")
  537. print(" python image_generator.py -m modelscope_qwen -p '一只可爱的金色小猫'")
  538. print(" # 使用提示词模板进行文生图")
  539. print(" python image_generator.py -m modelscope_qwen -p '金色小猫' -t text_to_image_simple")
  540. print(" # 查看信息")
  541. print(" python image_generator.py --list-models")
  542. print(" python image_generator.py --list-prompts")
  543. return 1
  544. # 生成图片
  545. result = generator.generate_image(
  546. model_name=args.model,
  547. prompt=args.prompt,
  548. image_path=args.image,
  549. style_index=args.style,
  550. custom_style_url=args.style_ref,
  551. prompt_template_name=args.template,
  552. output_dir=args.output
  553. )
  554. print(f"\n🎉 图片生成完成!")
  555. saved_files = result.get('generation_info', {}).get('saved_files', [])
  556. print(f"📊 生成统计: 共保存 {len(saved_files)} 张图片")
  557. return 0
  558. except Exception as e:
  559. print(f"❌ 程序执行失败: {e}")
  560. return 1
  561. if __name__ == "__main__":
  562. # 调试用的默认参数
  563. import sys
  564. if len(sys.argv) == 1:
  565. sys.argv.extend([
  566. '-m', 'dashscope_background',
  567. '-i', '../sample_data/工大照片-1.jpg',
  568. '-t', 'background_studio', # 使用提示词模板
  569. '-p', '温馨的书房环境', # 文生图提示词
  570. '-o', './output'
  571. ])
  572. exit(main())