planning_agent.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from typing import List, Dict, Optional, Any, Union
  2. from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
  3. from pydantic import BaseModel, Field
  4. from langchain_openai import ChatOpenAI
  5. import json
  6. from llmops.agents.state import AgentState, MetricRequirement, convert_numpy_types
  7. from llmops.agents.datadev.llm import get_llm
  8. class ActionItem(BaseModel):
  9. """动作项定义"""
  10. action: str = Field(description="动作名称")
  11. parameters: Optional[Dict[str, Any]] = Field(default_factory=dict)
  12. class ClarificationRequest(BaseModel):
  13. """澄清请求结构化格式"""
  14. questions: List[str] = Field(description="需要澄清的问题列表")
  15. missing_fields: List[str] = Field(default_factory=list, description="缺少的字段或信息")
  16. class PlanningOutput(BaseModel):
  17. """规划决策输出 - 支持灵活格式"""
  18. decision: str = Field(
  19. description="决策类型: generate_outline, compute_metrics, finalize, clarify"
  20. )
  21. reasoning: str = Field(description="详细推理过程")
  22. next_actions: List[Union[str, ActionItem]] = Field(
  23. default_factory=list,
  24. description="下一步动作列表"
  25. )
  26. # 关键修复:明确传递待计算指标ID列表
  27. metrics_to_compute: List[str] = Field(
  28. default_factory=list,
  29. description="待计算指标ID列表(如 ['total_income', 'avg_balance'])"
  30. )
  31. additional_requirements: Optional[
  32. Union[Dict[str, Any], List[Any], ClarificationRequest]
  33. ] = Field(default=None, description="额外需求或澄清信息")
  34. def normalize_additional_requirements(req: Any) -> Optional[Dict[str, Any]]:
  35. """
  36. 规范化 additional_requirements
  37. 将列表转换为字典格式
  38. """
  39. if req is None:
  40. return None
  41. if isinstance(req, dict):
  42. return req
  43. if isinstance(req, list):
  44. # 如果LLM错误地返回了列表,转换为字典格式
  45. return {
  46. "questions": [str(item) for item in req],
  47. "missing_fields": []
  48. }
  49. return {"raw": str(req)}
  50. def create_planning_agent(llm, state: AgentState):
  51. """创建规划智能体(修复版:移除JSON示例,避免变量冲突)"""
  52. prompt = ChatPromptTemplate.from_messages([
  53. ("system", """你是报告规划总控智能体,核心职责是精准分析当前状态并决定下一步行动。
  54. ### 决策选项(四选一)
  55. 1. generate_outline:大纲未生成或大纲无效
  56. 2. compute_metrics:大纲已生成但指标未完成(覆盖率<80%)
  57. 3. finalize:指标覆盖率≥80%,信息充足
  58. 4. clarify:用户需求模糊,缺少关键信息
  59. ### 决策规则(按顺序检查)
  60. 1. 检查 outline_draft 是否为空 → 空则选择 generate_outline
  61. 2. 检查 metrics_requirements 是否为空 → 空则选择 generate_outline
  62. 3. 计算指标覆盖率 = 已计算指标 / 总需求指标
  63. - 覆盖率 < 0.8 → 选择 compute_metrics
  64. - 覆盖率 ≥ 0.8 → 选择 finalize
  65. 4. 如果无法理解需求 → 选择 clarify
  66. ### 重要原则
  67. - 大纲草稿已存在时,不要重复生成大纲
  68. - 决策为 compute_metrics 时,必须提供具体的指标ID列表
  69. - 确保 metrics_to_compute 是字符串数组格式
  70. ### 输出字段说明
  71. - decision: 决策字符串
  72. - reasoning: 决策原因说明
  73. - next_actions: 动作列表(可选)
  74. - metrics_to_compute: 待计算指标ID列表(决策为compute_metrics时必须提供)
  75. - additional_requirements: 额外需求(可选)
  76. 必须输出有效的JSON格式!"""),
  77. MessagesPlaceholder("messages"),
  78. ("user", "报告需求:{question}\n\n请输出决策结果。")
  79. ])
  80. return prompt | llm.with_structured_output(PlanningOutput)
  81. async def planning_node(state: AgentState) -> AgentState:
  82. """规划节点:正确识别待计算指标并传递"""
  83. llm = get_llm()
  84. planner = create_planning_agent(llm, state)
  85. # 构建完整的状态评估上下文
  86. required_count = len(state["metrics_requirements"])
  87. computed_count = len(state["computed_metrics"])
  88. coverage = computed_count / required_count if required_count > 0 else 0
  89. # 新增:跟踪失败次数,避免无限循环
  90. failed_attempts = state.get("failed_metric_attempts", {})
  91. pending_ids = state.get("pending_metric_ids", [])
  92. # 过滤掉失败次数过多的指标
  93. max_retry = 3
  94. filtered_pending_ids = [
  95. mid for mid in pending_ids
  96. if failed_attempts.get(mid, 0) < max_retry
  97. ]
  98. status_snapshot = f"""当前状态评估:
  99. - 规划步骤: {state['planning_step']}
  100. - 大纲版本: {state['outline_version']}
  101. - 大纲草稿存在: {state['outline_draft'] is not None}
  102. - 指标需求总数: {required_count}
  103. - 已计算指标数: {computed_count}
  104. - 指标覆盖率: {coverage:.2%}
  105. - 待计算指标数: {len(pending_ids)}
  106. - 有效待计算指标数: {len(filtered_pending_ids)}
  107. - 失败尝试记录: {failed_attempts}
  108. 建议下一步: {"计算指标" if coverage < 0.8 else "生成报告"}"""
  109. # 执行规划
  110. result = await planner.ainvoke({
  111. "question": state["question"],
  112. "messages": [("system", status_snapshot)]
  113. })
  114. # 规范化结果
  115. normalized_req = normalize_additional_requirements(result.additional_requirements)
  116. # 找出所有未计算的指标
  117. computed_ids = set(state["computed_metrics"].keys())
  118. required_metrics = state["metrics_requirements"]
  119. pending_metrics = [
  120. m for m in required_metrics
  121. if m.metric_id not in computed_ids
  122. ]
  123. # 关键:使用 LLM 返回的指标ID,如果没有则使用全部待计算指标
  124. if result.metrics_to_compute:
  125. pending_ids = result.metrics_to_compute
  126. valid_ids = [m.metric_id for m in pending_metrics]
  127. pending_metrics = [m for m in pending_metrics if m.metric_id in pending_ids and m.metric_id in valid_ids]
  128. # 更新状态
  129. new_state = state.copy()
  130. new_state["plan_history"].append(
  131. f"Step {new_state['planning_step']}: {result.decision}"
  132. )
  133. new_state["planning_step"] += 1
  134. new_state["additional_requirements"] = normalized_req
  135. # 关键:保存待计算指标ID列表
  136. if pending_metrics:
  137. pending_ids = [m.metric_id for m in pending_metrics]
  138. new_state["pending_metric_ids"] = pending_ids
  139. new_state["metrics_to_compute"] = pending_metrics # 保存完整对象
  140. # 设置路由标志
  141. if result.decision == "generate_outline":
  142. new_state["messages"].append(
  143. ("ai", f"📋 规划决策:生成大纲 (v{new_state['outline_version'] + 1})")
  144. )
  145. new_state["next_route"] = "outline_generator"
  146. elif result.decision == "compute_metrics":
  147. # 修复:确保显示正确的数量
  148. if not pending_metrics:
  149. # 如果没有待计算指标但有需求,则计算所有未完成的
  150. computed_ids = set(state["computed_metrics"].keys())
  151. pending_metrics = [m for m in required_metrics if m.metric_id not in computed_ids]
  152. # 新增:如果有效待计算指标为空但还有指标未计算,说明都失败了太多次
  153. if not filtered_pending_ids and pending_ids:
  154. new_state["messages"].append(
  155. ("ai", f"⚠️ 剩余 {len(pending_ids)} 个指标已多次计算失败,将跳过这些指标直接生成报告")
  156. )
  157. new_state["next_route"] = "report_compiler"
  158. # 关键修复:返回前清理状态
  159. return convert_numpy_types(new_state)
  160. new_state["messages"].append(
  161. ("ai", f"🧮 规划决策:计算 {len(pending_metrics)} 个指标 ({[m.metric_id for m in pending_metrics]})")
  162. )
  163. new_state["next_route"] = "metrics_calculator"
  164. elif result.decision == "finalize":
  165. new_state["is_complete"] = True
  166. new_state["messages"].append(
  167. ("ai", f"✅ 规划决策:信息充足,生成最终报告(覆盖率 {coverage:.2%})")
  168. )
  169. new_state["next_route"] = "report_compiler"
  170. elif result.decision == "clarify":
  171. questions = []
  172. if normalized_req and "questions" in normalized_req:
  173. questions = normalized_req["questions"]
  174. new_state["messages"].append(
  175. ("ai", f"❓ 需要澄清:{';'.join(questions) if questions else '请提供更详细的需求'}")
  176. )
  177. new_state["next_route"] = "clarify_node"
  178. # 关键修复:返回前清理状态
  179. return convert_numpy_types(new_state)