Jelajahi Sumber

增加数据分类打标逻辑

jiaqiang 2 hari lalu
induk
melakukan
808065e73b
1 mengubah file dengan 129 tambahan dan 80 penghapusan
  1. 129 80
      llmops/complete_agent_flow_rule.py

+ 129 - 80
llmops/complete_agent_flow_rule.py

@@ -32,19 +32,21 @@ from typing import Dict, Any, List
 from datetime import datetime
 from langgraph.graph import StateGraph, START, END
 
-from workflow_state import (
+from llmops.workflow_state import (
     IntegratedWorkflowState,
     create_initial_integrated_state,
     get_calculation_progress,
     update_state_with_outline_generation,
     update_state_with_planning_decision,
+    update_state_with_data_classified,
     convert_numpy_types,
-
 )
-from agents.outline_agent import OutlineGeneratorAgent, generate_report_outline
-from agents.planning_agent import PlanningAgent, plan_next_action, analyze_current_state
-from agents.rules_engine_metric_calculation_agent import RulesEngineMetricCalculationAgent
-
+from llmops.agents.outline_agent import  generate_report_outline
+from llmops.agents.planning_agent import  plan_next_action
+from llmops.agents.rules_engine_metric_calculation_agent import RulesEngineMetricCalculationAgent
+from llmops.agents.data_manager import DataManager
+import os
+from llmops.agents.data_classify_agent import data_classify
 
 class CompleteAgentFlow:
     """完整的智能体工作流"""
@@ -74,6 +76,7 @@ class CompleteAgentFlow:
         workflow.add_node("planning_node", self._planning_node)
         workflow.add_node("outline_generator", self._outline_generator_node)
         workflow.add_node("metric_calculator", self._metric_calculator_node)
+        workflow.add_node("data_classify", self._data_classify_node)
 
         # 设置入口点
         workflow.set_entry_point("planning_node")
@@ -85,13 +88,15 @@ class CompleteAgentFlow:
             {
                 "outline_generator": "outline_generator",
                 "metric_calculator": "metric_calculator",
+                "data_classify": "data_classify",
                 END: END
             }
         )
 
         # 从各个节点返回规划节点重新决策
+        workflow.add_edge("data_classify", "planning_node")
         workflow.add_edge("outline_generator", "planning_node")
-        workflow.add_edge("metric_calculator", "planning_node")
+        workflow.add_edge("metric_calculator", END)
 
         return workflow
 
@@ -106,6 +111,7 @@ class CompleteAgentFlow:
             目标节点名称
         """
         print(f"\n🔍 [路由决策] 步骤={state['planning_step']}, "
+              f"数据集分类打标数量={len(state.get('data_set_classified', []))}",
               f"大纲={state.get('outline_draft') is not None}, "
               f"指标需求={len(state.get('metrics_requirements', []))}")
 
@@ -114,11 +120,21 @@ class CompleteAgentFlow:
             print("⚠️ 规划步骤超过30次,强制结束流程")
             return END
 
+        # 数据分类打标数量为0 → 分类打标
+        if len(state.get("data_set_classified", [])) == 0:
+            print("→ 路由到 data_classify(分类打标)")
+            return "data_classify"
+
         # 如果大纲为空 → 生成大纲
         if not state.get("outline_draft"):
             print("→ 路由到 outline_generator(生成大纲)")
             return "outline_generator"
 
+        # 如果指标需求为空但大纲已生成 → 评估指标需求
+        if not state.get("metrics_requirements") and state.get("outline_draft"):
+            print("→ 路由到 metric_evaluator(评估指标需求)")
+            return "metric_evaluator"
+
         # 计算覆盖率
         progress = get_calculation_progress(state)
         coverage = progress["coverage_rate"]
@@ -130,22 +146,10 @@ class CompleteAgentFlow:
             print(f"→ 路由到 metric_calculator(计算指标,覆盖率={coverage:.2%})")
             return "metric_calculator"
 
-        # 检查是否应该结束流程
-        pending_ids = state.get("pending_metric_ids", [])
-        failed_attempts = state.get("failed_metric_attempts", {})
-        max_retries = 3
-
-        # 计算还有哪些指标可以重试(未达到最大重试次数)
-        retryable_metrics = [
-            mid for mid in pending_ids
-            if failed_attempts.get(mid, 0) < max_retries
-        ]
-
-        # 如果覆盖率 >= 80%,或者没有可重试的指标 → 结束流程
-        if coverage >= 0.8 or not retryable_metrics:
-            reason = "覆盖率达到80%" if coverage >= 0.8 else "没有可重试指标"
-            print(f"→ 结束流程(覆盖率={coverage:.2%},原因:{reason})")
-            return END
+        # 如果没有待计算指标或覆盖率 >= 80% → 生成最终报告
+        if not state.get("pending_metric_ids") or coverage >= 0.8:
+            print(f"→ 路由到 report_finalizer(生成最终报告,覆盖率={coverage:.2%})")
+            return "report_finalizer"
 
         # 默认返回规划节点
         return "planning_node"
@@ -219,6 +223,32 @@ class CompleteAgentFlow:
             new_state["errors"].append(f"大纲生成错误: {str(e)}")
             return convert_numpy_types(new_state)
 
+    async def _data_classify_node(self, state: IntegratedWorkflowState) -> IntegratedWorkflowState:
+        """数据分类打标节点"""
+        try:
+            print("📝 正在对数据进行分类打标...")
+
+            # 对数据进行分类打标
+            data_set_classified = await data_classify(
+                industry=state["industry"],
+                data_set=state["data_set"],
+                file_name=state["file_name"]
+            )
+
+            # 更新状态
+            new_state = update_state_with_data_classified(state, data_set_classified)
+
+            print(f"✅ 数据分类打标完成,打标记录数: {len(data_set_classified)}")
+
+            return convert_numpy_types(new_state)
+
+        except Exception as e:
+            print(f"❌ 数据分类打标失败: {e}")
+            new_state = state.copy()
+            new_state["errors"].append(f"数据分类打标错误: {str(e)}")
+            return convert_numpy_types(new_state)
+
+
     def _print_ai_selection_analysis(self, outline):
         """打印AI指标选择的推理过程分析 - 完全通用版本"""
         print()
@@ -270,6 +300,50 @@ class CompleteAgentFlow:
         print('   能够根据具体业务场景动态调整分析框架,确保分析的针对性和有效性。')
         print()
 
+    async def _metric_evaluator_node(self, state: IntegratedWorkflowState) -> IntegratedWorkflowState:
+        """指标评估节点:根据大纲确定需要计算的指标"""
+        try:
+            print("🔍 正在评估指标需求...")
+
+            new_state = state.copy()
+            outline = state.get("outline_draft")
+
+            if not outline:
+                print("⚠️ 没有大纲信息,跳过指标评估")
+                return convert_numpy_types(new_state)
+
+            # 从大纲中提取指标需求
+            metrics_requirements = outline.global_metrics
+            metric_ids = [m.metric_id for m in metrics_requirements]
+
+            # 设置待计算指标
+            new_state["metrics_requirements"] = metrics_requirements
+            new_state["pending_metric_ids"] = metric_ids.copy()
+            new_state["computed_metrics"] = {}
+            new_state["metrics_cache"] = {}
+
+            print(f"✅ 指标评估完成,发现 {len(metric_ids)} 个待计算指标")
+            for i, metric_id in enumerate(metric_ids[:5], 1):  # 只显示前5个
+                print(f"   {i}. {metric_id}")
+
+            if len(metric_ids) > 5:
+                print(f"   ... 还有 {len(metric_ids) - 5} 个指标")
+
+            # 添加消息
+            new_state["messages"].append({
+                "role": "assistant",
+                "content": f"🔍 指标评估完成:发现 {len(metric_ids)} 个待计算指标",
+                "timestamp": datetime.now().isoformat()
+            })
+
+            return convert_numpy_types(new_state)
+
+        except Exception as e:
+            print(f"❌ 指标评估失败: {e}")
+            new_state = state.copy()
+            new_state["errors"].append(f"指标评估错误: {str(e)}")
+            return convert_numpy_types(new_state)
+
     async def _metric_calculator_node(self, state: IntegratedWorkflowState) -> IntegratedWorkflowState:
         """指标计算节点"""
         try:
@@ -315,15 +389,11 @@ class CompleteAgentFlow:
                     # 找到对应的指标需求
                     metric_req = next((m for m in metrics_requirements if m.metric_id == metric_id), None)
                     if not metric_req:
-                        # 修复:找不到指标需求时,创建临时的指标需求结构,避免跳过指标
-                        print(f"⚠️ 指标 {metric_id} 找不到需求信息,创建临时配置继续计算")
-                        metric_req = type('MetricRequirement', (), {
-                            'metric_id': metric_id,
-                            'metric_name': metric_id.replace('metric-', '') if metric_id.startswith('metric-') else metric_id,
-                            'calculation_logic': f'计算 {metric_id}',
-                            'required_fields': ['transactions'],
-                            'dependencies': []
-                        })()
+                        print(f"⚠️ 找不到指标 {metric_id} 的需求信息,跳过")
+                        # 仍然从待计算列表中移除,避免无限循环
+                        if metric_id in new_state["pending_metric_ids"]:
+                            new_state["pending_metric_ids"].remove(metric_id)
+                        continue
 
                     print(f"🧮 计算指标: {metric_id} - {metric_req.metric_name}")
 
@@ -366,55 +436,27 @@ class CompleteAgentFlow:
                         }
 
                     # 处理计算结果
-                    calculation_success = False
                     for result in results.get("results", []):
                         if result.get("result", {}).get("success"):
                             # 计算成功
                             new_state["computed_metrics"][metric_id] = result["result"]
                             successful_calculations += 1
-                            calculation_success = True
                             print(f"✅ 指标 {metric_id} 计算成功")
-                            break  # 找到一个成功的就算成功
                         else:
                             # 计算失败
                             failed_calculations += 1
                             print(f"❌ 指标 {metric_id} 计算失败")
 
-                    # 初始化失败尝试记录
-                    if "failed_metric_attempts" not in new_state:
-                        new_state["failed_metric_attempts"] = {}
-
-                    # 根据计算结果处理指标
-                    if calculation_success:
-                        # 计算成功:从待计算列表中移除
-                        if metric_id in new_state["pending_metric_ids"]:
-                            new_state["pending_metric_ids"].remove(metric_id)
-                        # 重置失败计数
-                        new_state["failed_metric_attempts"].pop(metric_id, None)
-                    else:
-                        # 计算失败:记录失败次数,不从待计算列表移除
-                        new_state["failed_metric_attempts"][metric_id] = new_state["failed_metric_attempts"].get(metric_id, 0) + 1
-                        max_retries = 3
-                        if new_state["failed_metric_attempts"][metric_id] >= max_retries:
-                            print(f"⚠️ 指标 {metric_id} 已达到最大重试次数 ({max_retries}),从待计算列表中移除")
-                            if metric_id in new_state["pending_metric_ids"]:
-                                new_state["pending_metric_ids"].remove(metric_id)
+                    # 从待计算列表中移除(无论成功还是失败)
+                    if metric_id in new_state["pending_metric_ids"]:
+                        new_state["pending_metric_ids"].remove(metric_id)
 
                 except Exception as e:
                     print(f"❌ 计算指标 {metric_id} 时发生异常: {e}")
                     failed_calculations += 1
-
-                    # 初始化失败尝试记录
-                    if "failed_metric_attempts" not in new_state:
-                        new_state["failed_metric_attempts"] = {}
-
-                    # 记录失败次数
-                    new_state["failed_metric_attempts"][metric_id] = new_state["failed_metric_attempts"].get(metric_id, 0) + 1
-                    max_retries = 3
-                    if new_state["failed_metric_attempts"][metric_id] >= max_retries:
-                        print(f"⚠️ 指标 {metric_id} 异常已达到最大重试次数 ({max_retries}),从待计算列表中移除")
-                        if metric_id in new_state["pending_metric_ids"]:
-                            new_state["pending_metric_ids"].remove(metric_id)
+                    # 即使异常,也要从待计算列表中移除,避免无限循环
+                    if metric_id in new_state["pending_metric_ids"]:
+                        new_state["pending_metric_ids"].remove(metric_id)
 
             # 更新计算结果统计
             new_state["calculation_results"] = {
@@ -456,9 +498,10 @@ class CompleteAgentFlow:
     def _decision_to_route(self, decision: str) -> str:
         """将规划决策转换为路由"""
         decision_routes = {
+            "data_classify": "data_classify",
             "generate_outline": "outline_generator",
             "compute_metrics": "metric_calculator",
-            "finalize_report": END  # 直接结束流程
+            "finalize_report": "report_finalizer"
         }
         return decision_routes.get(decision, "planning_node")
 
@@ -480,7 +523,7 @@ class CompleteAgentFlow:
         except:
             return "🤔 规划决策已完成"
 
-    async def run_workflow(self, question: str, industry: str, data: List[Dict[str, Any]], session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
+    async def run_workflow(self, question: str, industry: str, data: List[Dict[str, Any]], file_name: str, session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
         """
         运行完整的工作流
 
@@ -488,6 +531,7 @@ class CompleteAgentFlow:
             question: 用户查询
             industry: 行业
             data: 数据集
+            file_name: 数据文件名称
             session_id: 会话ID
             use_rules_engine_only: 是否只使用规则引擎指标计算
             use_traditional_engine_only: 是否只使用传统引擎指标计算
@@ -499,6 +543,7 @@ class CompleteAgentFlow:
             print("🚀 启动完整智能体工作流...")
             print(f"问题:{question}")
             print(f"行业:{industry}")
+            print(f"数据文件:{file_name}")
             print(f"数据条数:{len(data)}")
 
             if use_rules_engine_only:
@@ -509,7 +554,7 @@ class CompleteAgentFlow:
                 print("计算模式:标准模式")
 
             # 创建初始状态
-            initial_state = create_initial_integrated_state(question, industry, data, session_id)
+            initial_state = create_initial_integrated_state(question, industry, data, file_name, session_id)
 
             # 设置计算模式标记
             if use_rules_engine_only:
@@ -553,13 +598,14 @@ class CompleteAgentFlow:
 
 
 # 便捷函数
-async def run_complete_agent_flow(question: str, industry: str, data: List[Dict[str, Any]], api_key: str, session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
+async def run_complete_agent_flow(question: str, industry: str, data: List[Dict[str, Any]], file_name: str, api_key: str, session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
     """
     运行完整智能体工作流的便捷函数
 
     Args:
         question: 用户查询
         data: 数据集
+        file_name: 数据文件名称
         api_key: API密钥
         session_id: 会话ID
         use_rules_engine_only: 是否只使用规则引擎指标计算
@@ -569,7 +615,7 @@ async def run_complete_agent_flow(question: str, industry: str, data: List[Dict[
         工作流结果
     """
     workflow = CompleteAgentFlow(api_key)
-    return await workflow.run_workflow(question, industry, data, session_id, use_rules_engine_only, use_traditional_engine_only)
+    return await workflow.run_workflow(question, industry, data, file_name, session_id, use_rules_engine_only, use_traditional_engine_only)
 
 
 # 主函数用于测试
@@ -585,23 +631,26 @@ async def main():
         print("❌ 未找到API密钥")
         return
 
-    # 测试数据
-    test_data = [
-        {
+    # 行业
+    industry = "农业"
 
-        }
-    ]
+    # 测试文件
+    file_name = "test_temp_agriculture_transaction_flow.csv"
+    curr_dir = os.path.dirname(os.path.abspath(__file__))
+    file_path = os.path.join(curr_dir, "..", "data_files", file_name)
 
-    print(f"📊 测试数据: {len(test_data)} 条记录")
+    # 加载测试数据集并展示两条样例
+    test_data = DataManager.load_data_from_csv_file(file_path)
 
+    print(f"📊 读取测试数据文件: {file_name} 数据, 加载 {len(test_data)} 条记录")
+    print(f"测试数据样例: {test_data[0:1]}")
 
     # 执行测试
     result = await run_complete_agent_flow(
         question="请生成一份详细的农业经营贷流水分析报告,需要包含:1.总收入和总支出统计 2.收入笔数和支出笔数 3.各类型收入支出占比分析 4.交易对手收入支出TOP3排名 5.按月份的收入支出趋势分析 6.账户数量和交易时间范围统计 7.资金流入流出月度统计等全面指标",
-        industry = "农业",
-        # question="请生成一份详细的黑色金属相关经营贷流水分析报告,需要包含:1.总收入统计 2.收入笔数 3.各类型收入占比分析 4.交易对手收入排名 5.按月份的收入趋势分析 6.账户数量和交易时间范围统计 7.资金流入流出月度统计等全面指标",
-        # industry = "黑色金属",
+        industry = industry,
         data=test_data,
+        file_name=file_name,
         api_key=config.DEEPSEEK_API_KEY,
         session_id="direct-test"
     )