outline_agent.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from typing import List, Dict, Any
  2. from langchain_openai import ChatOpenAI
  3. from langchain_core.prompts import ChatPromptTemplate
  4. import json # 确保导入json
  5. import uuid
  6. from llmops.agents.state import AgentState, ReportOutline, ReportSection, MetricRequirement, convert_numpy_types
  7. from llmops.agents.datadev.llm import get_llm
  8. class OutlineGenerator:
  9. """大纲生成智能体:将报告需求转化为结构化大纲"""
  10. def __init__(self, llm):
  11. self.llm = llm.with_structured_output(ReportOutline)
  12. def create_prompt(self, question: str, sample_data: List[Dict]) -> str:
  13. """创建大纲生成提示"""
  14. available_fields = list(sample_data[0].keys()) if sample_data else []
  15. sample_str = json.dumps(sample_data[:2], ensure_ascii=False, indent=2)
  16. # 关键修复:提供详细的字段说明和示例
  17. return f"""你是银行流水报告大纲专家。根据用户需求和样本数据,生成专业、可执行的报告大纲。
  18. 需求分析:
  19. {question}
  20. 可用字段:
  21. {', '.join(available_fields)}
  22. 样本数据:
  23. {sample_str}
  24. 输出要求(必须生成有效的JSON):
  25. 1. report_title: 报告标题(字符串)
  26. 2. sections: 章节列表,每个章节必须包含:
  27. - section_id: 章节唯一ID(如"sec_1", "sec_2")
  28. - title: 章节标题
  29. - description: 章节描述
  30. - metrics_needed: 所需指标ID列表(字符串数组,可为空)
  31. 3. global_metrics: 全局指标列表,每个指标必须包含:
  32. - metric_id: 指标唯一ID(如"total_income", "avg_balance")
  33. - metric_name: 指标名称
  34. - calculation_logic: 计算逻辑描述
  35. - required_fields: 所需字段列表
  36. - dependencies: 依赖的其他指标ID(可为空)
  37. 重要提示:
  38. - 必须生成section_id,格式为"sec_1", "sec_2"等
  39. - 必须生成metric_id,格式为字母+下划线+描述
  40. - metrics_needed必须是字符串数组
  41. - 确保所有字段都存在,不能缺失
  42. 输出示例:
  43. {{
  44. "report_title": "2024年第三季度分析报告",
  45. "sections": [
  46. {{
  47. "section_id": "sec_1",
  48. "title": "收入概览",
  49. "description": "分析收入总额",
  50. "metrics_needed": ["total_income", "avg_income"]
  51. }}
  52. ],
  53. "global_metrics": [
  54. {{
  55. "metric_id": "total_income",
  56. "metric_name": "总收入",
  57. "calculation_logic": "sum of all income transactions",
  58. "required_fields": ["txAmount", "txDirection"],
  59. "dependencies": []
  60. }}
  61. ]
  62. }}"""
  63. async def generate(self, state: AgentState) -> ReportOutline:
  64. """异步生成大纲(修复版:自动补全缺失字段)"""
  65. prompt = self.create_prompt(
  66. question=state["question"],
  67. sample_data=state["data_set"][:2]
  68. )
  69. messages = [
  70. ("system", "你是一名专业的报告大纲生成专家,必须输出完整、有效的JSON格式,包含所有必需字段。"),
  71. ("user", prompt)
  72. ]
  73. outline = await self.llm.ainvoke(messages)
  74. # 关键修复:后处理,补全缺失的section_id和metric_id
  75. outline = self._post_process_outline(outline)
  76. return outline
  77. def _post_process_outline(self, outline: ReportOutline) -> ReportOutline:
  78. """
  79. 后处理大纲,自动补全缺失的必需字段
  80. """
  81. # 为章节补全section_id
  82. for idx, section in enumerate(outline.sections):
  83. if not section.section_id:
  84. section.section_id = f"sec_{idx + 1}"
  85. # 确保metrics_needed是列表
  86. if not isinstance(section.metrics_needed, list):
  87. section.metrics_needed = []
  88. # 为指标补全metric_id和dependencies
  89. for idx, metric in enumerate(outline.global_metrics):
  90. if not metric.metric_id:
  91. metric.metric_id = f"metric_{idx + 1}"
  92. # 确保dependencies是列表
  93. if not isinstance(metric.dependencies, list):
  94. metric.dependencies = []
  95. # 推断required_fields(如果为空)
  96. if not metric.required_fields:
  97. metric.required_fields = self._infer_required_fields(
  98. metric.calculation_logic
  99. )
  100. return outline
  101. def _infer_required_fields(self, logic: str) -> List[str]:
  102. """从计算逻辑推断所需字段"""
  103. field_mapping = {
  104. "收入": ["txAmount", "txDirection"],
  105. "支出": ["txAmount", "txDirection"],
  106. "余额": ["txBalance"],
  107. "对手方": ["txCounterparty"],
  108. "日期": ["txDate"],
  109. "时间": ["txTime", "txDate"],
  110. "摘要": ["txSummary"],
  111. "创建时间": ["createdAt"]
  112. }
  113. fields = []
  114. for keyword, field_list in field_mapping.items():
  115. if keyword in logic:
  116. fields.extend(field_list)
  117. return list(set(fields))
  118. async def outline_node(state: AgentState) -> AgentState:
  119. """大纲生成节点:设置成功标志,防止重复生成"""
  120. llm = get_llm()
  121. generator = OutlineGenerator(llm)
  122. try:
  123. # 异步生成大纲
  124. outline = await generator.generate(state)
  125. # 更新状态
  126. new_state = state.copy()
  127. new_state["outline_draft"] = outline
  128. new_state["outline_version"] += 1
  129. # 防护:设置成功标志
  130. new_state["outline_ready"] = True # 明确标志:大纲已就绪
  131. new_state["metrics_requirements"] = outline.global_metrics
  132. new_state["metrics_pending"] = outline.global_metrics.copy() # 待计算指标
  133. new_state["messages"].append(
  134. ("ai", f"✅ 大纲生成完成 v{new_state['outline_version']}:{outline.report_title}")
  135. )
  136. print(f"\n📝 大纲已生成:{outline.report_title}")
  137. print(f" 章节数:{len(outline.sections)}")
  138. print(f" 指标数:{len(outline.global_metrics)}")
  139. # 新增:详细打印大纲内容
  140. print("\n" + "=" * 70)
  141. print("📋 详细大纲内容")
  142. print("=" * 70)
  143. print(json.dumps(outline.dict(), ensure_ascii=False, indent=2))
  144. print("=" * 70)
  145. # 关键修复:返回前清理状态
  146. return convert_numpy_types(new_state)
  147. except Exception as e:
  148. print(f"⚠️ 大纲生成出错: {e},使用默认结构")
  149. # 创建默认大纲
  150. default_outline = ReportOutline(
  151. report_title="默认交易分析报告",
  152. sections=[
  153. ReportSection(
  154. section_id="sec_1",
  155. title="交易概览",
  156. description="基础交易情况分析",
  157. metrics_needed=["total_transactions", "total_income", "total_expense"]
  158. )
  159. ],
  160. global_metrics=[
  161. MetricRequirement(
  162. metric_id="total_transactions",
  163. metric_name="总交易笔数",
  164. calculation_logic="count all transactions",
  165. required_fields=["txId"],
  166. dependencies=[]
  167. ),
  168. MetricRequirement(
  169. metric_id="total_income",
  170. metric_name="总收入",
  171. calculation_logic="sum of income transactions",
  172. required_fields=["txAmount", "txDirection"],
  173. dependencies=[]
  174. )
  175. ]
  176. )
  177. new_state = state.copy()
  178. new_state["outline_draft"] = default_outline
  179. new_state["outline_version"] += 1
  180. new_state["outline_ready"] = True # 即使默认也标记为就绪
  181. new_state["metrics_requirements"] = default_outline.global_metrics
  182. new_state["messages"].append(
  183. ("ai", f"⚠️ 使用默认大纲 v{new_state['outline_version']}")
  184. )
  185. # 新增:详细打印默认大纲内容
  186. print("\n" + "=" * 70)
  187. print("📋 默认大纲内容")
  188. print("=" * 70)
  189. print(json.dumps(default_outline.dict(), ensure_ascii=False, indent=2))
  190. print("=" * 70)
  191. # 关键修复:返回前清理状态
  192. return convert_numpy_types(new_state)