|
|
@@ -0,0 +1,596 @@
|
|
|
+
|
|
|
+from typing import TypedDict, Any, List, Dict
|
|
|
+
|
|
|
+from langgraph.graph import StateGraph, START, END
|
|
|
+from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
|
|
|
+import pandas as pd
|
|
|
+import json
|
|
|
+
|
|
|
+from langgraph.prebuilt import create_react_agent
|
|
|
+from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
+from llmops.agents.datadev.memory.memory_saver_with_expiry2 import MemorySaverWithExpiry
|
|
|
+
|
|
|
+from llmops.agents.datadev.llm import get_llm
|
|
|
+from llmops.agents.tools.unerstand_dataset_tool import understand_dataset_structure
|
|
|
+
|
|
|
+import logging
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+logging.basicConfig(level=logging.DEBUG,
|
|
|
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
|
+ datefmt='%Y-%m-%d %H:%M:%S')
|
|
|
+
|
|
|
+class AgentState(TypedDict):
|
|
|
+ """
|
|
|
+ ChatBot智能体状态
|
|
|
+ """
|
|
|
+ question: str # 用户问题
|
|
|
+ data_set: List[Any] # 数据集 [{}]
|
|
|
+ transactions_df: pd.DataFrame # pd 数据集
|
|
|
+ intent_num: str # 意图分类ID(1-8)
|
|
|
+ history_messages: List[BaseMessage] # 历史消息
|
|
|
+ bank_type: str # 银行类型
|
|
|
+ transactions_df: pd.DataFrame # 数据集
|
|
|
+ analysis_results: Dict[str, Any] # 分析结果
|
|
|
+ current_step: str # 当前操作步骤
|
|
|
+ file_path: str # 数据集文件路径
|
|
|
+ file_hash: str # 数据集文件hash值
|
|
|
+ status: str # 状态 success | error
|
|
|
+ answer: Any # AI回答
|
|
|
+
|
|
|
+class TxFlowAnalysisAgent:
|
|
|
+ """交易流水分析智能体"""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ # self.llm_orchestrator = LLMOrchestrator()
|
|
|
+ self.llm = get_llm()
|
|
|
+
|
|
|
+ # 会话记忆
|
|
|
+ self.memory = MemorySaverWithExpiry(expire_time=600, clean_interval=60)
|
|
|
+
|
|
|
+ # 构建图
|
|
|
+ self.graph = self._build_graph()
|
|
|
+
|
|
|
+
|
|
|
+ def _build_graph(self):
|
|
|
+ # 构造计算图
|
|
|
+ graph_builder = StateGraph(AgentState)
|
|
|
+
|
|
|
+ # 添加节点
|
|
|
+ graph_builder.add_node("_intent_recognition", self._intent_recognition)
|
|
|
+ graph_builder.add_node("_say_hello", self._say_hello)
|
|
|
+ graph_builder.add_node("_intent_clarify", self._intent_clarify)
|
|
|
+ graph_builder.add_node("_beyond_ability", self._beyond_ability)
|
|
|
+ graph_builder.add_node("_data_op", self._data_op)
|
|
|
+ graph_builder.add_node("_gen_report", self._gen_report)
|
|
|
+ graph_builder.add_node("_analysis", self._analysis)
|
|
|
+ graph_builder.add_node("_statistics", self._statistics)
|
|
|
+ graph_builder.add_node("_query", self._query)
|
|
|
+ graph_builder.add_node("_invalid", self._invalid)
|
|
|
+
|
|
|
+ # 添加边
|
|
|
+ graph_builder.add_edge(START, "_intent_recognition")
|
|
|
+ graph_builder.add_edge("_say_hello", END)
|
|
|
+ graph_builder.add_edge("_intent_clarify", END)
|
|
|
+ graph_builder.add_edge("_beyond_ability", END)
|
|
|
+ graph_builder.add_edge("_data_op", END)
|
|
|
+ graph_builder.add_edge("_gen_report", END)
|
|
|
+ graph_builder.add_edge("_analysis", END)
|
|
|
+ graph_builder.add_edge("_statistics", END)
|
|
|
+ graph_builder.add_edge("_query", END)
|
|
|
+ graph_builder.add_edge("_invalid", END)
|
|
|
+
|
|
|
+ # 条件边
|
|
|
+ graph_builder.add_conditional_edges(source="_intent_recognition", path=self._router, path_map={
|
|
|
+ "_say_hello": "_say_hello",
|
|
|
+ "_intent_clarify": "_intent_clarify",
|
|
|
+ "_beyond_ability": "_beyond_ability",
|
|
|
+ "_data_op": "_data_op",
|
|
|
+ "_gen_report": "_gen_report",
|
|
|
+ "_analysis": "_analysis",
|
|
|
+ "_statistics": "_statistics",
|
|
|
+ "_query": "_query",
|
|
|
+ "_invalid": "_invalid"
|
|
|
+ })
|
|
|
+
|
|
|
+ return graph_builder.compile(checkpointer=self.memory)
|
|
|
+
|
|
|
+
|
|
|
+ def _intent_recognition(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 根据用户问题,识别用户意图
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+ 你是一位银行流水分析助手,专门识别用户问题的意图类别。
|
|
|
+
|
|
|
+ ### 任务说明
|
|
|
+ 根据用户问题和历史对话,判断用户意图属于以下哪个类别。直接输出对应的数字编号。
|
|
|
+
|
|
|
+ ### 历史对话:
|
|
|
+ {history_messages}
|
|
|
+
|
|
|
+ ### 当前用户问题:
|
|
|
+ {question}
|
|
|
+
|
|
|
+ ### 意图分类体系:
|
|
|
+ 1. **流水查询与检索** - 查找特定交易记录
|
|
|
+ - 示例:"查一下昨天给张三的转账"、"找一笔5000元的交易"
|
|
|
+ 2. **统计与汇总** - 计算数值、统计信息
|
|
|
+ - 示例:"本月总支出多少"、"工资收入总计多少"
|
|
|
+ 3. **洞察与分析** - 分析模式、趋势、异常
|
|
|
+ - 示例:"我的消费趋势如何"、"有没有异常交易"
|
|
|
+ 4. **生成报告** - 生成一份分析报告
|
|
|
+ - 示例:"生成一份收入与支持的分析报告"
|
|
|
+ 5. **数据操作与证明** - 验证交易
|
|
|
+ - 示例:"导出本月流水"、"查这笔交易的流水号"
|
|
|
+ 6. **超出能力边界** - 系统无法处理的问题
|
|
|
+ - 示例:"帮我转账"、"下个月我能存多少钱"
|
|
|
+ 7. **意图澄清** - 问题模糊需要进一步确认
|
|
|
+ - 示例:"花了多少钱"(无时间范围)、"查一下交易"(无具体条件)
|
|
|
+ 8. ** 打招呼 **
|
|
|
+ - 示例: "你好"
|
|
|
+
|
|
|
+ ### 输出规则:
|
|
|
+ - 只输出1个数字(1-8),不要任何解释
|
|
|
+ - 如果问题不明确,优先选7(意图澄清)
|
|
|
+ - 如果问题超出流水分析范围,选6(超出能力边界)
|
|
|
+ - 如果有多个意图,选最主要的一个
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ chain = pt | self.llm
|
|
|
+
|
|
|
+ # 从state获取历史消息
|
|
|
+ history_messages = state.get("history_messages", [])
|
|
|
+ # 转换为字符串格式
|
|
|
+ history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in history_messages])
|
|
|
+
|
|
|
+ try:
|
|
|
+
|
|
|
+ response = chain.invoke({"question": state["question"], "history_messages": history_str})
|
|
|
+ logger.info(f"对用户问题:{state['question']} 进行意图识别:{response.content}")
|
|
|
+
|
|
|
+ # 更新历史
|
|
|
+ new_question = HumanMessage(content=state["question"])
|
|
|
+ ai_message = AIMessage(content=response.content)
|
|
|
+ new_history = history_messages + [new_question] + [ai_message]
|
|
|
+
|
|
|
+ return {"intent_num": response.content, "history_messages": new_history[-100:]}
|
|
|
+ except Exception as e:
|
|
|
+ print(f"用户问题:{state['question']}意图识别异常,{str(e)}")
|
|
|
+ return {"status": "error"}
|
|
|
+
|
|
|
+ def _router(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 根据意图进行进行路由
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ intent_num = state["intent_num"]
|
|
|
+ if intent_num == "8":
|
|
|
+ return "_say_hello"
|
|
|
+ elif intent_num == "7": # 意图澄清
|
|
|
+ return "_intent_clarify"
|
|
|
+ elif intent_num == "6": # 超出能力边界
|
|
|
+ return "_beyond_ability"
|
|
|
+ elif intent_num == "5": # 数据操作与证明
|
|
|
+ return "_data_op"
|
|
|
+ elif intent_num == "4": # 生成报告
|
|
|
+ return "_gen_report"
|
|
|
+ elif intent_num == "3": # 洞察分析
|
|
|
+ return "_analysis"
|
|
|
+ elif intent_num == "2": # 统计汇总
|
|
|
+ return "_statistics"
|
|
|
+ elif intent_num == "1": # 流水明细查询
|
|
|
+ return "_query"
|
|
|
+ else:
|
|
|
+ return "_invalid"
|
|
|
+
|
|
|
+ def _say_hello(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 向客户打招呼
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+ 你是交易流水分析助手,用户在跟你打招呼,请合理组织话术并进行回答。
|
|
|
+
|
|
|
+ ### 用户说
|
|
|
+ {question}
|
|
|
+
|
|
|
+ ### 回答要求
|
|
|
+ - 友好
|
|
|
+ - 专业
|
|
|
+ - 快速
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ chain = pt | self.llm
|
|
|
+ try:
|
|
|
+ response = chain.invoke({"question": state["question"]})
|
|
|
+ logger.info(f"用户说:{state['question']}, AI回答: {response.content}")
|
|
|
+ return {"answer": response.content}
|
|
|
+ except Exception as e:
|
|
|
+ print(f"say_hello异常 {str(e)}")
|
|
|
+ return {"status": "error", "message": str(e)}
|
|
|
+
|
|
|
+
|
|
|
+ def _intent_clarify(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 意图澄清节点
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+ 你是交易流水分析助手,根据用户问题和历史对话,对用户的问题进行反问,以便充分理解用户意图。
|
|
|
+
|
|
|
+ ### 用户问题
|
|
|
+ {question}
|
|
|
+
|
|
|
+ ### 历史对话
|
|
|
+ {history_messages}
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ # 从state获取历史消息
|
|
|
+ history_messages = state.get("history_messages", [])
|
|
|
+ # 转换为字符串格式
|
|
|
+ history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in history_messages])
|
|
|
+
|
|
|
+ chain = pt | self.llm
|
|
|
+ try:
|
|
|
+ response = chain.invoke({
|
|
|
+ "question": state["question"],
|
|
|
+ "history_messages": history_str
|
|
|
+ })
|
|
|
+
|
|
|
+ # 更新历史对话
|
|
|
+ new_question = HumanMessage(content=state["question"])
|
|
|
+ ai_message = AIMessage(content=response.content)
|
|
|
+ new_history = history_messages + [new_question] + [ai_message]
|
|
|
+
|
|
|
+ return {"answer": response.content, "history_messages": new_history[-100:]}
|
|
|
+ except Exception as e:
|
|
|
+ return {"status": "error", "message": str(e)}
|
|
|
+
|
|
|
+ def _beyond_ability(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 超出能力边界处理节点
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+ 你是交易流水分析助手,对于用户的问题,已超出你的能力边界,组合合理的话术回答用户。
|
|
|
+
|
|
|
+ ### 用户问题
|
|
|
+ {question}
|
|
|
+
|
|
|
+ ### 要求
|
|
|
+ - 友好
|
|
|
+ - 专业
|
|
|
+ - 快速
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ chain = pt | self.llm
|
|
|
+ try:
|
|
|
+ response = chain.invoke({"question": state["question"]})
|
|
|
+ return {"answer": response.content}
|
|
|
+ except Exception as e:
|
|
|
+ return {"status": "error", "message": str(e)}
|
|
|
+
|
|
|
+ def _data_op(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 数据验证处理节点
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ return {"answer": "暂不支持"}
|
|
|
+
|
|
|
+ def _gen_report(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 生成报告节点
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+
|
|
|
+ """
|
|
|
+ return {"answer": {"report": {"title": "this is title", "content": "this is content"}}}
|
|
|
+
|
|
|
+ def _analysis(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 洞察分析处理节点
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ return {"answer": "这是洞察分析结果"}
|
|
|
+
|
|
|
+ def _statistics(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 统计汇总处理节点
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ return {"answer": "这是统计汇总处理结果"}
|
|
|
+
|
|
|
+ def _query(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 流水检索处理节点
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+
|
|
|
+ # 创建 react agent( reasoning/acting )
|
|
|
+ # 工具集,理解数据集结构
|
|
|
+ tools = [understand_dataset_structure]
|
|
|
+
|
|
|
+ system_prompt = f"""
|
|
|
+ 你是交易流水分析助手,根据用户问题和已有的数据集,检索出对应的数据并返回。
|
|
|
+
|
|
|
+ ### 用户问题
|
|
|
+ {state['question']}
|
|
|
+
|
|
|
+ ### 已有数据集
|
|
|
+ {state['data_set']}
|
|
|
+
|
|
|
+ ### 要求
|
|
|
+ - 首先要理解 数据集结构
|
|
|
+ - 必须在 已有数据集(已提供) 中进行检索
|
|
|
+ - 检索出 全部 符合要求的记录
|
|
|
+ - 如果没有符合要求的记录,则友好提示
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ agent = create_react_agent(model=self.llm, tools=tools, prompt=system_prompt)
|
|
|
+ response = agent.invoke({
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": f"""
|
|
|
+ ### 用户问题
|
|
|
+ {state['question']}
|
|
|
+ """
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ })
|
|
|
+ answer = response["messages"][-1].content
|
|
|
+
|
|
|
+ # 解析 LLM 的回答
|
|
|
+ try:
|
|
|
+ result = json.loads(answer)
|
|
|
+ if not result:
|
|
|
+ answer = "没有找到符合要求的记录"
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ answer = "LLM 的回答格式不正确,请重新提问"
|
|
|
+
|
|
|
+ logger.info(f"用户问题:{state['question']}, AI回答: {answer}")
|
|
|
+ return {"answer": answer}
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"_query 异常:{str(e)}")
|
|
|
+ return {"status": "error", "message": str(e)}
|
|
|
+
|
|
|
+
|
|
|
+ def _invalid(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 无效问题
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ return {"answer": "这是一个无效问题"}
|
|
|
+
|
|
|
+ def _knowledge_save(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 从用户问题,AI解答等过程中发现知识、存储知识
|
|
|
+ :param state:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+
|
|
|
+ async def process_upload(self, state: AgentState) -> AgentState:
|
|
|
+ """处理文件上传"""
|
|
|
+ print("🔍 处理文件上传...")
|
|
|
+ import os
|
|
|
+ file_path = state.get("file_path", "")
|
|
|
+ if not file_path or not os.path.exists(file_path):
|
|
|
+ state["messages"].append(AIMessage(content="未找到上传的流水文件,请重新上传。"))
|
|
|
+ return state
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 提取文本
|
|
|
+ pdf_text = self.pdf_tools.extract_text_from_pdf(file_path)
|
|
|
+
|
|
|
+ # 检测银行类型
|
|
|
+ bank_type = self.pdf_tools.detect_bank_type(pdf_text)
|
|
|
+ state["bank_type"] = bank_type
|
|
|
+
|
|
|
+ # 解析交易数据
|
|
|
+ if "招商银行" in bank_type:
|
|
|
+ transactions_json = self.pdf_tools.parse_cmb_statement(pdf_text)
|
|
|
+ transactions = json.loads(transactions_json)
|
|
|
+
|
|
|
+ # 转换为DataFrame
|
|
|
+ df = pd.DataFrame(transactions)
|
|
|
+ df['date'] = pd.to_datetime(df['date'])
|
|
|
+ df = df.sort_values('date')
|
|
|
+
|
|
|
+ # 添加分类
|
|
|
+ df['category'] = df['description'].apply(
|
|
|
+ lambda x: self._categorize_transaction(x)
|
|
|
+ )
|
|
|
+
|
|
|
+ state["transactions_df"] = df
|
|
|
+
|
|
|
+ # 创建分析工具实例
|
|
|
+ self.analysis_tools = AnalysisTools(df)
|
|
|
+
|
|
|
+ summary = self.analysis_tools.get_summary_statistics()
|
|
|
+
|
|
|
+ response = f"""
|
|
|
+ ✅ 文件解析成功!
|
|
|
+
|
|
|
+ 📊 基本信息:
|
|
|
+ - 银行类型:{bank_type}
|
|
|
+ - 交易笔数:{summary['total_transactions']} 笔
|
|
|
+ - 时间范围:{summary['date_range']['start']} 至 {summary['date_range']['end']}
|
|
|
+ - 总收入:¥{summary['total_income']:,.2f}
|
|
|
+ - 总支出:¥{summary['total_expense']:,.2f}
|
|
|
+ - 净现金流:¥{summary['net_cash_flow']:,.2f}
|
|
|
+
|
|
|
+ 您可以问我:
|
|
|
+ 1. "分析我的消费模式"
|
|
|
+ 2. "检测异常交易"
|
|
|
+ 3. "生成财务报告"
|
|
|
+ 4. "查询大额交易"
|
|
|
+ 5. "预测未来余额"
|
|
|
+ """
|
|
|
+
|
|
|
+ state["messages"].append(AIMessage(content=response))
|
|
|
+
|
|
|
+ else:
|
|
|
+ state["messages"].append(AIMessage(
|
|
|
+ content=f"检测到{bank_type},当前主要支持招商银行格式,其他银行功能正在开发中。"
|
|
|
+ ))
|
|
|
+
|
|
|
+ state["current_step"] = "analysis_ready"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ state["messages"].append(AIMessage(
|
|
|
+ content=f"文件解析失败:{str(e)}"
|
|
|
+ ))
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ def ask_question(self, session: str, question: str, data_set_file: str = None) -> dict:
|
|
|
+ """
|
|
|
+ 程序调用入口,用户提问
|
|
|
+ :param session: 会话ID
|
|
|
+ :param question: 用户问题
|
|
|
+ :param data_set_file: 数据集文档(json数据文件)
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+
|
|
|
+ data_set = []
|
|
|
+ # 读取 json 文件数据, 格式为 List[Dict]
|
|
|
+ import json
|
|
|
+ import os
|
|
|
+ try:
|
|
|
+ logger.info(f"传入的数据文件路径: {data_set_file}")
|
|
|
+ if data_set_file:
|
|
|
+ if os.path.exists(data_set_file):
|
|
|
+ with open(data_set_file, 'r', encoding='utf-8') as file:
|
|
|
+ data_set = json.load(file)
|
|
|
+ logger.info(f"加载数据条数:{len(data_set)}")
|
|
|
+ # 加载数据集
|
|
|
+ df = pd.DataFrame(data_set)
|
|
|
+ df['txDate'] = pd.to_datetime(df['txDate']) # 提前转换类型
|
|
|
+ else:
|
|
|
+ logger.error(f"数据文件:{data_set_file} 不存在,请检查路径!")
|
|
|
+ else:
|
|
|
+ logger.info("未传入数据集文件")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"读取数据文件:{data_set_file}异常,{str(e)}")
|
|
|
+ return {
|
|
|
+ "status": "error",
|
|
|
+ "message": f"读取数据文件:{data_set_file}异常,{str(e)}"
|
|
|
+ }
|
|
|
+
|
|
|
+ if not session or not question:
|
|
|
+ return {
|
|
|
+ "status": "error",
|
|
|
+ "message": "缺少参数"
|
|
|
+ }
|
|
|
+ result = {
|
|
|
+ "status": "success"
|
|
|
+ }
|
|
|
+ try:
|
|
|
+ config = {"configurable": {"thread_id": session}}
|
|
|
+ current_state = self.memory.get(config)
|
|
|
+ history_messages = current_state["channel_values"]["history_messages"] if current_state else []
|
|
|
+ if len(data_set) == 0: # 没有指定数据集
|
|
|
+ data_set = current_state["channel_values"]["data_set"] if current_state else []
|
|
|
+
|
|
|
+ # 执行
|
|
|
+ response = self.graph.invoke({
|
|
|
+ "session": session,
|
|
|
+ "question": question,
|
|
|
+ "data_set": data_set,
|
|
|
+ "transactions_df": df,
|
|
|
+ "history_messages": history_messages,
|
|
|
+ }, config)
|
|
|
+ result["answer"] = response.get("answer","")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"用户对话失败,异常:{str(e)}")
|
|
|
+ result["status"] = "error"
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 构建智能体图 ====================
|
|
|
+def create_bank_statement_agent():
|
|
|
+ """创建银行流水分析智能体图"""
|
|
|
+
|
|
|
+ # 创建节点
|
|
|
+ nodes = TxFlowAnalysisAgent()
|
|
|
+
|
|
|
+ # 创建状态图
|
|
|
+ workflow = StateGraph(AgentState)
|
|
|
+
|
|
|
+ # 添加节点
|
|
|
+ workflow.add_node("process_upload", nodes.process_upload)
|
|
|
+ workflow.add_node("analyze_query", nodes.analyze_query)
|
|
|
+ workflow.add_node("waiting_for_input", lambda state: state) # 等待输入
|
|
|
+
|
|
|
+ # 设置入口点
|
|
|
+ workflow.set_entry_point("waiting_for_input")
|
|
|
+
|
|
|
+ # 添加条件边
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "waiting_for_input",
|
|
|
+ nodes.route_query,
|
|
|
+ {
|
|
|
+ "process_upload": "process_upload",
|
|
|
+ "analyze_query": "analyze_query",
|
|
|
+ "waiting_for_input": "waiting_for_input"
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加普通边
|
|
|
+ workflow.add_edge("process_upload", "waiting_for_input")
|
|
|
+ workflow.add_edge("analyze_query", "waiting_for_input")
|
|
|
+
|
|
|
+ # 编译图
|
|
|
+ graph = workflow.compile()
|
|
|
+
|
|
|
+ return graph
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ agent = TxFlowAnalysisAgent()
|
|
|
+ question = "你好"
|
|
|
+ # 数据集文件
|
|
|
+ data_file = "/Applications/work/宇信科技/知识沉淀平台/原始数据-流水分析-农业原始数据.json"
|
|
|
+ result = agent.ask_question(session="s1", question=question, data_set_file=data_file)
|
|
|
+ print(f"问题:{question}, 响应:{result}")
|
|
|
+
|
|
|
+ question = "查询交易日期是2023-01-05对应的收入记录"
|
|
|
+ result = agent.ask_question(session="s1", question=question, data_set_file=data_file)
|
|
|
+ print(f"问题:{question}, 响应:{result}")
|
|
|
+
|
|
|
+ question = "查询交易对手是绿源农产品公司的记录"
|
|
|
+ result = agent.ask_question(session="s1", question=question, data_set_file=data_file)
|
|
|
+ print(f"问题:{question}, 响应:{result}")
|
|
|
+
|
|
|
+ # question = "查找转给贾强的交易记录"
|
|
|
+ # result = agent.ask_question(session="s1", question=question)
|
|
|
+ # print(f"问题:{question}, 意图:{result}")
|
|
|
+ # question = "花了多少钱"
|
|
|
+ # result = agent.ask_question(session="s1", question=question)
|
|
|
+ # print(f"问题:{question}, 意图:{result}")
|
|
|
+ # question = "统计用户总收入"
|
|
|
+ # result = agent.ask_question(session="s1", question=question)
|
|
|
+ # print(f"问题:{question}, 意图:{result}")
|
|
|
+ # question = "生成分析报告"
|
|
|
+ # result = agent.ask_question(session="s1", question=question)
|
|
|
+ # print(f"问题:{question}, 意图:{result}")
|
|
|
+ # question = "转账给张三"
|
|
|
+ # result = agent.ask_question(session="s1", question=question)
|
|
|
+ # print(f"问题:{question}, 意图:{result}")
|
|
|
+ # question = "生成一份异常分析报告"
|
|
|
+ # result = agent.ask_question(session="s1", question=question)
|
|
|
+ # print(f"问题:{question}, 意图:{result}")
|