浏览代码

增加分类打标状态

jiaqiang 2 天之前
父节点
当前提交
953905f476
共有 1 个文件被更改,包括 35 次插入1 次删除
  1. 35 1
      llmops/workflow_state.py

+ 35 - 1
llmops/workflow_state.py

@@ -83,11 +83,14 @@ class IntegratedWorkflowState(TypedDict):
     # === 基础输入层 (兼容Big Agent) ===
     user_input: str
     question: str  # 别名,兼容报告生成Agent
+
     industry: str  # 行业
 
     # === 数据层 ===
     data_set: List[Dict[str, Any]]  # 报告生成Agent的数据格式
     transactions_df: Optional[Any]  # 可选的数据框格式
+    file_name: str                  # 数据文件名称
+    data_set_classified: List[Dict[str, Any]] # 分类打标后的数据集
 
     # === 意图识别层 (Big Agent原有) ===
     intent_result: Optional[Dict[str, Any]]
@@ -132,13 +135,15 @@ class IntegratedWorkflowState(TypedDict):
 
 # ============= 状态创建和初始化函数 =============
 
-def create_initial_integrated_state(question: str, industry: str, data: List[Dict[str, Any]], session_id: str = None) -> IntegratedWorkflowState:
+def create_initial_integrated_state(question: str, industry: str, data: List[Dict[str, Any]], file_name: str, session_id: str = None) -> IntegratedWorkflowState:
     """
     创建初始的整合状态
 
     Args:
         question: 用户查询
+        industry: 行业
         data: 数据集
+        file_name: 数据文件名称
         session_id: 会话ID
 
     Returns:
@@ -155,6 +160,12 @@ def create_initial_integrated_state(question: str, industry: str, data: List[Dic
 
         # 数据层
         "data_set": convert_numpy_types(data),
+        "data_set_classified": [],    # 分类打标后的数据集
+        "transactions_df": None,
+        "file_name": file_name,       # 文件名称
+
+        # 意图识别层
+        "intent_result": None,
 
         # 规划和大纲层
         "planning_step": 0,
@@ -320,4 +331,27 @@ def finalize_state_with_report(state: IntegratedWorkflowState, final_report: Dic
     progress = get_calculation_progress(new_state)
     new_state["completeness_score"] = progress["coverage_rate"]
 
+    return new_state
+
+def update_state_with_data_classified(state: IntegratedWorkflowState, data_set_classified: List[Dict]) -> IntegratedWorkflowState:
+    """
+    使用分类打标结果更新状态
+
+    Args:
+        state: 当前状态
+        data_set_classified: 分类打标的数据
+
+    Returns:
+        更新后的状态
+    """
+    new_state = state.copy()
+    new_state["data_set_classified"] = data_set_classified
+
+    # 添加消息
+    new_state["messages"].append({
+        "role": "assistant",
+        "content": f"✅ 数据分类打标已完成",
+        "timestamp": datetime.now().isoformat()
+    })
+
     return new_state