workflow_state.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. """
  2. 整合的工作流状态定义
  3. ===================
  4. 此文件定义了整合了多个Agent的工作流状态,兼容现有的Big Agent状态管理和新增的报告生成Agent状态。
  5. 状态层次:
  6. 1. 输入层:用户查询和数据
  7. 2. 意图层:意图识别结果
  8. 3. 规划层:规划决策和大纲生成
  9. 4. 计算层:指标计算结果
  10. 5. 结果层:最终报告生成
  11. 6. 对话层:消息历史和错误处理
  12. 兼容性:
  13. - 兼容现有的Big Agent WorkflowState
  14. - 整合来自other_agents的AgentState
  15. - 支持扩展新的Agent状态需求
  16. 作者: Big Agent Team
  17. 版本: 1.0.0
  18. 创建时间: 2024-12-20
  19. """
  20. from typing import TypedDict, List, Dict, Any, Optional
  21. from datetime import datetime
  22. from langchain_core.messages import BaseMessage
  23. from pydantic import BaseModel, Field
  24. # ============= 数据模型 =============
  25. class MetricRequirement(BaseModel):
  26. """指标需求定义"""
  27. metric_id: str = Field(description="指标唯一标识,如 'total_income_jan'")
  28. metric_name: str = Field(description="指标中文名称")
  29. calculation_logic: str = Field(description="计算逻辑描述")
  30. required_fields: List[str] = Field(description="所需字段")
  31. dependencies: List[str] = Field(default_factory=list, description="依赖的其他指标ID")
  32. class ReportSection(BaseModel):
  33. """报告大纲章节"""
  34. section_id: str = Field(description="章节ID")
  35. title: str = Field(description="章节标题")
  36. description: str = Field(description="章节内容要求")
  37. metrics_needed: List[str] = Field(description="所需指标ID列表")
  38. class ReportOutline(BaseModel):
  39. """完整报告大纲"""
  40. report_title: str = Field(description="报告标题")
  41. sections: List[ReportSection] = Field(description="章节列表")
  42. global_metrics: List[MetricRequirement] = Field(description="全局指标列表")
  43. # ============= 序列化工具函数 =============
  44. def convert_numpy_types(obj: Any) -> Any:
  45. """
  46. 递归转换所有numpy类型为Python原生类型
  47. 确保所有数据可序列化
  48. """
  49. if isinstance(obj, dict):
  50. return {str(k): convert_numpy_types(v) for k, v in obj.items()}
  51. elif isinstance(obj, list):
  52. return [convert_numpy_types(item) for item in obj]
  53. elif isinstance(obj, tuple):
  54. return tuple(convert_numpy_types(item) for item in obj)
  55. elif isinstance(obj, set):
  56. return {convert_numpy_types(item) for item in obj}
  57. elif hasattr(obj, 'item') and hasattr(obj, 'dtype'): # numpy scalar
  58. return convert_numpy_types(obj.item())
  59. else:
  60. return obj
  61. # ============= 整合的工作流状态定义 =============
  62. class IntegratedWorkflowState(TypedDict):
  63. """整合的工作流状态定义 - 兼容多个Agent系统"""
  64. # === 基础输入层 (兼容Big Agent) ===
  65. user_input: str
  66. question: str # 别名,兼容报告生成Agent
  67. industry: str # 行业
  68. # === 数据层 ===
  69. data_set: List[Dict[str, Any]] # 报告生成Agent的数据格式
  70. transactions_df: Optional[Any] # 可选的数据框格式
  71. file_name: str # 数据文件名称
  72. data_set_classified: List[Dict[str, Any]] # 分类打标后的数据集
  73. # === 意图识别层 (Big Agent原有) ===
  74. intent_result: Optional[Dict[str, Any]]
  75. # === 规划和大纲层 (新增) ===
  76. planning_step: int
  77. plan_history: List[str]
  78. outline_draft: Optional[ReportOutline]
  79. outline_version: int
  80. outline_ready: bool
  81. # === 指标计算层 ===
  82. metrics_requirements: List[MetricRequirement] # 报告生成Agent格式
  83. computed_metrics: Dict[str, Any] # 计算结果
  84. metrics_cache: Dict[str, Any] # 缓存
  85. pending_metric_ids: List[str] # 待计算指标ID
  86. failed_metric_attempts: Dict[str, int] # 失败统计
  87. calculation_results: Optional[Dict[str, Any]] # Big Agent格式的计算结果
  88. # === 结果层 ===
  89. report_draft: Dict[str, Any] # 报告草稿
  90. knowledge_result: Optional[Dict[str, Any]] # Big Agent知识沉淀结果
  91. is_complete: bool
  92. completeness_score: float
  93. answer: Optional[str] # 最终答案
  94. # === 对话和消息层 ===
  95. messages: List[Dict[str, Any]] # Big Agent消息格式
  96. current_node: str
  97. session_id: str
  98. next_route: str
  99. # === 错误处理层 ===
  100. errors: List[str]
  101. last_decision: str
  102. # === 时间跟踪层 ===
  103. start_time: str
  104. end_time: Optional[str]
  105. api_result: Dict[str, Any] # 存储所有API调用结果
  106. # ============= 状态创建和初始化函数 =============
  107. def create_initial_integrated_state(question: str, industry: str, data: List[Dict[str, Any]], file_name: str, session_id: str = None) -> IntegratedWorkflowState:
  108. """
  109. 创建初始的整合状态
  110. Args:
  111. question: 用户查询
  112. industry: 行业
  113. data: 数据集
  114. file_name: 数据文件名称
  115. session_id: 会话ID
  116. Returns:
  117. 初始化后的状态
  118. """
  119. current_time = datetime.now().isoformat()
  120. session = session_id or f"session_{int(datetime.now().timestamp())}"
  121. return {
  122. # 基础输入
  123. "user_input": question,
  124. "question": question,
  125. "industry": industry,
  126. # 数据层
  127. "data_set": convert_numpy_types(data),
  128. "data_set_classified": [], # 分类打标后的数据集
  129. "transactions_df": None,
  130. "file_name": file_name, # 文件名称
  131. # 意图识别层
  132. "intent_result": None,
  133. # 规划和大纲层
  134. "planning_step": 0,
  135. "plan_history": [],
  136. "outline_draft": None,
  137. "outline_version": 0,
  138. "outline_ready": False,
  139. # 指标计算层
  140. "metrics_requirements": [],
  141. "computed_metrics": {},
  142. "metrics_cache": {},
  143. "pending_metric_ids": [],
  144. "failed_metric_attempts": {},
  145. "calculation_results": None,
  146. # 结果层
  147. "report_draft": {},
  148. "knowledge_result": None,
  149. "is_complete": False,
  150. "completeness_score": 0.0,
  151. "answer": None,
  152. # 对话和消息层
  153. "messages": [{
  154. "role": "user",
  155. "content": question,
  156. "timestamp": current_time
  157. }],
  158. "current_node": "start",
  159. "session_id": session,
  160. "next_route": "planning_node",
  161. # 错误处理层
  162. "errors": [],
  163. "last_decision": "init",
  164. # 时间跟踪层
  165. "start_time": current_time,
  166. "end_time": None,
  167. "api_result": {}, # 存储所有API调用结果
  168. # 计算模式配置层
  169. "use_rules_engine_only": False,
  170. "use_traditional_engine_only": False
  171. }
  172. def is_state_ready_for_calculation(state: IntegratedWorkflowState) -> bool:
  173. """
  174. 检查状态是否准备好进行指标计算
  175. Args:
  176. state: 当前状态
  177. Returns:
  178. 是否准备好
  179. """
  180. return (
  181. state.get("outline_draft") is not None and
  182. len(state.get("metrics_requirements", [])) > 0 and
  183. len(state.get("pending_metric_ids", [])) > 0
  184. )
  185. def get_calculation_progress(state: IntegratedWorkflowState) -> Dict[str, Any]:
  186. """
  187. 获取指标计算进度信息
  188. Args:
  189. state: 当前状态
  190. Returns:
  191. 进度信息
  192. """
  193. required = len(state.get("metrics_requirements", []))
  194. computed = len(state.get("computed_metrics", {}))
  195. pending = len(state.get("pending_metric_ids", []))
  196. return {
  197. "required_count": required,
  198. "computed_count": computed,
  199. "pending_count": pending,
  200. "coverage_rate": computed / required if required > 0 else 0,
  201. "is_complete": computed >= required * 0.8 # 80%覆盖率视为完成
  202. }
  203. def update_state_with_outline_generation(state: IntegratedWorkflowState, outline: ReportOutline) -> IntegratedWorkflowState:
  204. """
  205. 使用大纲生成结果更新状态
  206. Args:
  207. state: 当前状态
  208. outline: 生成的大纲
  209. Returns:
  210. 更新后的状态
  211. """
  212. new_state = state.copy()
  213. new_state["outline_draft"] = outline
  214. new_state["outline_version"] += 1
  215. new_state["outline_ready"] = True
  216. new_state["metrics_requirements"] = outline.global_metrics
  217. new_state["pending_metric_ids"] = [m.metric_id for m in outline.global_metrics]
  218. # 添加消息
  219. new_state["messages"].append({
  220. "role": "assistant",
  221. "content": f"✅ 大纲生成完成 v{new_state['outline_version']}:{outline.report_title}",
  222. "timestamp": datetime.now().isoformat()
  223. })
  224. return new_state
  225. def update_state_with_planning_decision(state: IntegratedWorkflowState, decision: Dict[str, Any]) -> IntegratedWorkflowState:
  226. """
  227. 使用规划决策结果更新状态
  228. Args:
  229. state: 当前状态
  230. decision: 规划决策
  231. Returns:
  232. 更新后的状态
  233. """
  234. new_state = state.copy()
  235. new_state["planning_step"] += 1
  236. new_state["last_decision"] = decision.get("decision", "unknown")
  237. new_state["next_route"] = decision.get("next_route", "planning_node")
  238. # 如果有待计算指标,更新待计算列表
  239. if decision.get("metrics_to_compute"):
  240. new_state["pending_metric_ids"] = decision["metrics_to_compute"]
  241. # 添加规划历史
  242. new_state["plan_history"].append(
  243. f"Step {new_state['planning_step']}: {decision.get('decision', 'unknown')}"
  244. )
  245. return new_state
  246. def finalize_state_with_report(state: IntegratedWorkflowState, final_report: Dict[str, Any]) -> IntegratedWorkflowState:
  247. """
  248. 使用最终报告完成状态
  249. Args:
  250. state: 当前状态
  251. final_report: 最终报告
  252. Returns:
  253. 完成的状态
  254. """
  255. new_state = state.copy()
  256. new_state["report_draft"] = final_report
  257. new_state["is_complete"] = True
  258. new_state["answer"] = final_report
  259. new_state["end_time"] = datetime.now().isoformat()
  260. # 计算完整性分数
  261. progress = get_calculation_progress(new_state)
  262. new_state["completeness_score"] = progress["coverage_rate"]
  263. return new_state
  264. def update_state_with_data_classified(state: IntegratedWorkflowState, data_set_classified: List[Dict]) -> IntegratedWorkflowState:
  265. """
  266. 使用分类打标结果更新状态
  267. Args:
  268. state: 当前状态
  269. data_set_classified: 分类打标的数据
  270. Returns:
  271. 更新后的状态
  272. """
  273. new_state = state.copy()
  274. new_state["data_set_classified"] = data_set_classified
  275. # 添加消息
  276. new_state["messages"].append({
  277. "role": "assistant",
  278. "content": f"✅ 数据分类打标已完成",
  279. "timestamp": datetime.now().isoformat()
  280. })
  281. return new_state