planning_agent.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. from typing import List, Dict, Optional, Any, Union
  2. from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
  3. from pydantic import BaseModel, Field
  4. from langchain_openai import ChatOpenAI
  5. import json
  6. import os
  7. from datetime import datetime
  8. # 数据模型定义
  9. class ActionItem(BaseModel):
  10. """动作项定义"""
  11. action: str = Field(description="动作名称")
  12. parameters: Optional[Dict[str, Any]] = Field(default_factory=dict, description="动作参数")
  13. class ClarificationRequest(BaseModel):
  14. """澄清请求结构化格式"""
  15. questions: List[str] = Field(description="需要澄清的问题列表")
  16. missing_fields: List[str] = Field(default_factory=list, description="缺少的字段或信息")
  17. class PlanningDecision(BaseModel):
  18. """规划决策输出"""
  19. decision: str = Field(
  20. description="决策类型: generate_outline, compute_metrics, finalize_report, clarify_requirements"
  21. )
  22. reasoning: str = Field(description="详细推理过程")
  23. next_actions: List[Union[str, ActionItem]] = Field(
  24. default_factory=list,
  25. description="下一步动作列表"
  26. )
  27. metrics_to_compute: List[str] = Field(
  28. default_factory=list,
  29. description="待计算指标ID列表(如 ['total_income', 'avg_balance'])"
  30. )
  31. priority_metrics: List[str] = Field(
  32. default_factory=list,
  33. description="优先级高的指标ID"
  34. )
  35. additional_requirements: Optional[
  36. Union[Dict[str, Any], List[Any], ClarificationRequest]
  37. ] = Field(default=None, description="额外需求或澄清信息")
  38. def normalize_requirements(req: Any) -> Optional[Dict[str, Any]]:
  39. """
  40. 规范化 additional_requirements
  41. 将列表转换为字典格式
  42. """
  43. if req is None:
  44. return None
  45. if isinstance(req, dict):
  46. return req
  47. if isinstance(req, list):
  48. # 如果LLM错误地返回了列表,转换为字典格式
  49. return {
  50. "questions": [str(item) for item in req],
  51. "missing_fields": []
  52. }
  53. return {"raw": str(req)}
  54. class PlanningAgent:
  55. """规划智能体:负责状态分析和决策制定"""
  56. def __init__(self, api_key: str, base_url: str = "https://api.deepseek.com"):
  57. """
  58. 初始化规划Agent
  59. Args:
  60. api_key: DeepSeek API密钥
  61. base_url: DeepSeek API基础URL
  62. """
  63. self.llm = ChatOpenAI(
  64. model="deepseek-chat",
  65. api_key=api_key,
  66. base_url=base_url,
  67. temperature=0.1
  68. )
  69. # 初始化API调用跟踪
  70. self.api_calls = []
  71. def create_planning_prompt(self) -> ChatPromptTemplate:
  72. """创建规划提示模板"""
  73. return ChatPromptTemplate.from_messages([
  74. ("system", """你是报告规划总控智能体,核心职责是精准分析当前状态并决定下一步行动。
  75. ### 决策选项(二选一)
  76. 1. generate_outline:大纲未生成或大纲无效
  77. 2. compute_metrics:大纲已生成但指标未完成
  78. ### 决策规则(按顺序检查)
  79. 1. 检查 outline_draft 是否为空 → 空则选择 generate_outline
  80. 2. 检查 metrics_requirements 是否为空 → 空则选择 generate_outline
  81. 3. 检查是否有待计算指标 → 有则选择 compute_metrics
  82. 4. 所有指标都已计算完成 → 选择 finalize_report
  83. 5. 如果无法理解需求 → 选择 clarify_requirements
  84. ### 重要原则
  85. - 大纲草稿已存在时,不要重复生成大纲
  86. - 决策为 compute_metrics 时,必须从状态信息中的"有效待计算指标ID列表"中选择
  87. - 确保 metrics_to_compute 是字符串数组格式
  88. - 确保指标ID与大纲中的global_metrics.metric_id完全一致
  89. - 从状态信息中的"有效待计算指标ID列表"中提取metric_id作为metrics_to_compute的值
  90. - 计算失败的指标可以重试最多3次
  91. - 绝对不要自己生成新的指标ID,必须严格使用状态信息中提供的已有指标ID
  92. - 如果状态信息中没有可用的指标ID,不要生成compute_metrics决策
  93. ### 输出字段说明
  94. - decision: 决策字符串
  95. - reasoning: 决策原因说明
  96. - metrics_to_compute: 待计算指标ID列表,必须从状态信息中的"有效待计算指标ID列表"中选择。选择所有可用指标,除非指标数量过多(>10个)需要分批计算
  97. - priority_metrics: 优先级指标列表(前2-3个最重要的指标),从metrics_to_compute中选择
  98. 必须输出有效的JSON格式!"""),
  99. MessagesPlaceholder("messages"),
  100. ("user", "报告需求:{question}\n\n请输出决策结果。")
  101. ])
  102. async def make_decision(self, question: str, industry: str, current_state: Dict[str, Any]) -> PlanningDecision:
  103. """
  104. 根据当前状态做出规划决策
  105. Args:
  106. question: 用户查询
  107. industry: 行业
  108. current_state: 当前状态信息
  109. Returns:
  110. 规划决策结果
  111. """
  112. planner = self.create_planning_prompt() | self.llm
  113. # 构建状态评估上下文
  114. status_info = self._build_status_context(current_state)
  115. # 记录大模型输入
  116. print("========================================")
  117. print("[AGENT] PlanningAgent (规划Agent)")
  118. print("[MODEL_INPUT] PlanningAgent:")
  119. print(f"[CONTEXT] 基于当前状态做出规划决策")
  120. print(f"Question: {question}")
  121. print(f"Status info: {status_info}")
  122. print("========================================")
  123. # 执行规划
  124. start_time = datetime.now()
  125. response = await planner.ainvoke({
  126. "question": question,
  127. "industry": industry,
  128. "messages": [("system", status_info)]
  129. })
  130. end_time = datetime.now()
  131. # 解析JSON响应
  132. try:
  133. # 从响应中提取JSON内容
  134. content = response.content if hasattr(response, 'content') else str(response)
  135. # 尝试找到JSON部分
  136. json_start = content.find('{')
  137. json_end = content.rfind('}') + 1
  138. if json_start >= 0 and json_end > json_start:
  139. json_str = content[json_start:json_end]
  140. decision_data = json.loads(json_str)
  141. # 预处理 additional_requirements 字段
  142. if "additional_requirements" in decision_data:
  143. req = decision_data["additional_requirements"]
  144. if isinstance(req, str):
  145. # 如果是字符串,尝试将其转换为合适的格式
  146. if req.strip():
  147. # 将字符串包装为字典格式
  148. decision_data["additional_requirements"] = {"raw_content": req}
  149. else:
  150. # 空字符串设为 None
  151. decision_data["additional_requirements"] = None
  152. elif isinstance(req, list):
  153. # 如果是列表,转换为字典格式
  154. decision_data["additional_requirements"] = {
  155. "questions": [str(item) for item in req],
  156. "missing_fields": []
  157. }
  158. # 如果已经是 dict 或其他允许的类型,保持不变
  159. decision = PlanningDecision(**decision_data)
  160. # 验证决策的合理性
  161. if decision.decision == "compute_metrics":
  162. if not decision.metrics_to_compute:
  163. raise ValueError("AI决策缺少具体的指标ID")
  164. # 如果AI生成的指标ID明显是错误的(比如metric_001),使用默认逻辑
  165. if any(mid.startswith("metric_") and mid.replace("metric_", "").isdigit()
  166. for mid in decision.metrics_to_compute):
  167. raise ValueError("AI生成的指标ID格式不正确")
  168. else:
  169. raise ValueError("No JSON found in response")
  170. except Exception as e:
  171. print(f"解析规划决策响应失败: {e},使用默认决策")
  172. # 返回默认决策
  173. decision = self._get_default_decision(current_state)
  174. # 记录API调用结果
  175. content = response.content if hasattr(response, 'content') else str(response)
  176. call_id = f"api_mll_规划决策_{'{:.2f}'.format((end_time - start_time).total_seconds())}"
  177. api_call_info = {
  178. "call_id": call_id,
  179. "timestamp": end_time.isoformat(),
  180. "agent": "PlanningAgent",
  181. "model": "deepseek-chat",
  182. "request": {
  183. "question": question,
  184. "status_info": status_info,
  185. "start_time": start_time.isoformat()
  186. },
  187. "response": {
  188. "content": content,
  189. "decision": decision.dict() if hasattr(decision, 'dict') else decision,
  190. "end_time": end_time.isoformat(),
  191. "duration": (end_time - start_time).total_seconds()
  192. },
  193. "success": True
  194. }
  195. self.api_calls.append(api_call_info)
  196. # 保存API结果到文件
  197. api_results_dir = "api_results"
  198. os.makedirs(api_results_dir, exist_ok=True)
  199. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  200. filename = f"{timestamp}_{call_id}.json"
  201. filepath = os.path.join(api_results_dir, filename)
  202. try:
  203. with open(filepath, 'w', encoding='utf-8') as f:
  204. json.dump(api_call_info, f, ensure_ascii=False, indent=2)
  205. print(f"[API_RESULT] 保存API结果文件: {filepath}")
  206. except Exception as e:
  207. print(f"[ERROR] 保存API结果文件失败: {filepath}, 错误: {str(e)}")
  208. # 记录大模型输出
  209. print(f"[MODEL_OUTPUT] PlanningAgent: {json.dumps(decision.dict() if hasattr(decision, 'dict') else decision, ensure_ascii=False)}")
  210. print("========================================")
  211. return decision
  212. def _build_status_context(self, state: Dict[str, Any]) -> str:
  213. """构建状态评估上下文"""
  214. required_count = len(state.get("metrics_requirements", []))
  215. computed_count = len(state.get("computed_metrics", {}))
  216. coverage = computed_count / required_count if required_count > 0 else 0
  217. # 计算失败统计
  218. failed_attempts = state.get("failed_metric_attempts", {})
  219. pending_ids = state.get("pending_metric_ids", [])
  220. # 过滤掉失败次数过多的指标
  221. max_retry = 3
  222. filtered_pending_ids = [
  223. mid for mid in pending_ids
  224. if failed_attempts.get(mid, 0) < max_retry
  225. ]
  226. # 获取可用的指标ID
  227. available_metric_ids = []
  228. outline_draft = state.get('outline_draft')
  229. if outline_draft and outline_draft.global_metrics:
  230. available_metric_ids = [m.metric_id for m in outline_draft.global_metrics if m.metric_id]
  231. return f"""当前状态评估:
  232. - 规划步骤: {state.get('planning_step', 0)}
  233. - 大纲版本: {state.get('outline_version', 0)}
  234. - 大纲草稿存在: {state.get('outline_draft') is not None}
  235. - 指标需求总数: {required_count}
  236. - 已计算指标数: {computed_count}
  237. - 指标覆盖率: {coverage:.2%}
  238. - 待计算指标数: {len(pending_ids)}
  239. - 有效待计算指标ID列表: {filtered_pending_ids}
  240. - 可用指标ID列表: {available_metric_ids}
  241. - 失败尝试记录: {failed_attempts}
  242. """
  243. def analyze_current_state(state: Dict[str, Any]) -> Dict[str, Any]:
  244. """
  245. 分析当前状态,返回关键信息
  246. Args:
  247. state: 当前状态
  248. Returns:
  249. 状态分析结果
  250. """
  251. required_metrics = state.get("metrics_requirements", [])
  252. computed_metrics = state.get("computed_metrics", {})
  253. # 计算覆盖率
  254. required_count = len(required_metrics)
  255. computed_count = len(computed_metrics)
  256. coverage = computed_count / required_count if required_count > 0 else 0
  257. # 找出未计算的指标
  258. computed_ids = set(computed_metrics.keys())
  259. pending_metrics = [
  260. m for m in required_metrics
  261. if m.metric_id not in computed_ids
  262. ]
  263. # 检查失败次数
  264. failed_attempts = state.get("failed_metric_attempts", {})
  265. max_retry = 3
  266. valid_pending_metrics = [
  267. m for m in pending_metrics
  268. if failed_attempts.get(m.metric_id, 0) < max_retry
  269. ]
  270. return {
  271. "has_outline": state.get("outline_draft") is not None,
  272. "required_count": required_count,
  273. "computed_count": computed_count,
  274. "coverage": coverage,
  275. "pending_metrics": pending_metrics,
  276. "valid_pending_metrics": valid_pending_metrics,
  277. "pending_ids": [m.metric_id for m in pending_metrics],
  278. "valid_pending_ids": [m.metric_id for m in valid_pending_metrics],
  279. "planning_step": state.get("planning_step", 0),
  280. "outline_version": state.get("outline_version", 0)
  281. }
  282. async def plan_next_action(question: str, industry: str, current_state: Dict[str, Any], api_key: str) -> PlanningDecision:
  283. """
  284. 规划下一步行动的主函数
  285. Args:
  286. question: 用户查询
  287. current_state: 当前状态
  288. api_key: API密钥
  289. Returns:
  290. 规划决策结果
  291. """
  292. agent = PlanningAgent(api_key)
  293. try:
  294. decision = await agent.make_decision(question, industry, current_state)
  295. print(f"\n🧠 规划决策:{decision.decision}")
  296. print(f" 推理:{decision.reasoning[:100]}...")
  297. if decision.metrics_to_compute:
  298. print(f" 待计算指标:{decision.metrics_to_compute}")
  299. return decision
  300. except Exception as e:
  301. print(f"⚠️ 规划决策出错: {e},使用默认决策")
  302. # 直接返回最基本的默认决策,避免复杂的默认决策逻辑
  303. return PlanningDecision(
  304. decision="finalize_report",
  305. reasoning="规划决策失败,使用默认的报告生成决策",
  306. next_actions=["生成最终报告"],
  307. metrics_to_compute=[],
  308. priority_metrics=[]
  309. )