state.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from typing import TypedDict, List, Dict, Any, Optional
  2. from langchain_core.messages import BaseMessage
  3. from pydantic import BaseModel, Field
  4. import numpy as np
  5. # ============= 数据模型 =============
  6. class MetricRequirement(BaseModel):
  7. """指标需求定义"""
  8. metric_id: str = Field(description="指标唯一标识,如 'total_income_jan'")
  9. metric_name: str = Field(description="指标中文名称")
  10. calculation_logic: str = Field(description="计算逻辑描述")
  11. required_fields: List[str] = Field(description="所需字段")
  12. dependencies: List[str] = Field(default_factory=list)
  13. class ReportSection(BaseModel):
  14. """报告大纲章节"""
  15. section_id: str = Field(description="章节ID")
  16. title: str = Field(description="章节标题")
  17. description: str = Field(description="章节内容要求")
  18. metrics_needed: List[str] = Field(description="所需指标ID列表")
  19. class ReportOutline(BaseModel):
  20. """完整报告大纲"""
  21. report_title: str
  22. sections: List[ReportSection]
  23. global_metrics: List[MetricRequirement]
  24. # ============= 序列化工具函数 =============
  25. def convert_numpy_types(obj: Any) -> Any:
  26. """
  27. 递归转换所有numpy类型为Python原生类型
  28. 关键修复:确保所有数据可序列化
  29. """
  30. if isinstance(obj, dict):
  31. return {str(k): convert_numpy_types(v) for k, v in obj.items()}
  32. elif isinstance(obj, list):
  33. return [convert_numpy_types(item) for item in obj]
  34. elif isinstance(obj, tuple):
  35. return tuple(convert_numpy_types(item) for item in obj)
  36. elif isinstance(obj, set):
  37. return {convert_numpy_types(item) for item in obj}
  38. elif isinstance(obj, np.integer):
  39. return int(obj)
  40. elif isinstance(obj, np.floating):
  41. return float(obj)
  42. elif isinstance(obj, np.bool_):
  43. return bool(obj)
  44. elif isinstance(obj, np.ndarray):
  45. return convert_numpy_types(obj.tolist())
  46. elif hasattr(obj, 'item') and hasattr(obj, 'dtype'): # numpy scalar
  47. return convert_numpy_types(obj.item())
  48. else:
  49. return obj
  50. def create_initial_state(question: str, data: List[Dict[str, Any]], session_id: str = None) -> Dict[str, Any]:
  51. """创建初始状态,确保所有数据已清理"""
  52. cleaned_data = convert_numpy_types(data)
  53. return {
  54. "question": str(question),
  55. "data_set": cleaned_data,
  56. "transactions_df": None,
  57. "planning_step": 0,
  58. "plan_history": [],
  59. "outline_draft": None,
  60. "outline_version": 0,
  61. "metrics_requirements": [],
  62. "computed_metrics": {},
  63. "metrics_cache": {},
  64. "pending_metric_ids": [],
  65. "failed_metric_attempts": {},
  66. "report_draft": {},
  67. "is_complete": False,
  68. "completeness_score": 0.0,
  69. "messages": [],
  70. "current_node": "start",
  71. "session_id": str(session_id) if session_id else "default_session",
  72. "next_route": "planning_node",
  73. "outline_ready": False,
  74. "metrics_ready": False,
  75. "last_decision": "init"
  76. }
  77. # ============= 状态定义 =============
  78. class AgentState(TypedDict):
  79. # === 输入层 ===
  80. question: str
  81. data_set: List[Dict[str, Any]]
  82. transactions_df: Optional[Any]
  83. # === 规划层 ===
  84. planning_step: int
  85. plan_history: List[str]
  86. # === 大纲层 ===
  87. outline_draft: Optional[ReportOutline]
  88. outline_version: int
  89. # === 指标层 ===
  90. metrics_requirements: List[MetricRequirement]
  91. computed_metrics: Dict[str, Any]
  92. metrics_cache: Dict[str, Any]
  93. pending_metric_ids: List[str]
  94. failed_metric_attempts: Dict[str, int]
  95. # === 结果层 ===
  96. report_draft: Dict[str, Any]
  97. is_complete: bool
  98. completeness_score: float
  99. # === 对话历史 ===
  100. messages: List[BaseMessage]
  101. current_node: str
  102. session_id: str
  103. next_route: str
  104. outline_ready: bool
  105. metrics_ready: bool
  106. last_decision: str