|
|
@@ -32,19 +32,21 @@ from typing import Dict, Any, List
|
|
|
from datetime import datetime
|
|
|
from langgraph.graph import StateGraph, START, END
|
|
|
|
|
|
-from workflow_state import (
|
|
|
+from llmops.workflow_state import (
|
|
|
IntegratedWorkflowState,
|
|
|
create_initial_integrated_state,
|
|
|
get_calculation_progress,
|
|
|
update_state_with_outline_generation,
|
|
|
update_state_with_planning_decision,
|
|
|
+ update_state_with_data_classified,
|
|
|
convert_numpy_types,
|
|
|
-
|
|
|
)
|
|
|
-from agents.outline_agent import OutlineGeneratorAgent, generate_report_outline
|
|
|
-from agents.planning_agent import PlanningAgent, plan_next_action, analyze_current_state
|
|
|
-from agents.rules_engine_metric_calculation_agent import RulesEngineMetricCalculationAgent
|
|
|
-
|
|
|
+from llmops.agents.outline_agent import generate_report_outline
|
|
|
+from llmops.agents.planning_agent import plan_next_action
|
|
|
+from llmops.agents.rules_engine_metric_calculation_agent import RulesEngineMetricCalculationAgent
|
|
|
+from llmops.agents.data_manager import DataManager
|
|
|
+import os
|
|
|
+from llmops.agents.data_classify_agent import data_classify
|
|
|
|
|
|
class CompleteAgentFlow:
|
|
|
"""完整的智能体工作流"""
|
|
|
@@ -74,6 +76,7 @@ class CompleteAgentFlow:
|
|
|
workflow.add_node("planning_node", self._planning_node)
|
|
|
workflow.add_node("outline_generator", self._outline_generator_node)
|
|
|
workflow.add_node("metric_calculator", self._metric_calculator_node)
|
|
|
+ workflow.add_node("data_classify", self._data_classify_node)
|
|
|
|
|
|
# 设置入口点
|
|
|
workflow.set_entry_point("planning_node")
|
|
|
@@ -85,13 +88,15 @@ class CompleteAgentFlow:
|
|
|
{
|
|
|
"outline_generator": "outline_generator",
|
|
|
"metric_calculator": "metric_calculator",
|
|
|
+ "data_classify": "data_classify",
|
|
|
END: END
|
|
|
}
|
|
|
)
|
|
|
|
|
|
# 从各个节点返回规划节点重新决策
|
|
|
+ workflow.add_edge("data_classify", "planning_node")
|
|
|
workflow.add_edge("outline_generator", "planning_node")
|
|
|
- workflow.add_edge("metric_calculator", "planning_node")
|
|
|
+ workflow.add_edge("metric_calculator", END)
|
|
|
|
|
|
return workflow
|
|
|
|
|
|
@@ -106,6 +111,7 @@ class CompleteAgentFlow:
|
|
|
目标节点名称
|
|
|
"""
|
|
|
print(f"\n🔍 [路由决策] 步骤={state['planning_step']}, "
|
|
|
+ f"数据集分类打标数量={len(state.get('data_set_classified', []))}",
|
|
|
f"大纲={state.get('outline_draft') is not None}, "
|
|
|
f"指标需求={len(state.get('metrics_requirements', []))}")
|
|
|
|
|
|
@@ -114,11 +120,21 @@ class CompleteAgentFlow:
|
|
|
print("⚠️ 规划步骤超过30次,强制结束流程")
|
|
|
return END
|
|
|
|
|
|
+ # 数据分类打标数量为0 → 分类打标
|
|
|
+ if len(state.get("data_set_classified", [])) == 0:
|
|
|
+ print("→ 路由到 data_classify(分类打标)")
|
|
|
+ return "data_classify"
|
|
|
+
|
|
|
# 如果大纲为空 → 生成大纲
|
|
|
if not state.get("outline_draft"):
|
|
|
print("→ 路由到 outline_generator(生成大纲)")
|
|
|
return "outline_generator"
|
|
|
|
|
|
+ # 如果指标需求为空但大纲已生成 → 评估指标需求
|
|
|
+ if not state.get("metrics_requirements") and state.get("outline_draft"):
|
|
|
+ print("→ 路由到 metric_evaluator(评估指标需求)")
|
|
|
+ return "metric_evaluator"
|
|
|
+
|
|
|
# 计算覆盖率
|
|
|
progress = get_calculation_progress(state)
|
|
|
coverage = progress["coverage_rate"]
|
|
|
@@ -130,22 +146,10 @@ class CompleteAgentFlow:
|
|
|
print(f"→ 路由到 metric_calculator(计算指标,覆盖率={coverage:.2%})")
|
|
|
return "metric_calculator"
|
|
|
|
|
|
- # 检查是否应该结束流程
|
|
|
- pending_ids = state.get("pending_metric_ids", [])
|
|
|
- failed_attempts = state.get("failed_metric_attempts", {})
|
|
|
- max_retries = 3
|
|
|
-
|
|
|
- # 计算还有哪些指标可以重试(未达到最大重试次数)
|
|
|
- retryable_metrics = [
|
|
|
- mid for mid in pending_ids
|
|
|
- if failed_attempts.get(mid, 0) < max_retries
|
|
|
- ]
|
|
|
-
|
|
|
- # 如果覆盖率 >= 80%,或者没有可重试的指标 → 结束流程
|
|
|
- if coverage >= 0.8 or not retryable_metrics:
|
|
|
- reason = "覆盖率达到80%" if coverage >= 0.8 else "没有可重试指标"
|
|
|
- print(f"→ 结束流程(覆盖率={coverage:.2%},原因:{reason})")
|
|
|
- return END
|
|
|
+ # 如果没有待计算指标或覆盖率 >= 80% → 生成最终报告
|
|
|
+ if not state.get("pending_metric_ids") or coverage >= 0.8:
|
|
|
+ print(f"→ 路由到 report_finalizer(生成最终报告,覆盖率={coverage:.2%})")
|
|
|
+ return "report_finalizer"
|
|
|
|
|
|
# 默认返回规划节点
|
|
|
return "planning_node"
|
|
|
@@ -219,6 +223,32 @@ class CompleteAgentFlow:
|
|
|
new_state["errors"].append(f"大纲生成错误: {str(e)}")
|
|
|
return convert_numpy_types(new_state)
|
|
|
|
|
|
+ async def _data_classify_node(self, state: IntegratedWorkflowState) -> IntegratedWorkflowState:
|
|
|
+ """数据分类打标节点"""
|
|
|
+ try:
|
|
|
+ print("📝 正在对数据进行分类打标...")
|
|
|
+
|
|
|
+ # 对数据进行分类打标
|
|
|
+ data_set_classified = await data_classify(
|
|
|
+ industry=state["industry"],
|
|
|
+ data_set=state["data_set"],
|
|
|
+ file_name=state["file_name"]
|
|
|
+ )
|
|
|
+
|
|
|
+ # 更新状态
|
|
|
+ new_state = update_state_with_data_classified(state, data_set_classified)
|
|
|
+
|
|
|
+ print(f"✅ 数据分类打标完成,打标记录数: {len(data_set_classified)}")
|
|
|
+
|
|
|
+ return convert_numpy_types(new_state)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 数据分类打标失败: {e}")
|
|
|
+ new_state = state.copy()
|
|
|
+ new_state["errors"].append(f"数据分类打标错误: {str(e)}")
|
|
|
+ return convert_numpy_types(new_state)
|
|
|
+
|
|
|
+
|
|
|
def _print_ai_selection_analysis(self, outline):
|
|
|
"""打印AI指标选择的推理过程分析 - 完全通用版本"""
|
|
|
print()
|
|
|
@@ -270,6 +300,50 @@ class CompleteAgentFlow:
|
|
|
print(' 能够根据具体业务场景动态调整分析框架,确保分析的针对性和有效性。')
|
|
|
print()
|
|
|
|
|
|
+ async def _metric_evaluator_node(self, state: IntegratedWorkflowState) -> IntegratedWorkflowState:
|
|
|
+ """指标评估节点:根据大纲确定需要计算的指标"""
|
|
|
+ try:
|
|
|
+ print("🔍 正在评估指标需求...")
|
|
|
+
|
|
|
+ new_state = state.copy()
|
|
|
+ outline = state.get("outline_draft")
|
|
|
+
|
|
|
+ if not outline:
|
|
|
+ print("⚠️ 没有大纲信息,跳过指标评估")
|
|
|
+ return convert_numpy_types(new_state)
|
|
|
+
|
|
|
+ # 从大纲中提取指标需求
|
|
|
+ metrics_requirements = outline.global_metrics
|
|
|
+ metric_ids = [m.metric_id for m in metrics_requirements]
|
|
|
+
|
|
|
+ # 设置待计算指标
|
|
|
+ new_state["metrics_requirements"] = metrics_requirements
|
|
|
+ new_state["pending_metric_ids"] = metric_ids.copy()
|
|
|
+ new_state["computed_metrics"] = {}
|
|
|
+ new_state["metrics_cache"] = {}
|
|
|
+
|
|
|
+ print(f"✅ 指标评估完成,发现 {len(metric_ids)} 个待计算指标")
|
|
|
+ for i, metric_id in enumerate(metric_ids[:5], 1): # 只显示前5个
|
|
|
+ print(f" {i}. {metric_id}")
|
|
|
+
|
|
|
+ if len(metric_ids) > 5:
|
|
|
+ print(f" ... 还有 {len(metric_ids) - 5} 个指标")
|
|
|
+
|
|
|
+ # 添加消息
|
|
|
+ new_state["messages"].append({
|
|
|
+ "role": "assistant",
|
|
|
+ "content": f"🔍 指标评估完成:发现 {len(metric_ids)} 个待计算指标",
|
|
|
+ "timestamp": datetime.now().isoformat()
|
|
|
+ })
|
|
|
+
|
|
|
+ return convert_numpy_types(new_state)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 指标评估失败: {e}")
|
|
|
+ new_state = state.copy()
|
|
|
+ new_state["errors"].append(f"指标评估错误: {str(e)}")
|
|
|
+ return convert_numpy_types(new_state)
|
|
|
+
|
|
|
async def _metric_calculator_node(self, state: IntegratedWorkflowState) -> IntegratedWorkflowState:
|
|
|
"""指标计算节点"""
|
|
|
try:
|
|
|
@@ -315,15 +389,11 @@ class CompleteAgentFlow:
|
|
|
# 找到对应的指标需求
|
|
|
metric_req = next((m for m in metrics_requirements if m.metric_id == metric_id), None)
|
|
|
if not metric_req:
|
|
|
- # 修复:找不到指标需求时,创建临时的指标需求结构,避免跳过指标
|
|
|
- print(f"⚠️ 指标 {metric_id} 找不到需求信息,创建临时配置继续计算")
|
|
|
- metric_req = type('MetricRequirement', (), {
|
|
|
- 'metric_id': metric_id,
|
|
|
- 'metric_name': metric_id.replace('metric-', '') if metric_id.startswith('metric-') else metric_id,
|
|
|
- 'calculation_logic': f'计算 {metric_id}',
|
|
|
- 'required_fields': ['transactions'],
|
|
|
- 'dependencies': []
|
|
|
- })()
|
|
|
+ print(f"⚠️ 找不到指标 {metric_id} 的需求信息,跳过")
|
|
|
+ # 仍然从待计算列表中移除,避免无限循环
|
|
|
+ if metric_id in new_state["pending_metric_ids"]:
|
|
|
+ new_state["pending_metric_ids"].remove(metric_id)
|
|
|
+ continue
|
|
|
|
|
|
print(f"🧮 计算指标: {metric_id} - {metric_req.metric_name}")
|
|
|
|
|
|
@@ -366,55 +436,27 @@ class CompleteAgentFlow:
|
|
|
}
|
|
|
|
|
|
# 处理计算结果
|
|
|
- calculation_success = False
|
|
|
for result in results.get("results", []):
|
|
|
if result.get("result", {}).get("success"):
|
|
|
# 计算成功
|
|
|
new_state["computed_metrics"][metric_id] = result["result"]
|
|
|
successful_calculations += 1
|
|
|
- calculation_success = True
|
|
|
print(f"✅ 指标 {metric_id} 计算成功")
|
|
|
- break # 找到一个成功的就算成功
|
|
|
else:
|
|
|
# 计算失败
|
|
|
failed_calculations += 1
|
|
|
print(f"❌ 指标 {metric_id} 计算失败")
|
|
|
|
|
|
- # 初始化失败尝试记录
|
|
|
- if "failed_metric_attempts" not in new_state:
|
|
|
- new_state["failed_metric_attempts"] = {}
|
|
|
-
|
|
|
- # 根据计算结果处理指标
|
|
|
- if calculation_success:
|
|
|
- # 计算成功:从待计算列表中移除
|
|
|
- if metric_id in new_state["pending_metric_ids"]:
|
|
|
- new_state["pending_metric_ids"].remove(metric_id)
|
|
|
- # 重置失败计数
|
|
|
- new_state["failed_metric_attempts"].pop(metric_id, None)
|
|
|
- else:
|
|
|
- # 计算失败:记录失败次数,不从待计算列表移除
|
|
|
- new_state["failed_metric_attempts"][metric_id] = new_state["failed_metric_attempts"].get(metric_id, 0) + 1
|
|
|
- max_retries = 3
|
|
|
- if new_state["failed_metric_attempts"][metric_id] >= max_retries:
|
|
|
- print(f"⚠️ 指标 {metric_id} 已达到最大重试次数 ({max_retries}),从待计算列表中移除")
|
|
|
- if metric_id in new_state["pending_metric_ids"]:
|
|
|
- new_state["pending_metric_ids"].remove(metric_id)
|
|
|
+ # 从待计算列表中移除(无论成功还是失败)
|
|
|
+ if metric_id in new_state["pending_metric_ids"]:
|
|
|
+ new_state["pending_metric_ids"].remove(metric_id)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"❌ 计算指标 {metric_id} 时发生异常: {e}")
|
|
|
failed_calculations += 1
|
|
|
-
|
|
|
- # 初始化失败尝试记录
|
|
|
- if "failed_metric_attempts" not in new_state:
|
|
|
- new_state["failed_metric_attempts"] = {}
|
|
|
-
|
|
|
- # 记录失败次数
|
|
|
- new_state["failed_metric_attempts"][metric_id] = new_state["failed_metric_attempts"].get(metric_id, 0) + 1
|
|
|
- max_retries = 3
|
|
|
- if new_state["failed_metric_attempts"][metric_id] >= max_retries:
|
|
|
- print(f"⚠️ 指标 {metric_id} 异常已达到最大重试次数 ({max_retries}),从待计算列表中移除")
|
|
|
- if metric_id in new_state["pending_metric_ids"]:
|
|
|
- new_state["pending_metric_ids"].remove(metric_id)
|
|
|
+ # 即使异常,也要从待计算列表中移除,避免无限循环
|
|
|
+ if metric_id in new_state["pending_metric_ids"]:
|
|
|
+ new_state["pending_metric_ids"].remove(metric_id)
|
|
|
|
|
|
# 更新计算结果统计
|
|
|
new_state["calculation_results"] = {
|
|
|
@@ -456,9 +498,10 @@ class CompleteAgentFlow:
|
|
|
def _decision_to_route(self, decision: str) -> str:
|
|
|
"""将规划决策转换为路由"""
|
|
|
decision_routes = {
|
|
|
+ "data_classify": "data_classify",
|
|
|
"generate_outline": "outline_generator",
|
|
|
"compute_metrics": "metric_calculator",
|
|
|
- "finalize_report": END # 直接结束流程
|
|
|
+ "finalize_report": "report_finalizer"
|
|
|
}
|
|
|
return decision_routes.get(decision, "planning_node")
|
|
|
|
|
|
@@ -480,7 +523,7 @@ class CompleteAgentFlow:
|
|
|
except:
|
|
|
return "🤔 规划决策已完成"
|
|
|
|
|
|
- async def run_workflow(self, question: str, industry: str, data: List[Dict[str, Any]], session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
|
|
|
+ async def run_workflow(self, question: str, industry: str, data: List[Dict[str, Any]], file_name: str, session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
|
|
|
"""
|
|
|
运行完整的工作流
|
|
|
|
|
|
@@ -488,6 +531,7 @@ class CompleteAgentFlow:
|
|
|
question: 用户查询
|
|
|
industry: 行业
|
|
|
data: 数据集
|
|
|
+ file_name: 数据文件名称
|
|
|
session_id: 会话ID
|
|
|
use_rules_engine_only: 是否只使用规则引擎指标计算
|
|
|
use_traditional_engine_only: 是否只使用传统引擎指标计算
|
|
|
@@ -499,6 +543,7 @@ class CompleteAgentFlow:
|
|
|
print("🚀 启动完整智能体工作流...")
|
|
|
print(f"问题:{question}")
|
|
|
print(f"行业:{industry}")
|
|
|
+ print(f"数据文件:{file_name}")
|
|
|
print(f"数据条数:{len(data)}")
|
|
|
|
|
|
if use_rules_engine_only:
|
|
|
@@ -509,7 +554,7 @@ class CompleteAgentFlow:
|
|
|
print("计算模式:标准模式")
|
|
|
|
|
|
# 创建初始状态
|
|
|
- initial_state = create_initial_integrated_state(question, industry, data, session_id)
|
|
|
+ initial_state = create_initial_integrated_state(question, industry, data, file_name, session_id)
|
|
|
|
|
|
# 设置计算模式标记
|
|
|
if use_rules_engine_only:
|
|
|
@@ -553,13 +598,14 @@ class CompleteAgentFlow:
|
|
|
|
|
|
|
|
|
# 便捷函数
|
|
|
-async def run_complete_agent_flow(question: str, industry: str, data: List[Dict[str, Any]], api_key: str, session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
|
|
|
+async def run_complete_agent_flow(question: str, industry: str, data: List[Dict[str, Any]], file_name: str, api_key: str, session_id: str = None, use_rules_engine_only: bool = False, use_traditional_engine_only: bool = False) -> Dict[str, Any]:
|
|
|
"""
|
|
|
运行完整智能体工作流的便捷函数
|
|
|
|
|
|
Args:
|
|
|
question: 用户查询
|
|
|
data: 数据集
|
|
|
+ file_name: 数据文件名称
|
|
|
api_key: API密钥
|
|
|
session_id: 会话ID
|
|
|
use_rules_engine_only: 是否只使用规则引擎指标计算
|
|
|
@@ -569,7 +615,7 @@ async def run_complete_agent_flow(question: str, industry: str, data: List[Dict[
|
|
|
工作流结果
|
|
|
"""
|
|
|
workflow = CompleteAgentFlow(api_key)
|
|
|
- return await workflow.run_workflow(question, industry, data, session_id, use_rules_engine_only, use_traditional_engine_only)
|
|
|
+ return await workflow.run_workflow(question, industry, data, file_name, session_id, use_rules_engine_only, use_traditional_engine_only)
|
|
|
|
|
|
|
|
|
# 主函数用于测试
|
|
|
@@ -585,23 +631,26 @@ async def main():
|
|
|
print("❌ 未找到API密钥")
|
|
|
return
|
|
|
|
|
|
- # 测试数据
|
|
|
- test_data = [
|
|
|
- {
|
|
|
+ # 行业
|
|
|
+ industry = "农业"
|
|
|
|
|
|
- }
|
|
|
- ]
|
|
|
+ # 测试文件
|
|
|
+ file_name = "test_temp_agriculture_transaction_flow.csv"
|
|
|
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+ file_path = os.path.join(curr_dir, "..", "data_files", file_name)
|
|
|
|
|
|
- print(f"📊 测试数据: {len(test_data)} 条记录")
|
|
|
+ # 加载测试数据集并展示两条样例
|
|
|
+ test_data = DataManager.load_data_from_csv_file(file_path)
|
|
|
|
|
|
+ print(f"📊 读取测试数据文件: {file_name} 数据, 加载 {len(test_data)} 条记录")
|
|
|
+ print(f"测试数据样例: {test_data[0:1]}")
|
|
|
|
|
|
# 执行测试
|
|
|
result = await run_complete_agent_flow(
|
|
|
question="请生成一份详细的农业经营贷流水分析报告,需要包含:1.总收入和总支出统计 2.收入笔数和支出笔数 3.各类型收入支出占比分析 4.交易对手收入支出TOP3排名 5.按月份的收入支出趋势分析 6.账户数量和交易时间范围统计 7.资金流入流出月度统计等全面指标",
|
|
|
- industry = "农业",
|
|
|
- # question="请生成一份详细的黑色金属相关经营贷流水分析报告,需要包含:1.总收入统计 2.收入笔数 3.各类型收入占比分析 4.交易对手收入排名 5.按月份的收入趋势分析 6.账户数量和交易时间范围统计 7.资金流入流出月度统计等全面指标",
|
|
|
- # industry = "黑色金属",
|
|
|
+ industry = industry,
|
|
|
data=test_data,
|
|
|
+ file_name=file_name,
|
|
|
api_key=config.DEEPSEEK_API_KEY,
|
|
|
session_id="direct-test"
|
|
|
)
|