Browse Source

调整流水分析接口,加入数据标准化处理

jiaqiang 3 days ago
parent
commit
05c1b137ad
3 changed files with 70 additions and 28 deletions
  1. 9 1
      llmops/batch_runner.py
  2. 7 4
      llmops/complete_agent_flow_rule.py
  3. 54 23
      llmops/main.py

+ 9 - 1
llmops/batch_runner.py

@@ -51,7 +51,7 @@ else:
 import config
 
 # ========== 配置参数 ==========
-RUNS = 2  # 运行次数
+RUNS = 10  # 运行次数
 INDUSTRY = "农业"  # 行业
 DATA_FILE = "data_files/交易流水样例数据.csv"  # 数据文件路径
 QUESTION = "请生成一份详细的农业经营贷流水分析报告,需要包含:1.总收入和总支出统计 2.收入笔数和支出笔数 3.各类型收入支出占比分析 4.交易对手收入支出TOP3排名 5.按月份的收入支出趋势分析 6.账户数量和交易时间范围统计 7.资金流入流出月度统计等全面指标"  # 分析查询
@@ -147,10 +147,17 @@ async def run_batch(runs: int, question: str, industry: str, data_file: str):
     failed_runs = 0
     results = []
 
+    # 运行总时长,秒
+    total_time = 0
+    import time
     # 逐个运行
     for i in range(1, runs + 1):
         run_id = str(i)
+        start_time = time.perf_counter()
+        # 单次执行
         result = await run_single_flow(run_id, question, industry, data, os.path.basename(data_file))
+        end_time = time.perf_counter()
+        total_time += (end_time - start_time)
 
         results.append(result)
 
@@ -168,6 +175,7 @@ async def run_batch(runs: int, question: str, industry: str, data_file: str):
     print("📊 批量运行完成统计")
     print(f"{'='*80}")
     print(f"总运行次数: {runs}")
+    print(f"总运行总用时: {total_time:.2f}秒,单次用时:{total_time/runs:.2f}秒")
     print(f"成功次数: {successful_runs}")
     print(f"失败次数: {failed_runs}")
     print(f"成功率: {successful_runs/runs*100:.1f}%")

+ 7 - 4
llmops/complete_agent_flow_rule.py

@@ -625,12 +625,15 @@ async def run_complete_agent_flow(question: str, industry: str, data: List[Dict[
 # 主函数用于测试
 async def main():
     """主函数:执行系统测试"""
+    import os
+    os.environ["LANGCHAIN_TRACING_V2"] = "false"
+    os.environ["LANGCHAIN_API_KEY"] = ""
+    # 禁用 LangGraph 的追踪
+    os.environ["LANGSMITH_TRACING"] = "false"
+
     print("🚀 执行CompleteAgentFlow系统测试")
     print("=" * 50)
 
-    # 导入配置
-    import config
-
     if not DEEPSEEK_API_KEY:
         print("❌ 未找到API密钥")
         return
@@ -655,7 +658,7 @@ async def main():
         industry = industry,
         data=test_data,
         file_name=file_name,
-        api_key=config.DEEPSEEK_API_KEY,
+        api_key=DEEPSEEK_API_KEY,
         session_id="direct-test"
     )
 

+ 54 - 23
llmops/main.py

@@ -8,6 +8,14 @@ import io
 import urllib.parse
 from llmops.complete_agent_flow_rule import run_complete_agent_flow
 from llmops.config import DEEPSEEK_API_KEY
+from llmops.agents.data_stardard import TransactionParserAgent
+from llmops.config import multimodal_api_url
+
+
+os.environ["LANGCHAIN_TRACING_V2"] = "false"
+os.environ["LANGCHAIN_API_KEY"] = ""
+# 禁用 LangGraph 的追踪
+os.environ["LANGSMITH_TRACING"] = "false"
 
 # 初始化FastAPI应用
 app = FastAPI(
@@ -93,12 +101,22 @@ async def dataset_classify(file: UploadFile = File(...), industry: str = Form(..
         traceback.print_exc()
 
 
+
+
+
+
+# 数据标准化agent
+standard_agent = TransactionParserAgent(
+            api_key=DEEPSEEK_API_KEY,
+            multimodal_api_url=multimodal_api_url
+)
+
 @app.post("/api/report/gen")
 async def gen_report(file: UploadFile = File(...), question: str = Form(...), industry: str = Form(...)):
     """
-    上传原始数据文件(格式CSV),输入问题和行业,生成对应的分析报告
+    上传原始数据文件(格式支持pdf/img/csv),输入问题和行业,生成对应的分析报告
     Args:
-        file: 用户上传的CSV数据文件
+        file: 用户上传的数据文件
         question: 用户问题
         industry: 行业
     Returns:
@@ -115,24 +133,40 @@ async def gen_report(file: UploadFile = File(...), question: str = Form(...), in
         with open(full_path, "wb") as f:
             f.write(file.file.read())
 
-        # 读取文件内容,标准化文件
-        data_set = DataManager.load_data_from_csv_file(full_path)
-        # 执行测试
-        result = await run_complete_agent_flow(
-            question=question,
-            industry=industry,
-            data=data_set,
-            file_name=file.filename,
-            api_key=DEEPSEEK_API_KEY,
-            session_id="direct-test"
-        )
-        print(result)
-        return {
-            "status": 0,
-            "message": "success",
-            "outline_draft": result["result"]["outline_draft"],
-            "computed_metrics": result["result"]["computed_metrics"]
-        }
+        # 数据标准化
+        result = await standard_agent.run_workflow_task(full_path)
+        if result["status"] == "success":
+            print(f"🎯 Workflow 任务完成!")
+            # 标准化后的文件
+            standard_file_path = result['file_path']
+            standard_file_name = os.path.basename(standard_file_path)
+            print(f"📂 文件全路径: {standard_file_path}, 文件名:{standard_file_name}")
+
+            # 读取文件内容,标准化文件
+            data_set = DataManager.load_data_from_csv_file(standard_file_path)
+            # 执行测试
+            result = await run_complete_agent_flow(
+                question=question,
+                industry=industry,
+                data=data_set,
+                file_name=standard_file_name,
+                api_key=DEEPSEEK_API_KEY,
+                session_id="direct-test"
+            )
+            print(result)
+            return {
+                "status": 0,
+                "message": "success",
+                "outline_draft": result["result"]["outline_draft"],
+                "computed_metrics": result["result"]["computed_metrics"]
+            }
+        else:
+            print(f"❌ 任务失败: {result['message']}")
+            return {
+                "status": 1,
+                "message": result['message'],
+                "report": {}
+            }
     except Exception as e:
         print(f"生成流水分析报告异常: {e}")
         import traceback
@@ -144,9 +178,6 @@ async def gen_report(file: UploadFile = File(...), question: str = Form(...), in
         }
 
 
-
-
-
 if __name__ == "__main__":
     import uvicorn
     uvicorn.run("main:app", host="0.0.0.0", port=3699, reload=True)