|
|
@@ -0,0 +1,152 @@
|
|
|
+from fastapi import FastAPI, File, UploadFile, Form
|
|
|
+import os
|
|
|
+from llmops.agents.data_classify_agent import DataClassifyAgent
|
|
|
+from llmops.agents.data_manager import DataManager
|
|
|
+import csv
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
+import io
|
|
|
+import urllib.parse
|
|
|
+from llmops.complete_agent_flow_rule import run_complete_agent_flow
|
|
|
+from llmops.config import DEEPSEEK_API_KEY
|
|
|
+
|
|
|
+# 初始化FastAPI应用
|
|
|
+app = FastAPI(
|
|
|
+ title="智能体服务<Agents API>",
|
|
|
+ description="提供数据分类打标服务",
|
|
|
+ version="1.0"
|
|
|
+)
|
|
|
+
|
|
|
+@app.get("/")
|
|
|
+def root():
|
|
|
+ return {"message": "Hello, Agents for you!"}
|
|
|
+
|
|
|
+# 设置文件保存路径
|
|
|
+UPLOAD_FOLDER = "uploads"
|
|
|
+os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+data_classify_agent = DataClassifyAgent()
|
|
|
+
|
|
|
+@app.post("/api/dataset/classify")
|
|
|
+async def dataset_classify(file: UploadFile = File(...), industry: str = Form(...)):
|
|
|
+ """
|
|
|
+ 上传原始数据文件(格式CSV)进行分类打标
|
|
|
+ Args:
|
|
|
+ file: 用户上传的CSV数据文件
|
|
|
+ industry: 行业
|
|
|
+ """
|
|
|
+ # 获取文件的存储路径
|
|
|
+ file_path = os.path.join(UPLOAD_FOLDER, file.filename)
|
|
|
+ full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
|
|
|
+
|
|
|
+ print(f"上传文件绝对路径:{full_path}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 将文件内容写入到上传目录中
|
|
|
+ with open(full_path, "wb") as f:
|
|
|
+ f.write(file.file.read())
|
|
|
+
|
|
|
+ # 读取文件内容
|
|
|
+ data_set = DataManager.load_data_from_csv_file(full_path)
|
|
|
+ # 对数据进行分类打标
|
|
|
+ data_set_classified = data_classify_agent.invoke_data_classify(industry=industry, data_set=data_set, file_name=file.filename)
|
|
|
+ # 格式转换 [{}] -> csv 列表形式
|
|
|
+ # 生成器函数,用于逐行生成 CSV 内容
|
|
|
+ def generate_csv_rows():
|
|
|
+ # 创建内存中的字符串缓冲区
|
|
|
+ output = io.StringIO()
|
|
|
+ writer = csv.writer(output)
|
|
|
+ output.write('\ufeff') # BOM头,确保Excel正确显示中文
|
|
|
+
|
|
|
+ # 写入表头
|
|
|
+ writer.writerow(data_classify_agent.fields_order)
|
|
|
+ yield output.getvalue()
|
|
|
+ output.seek(0)
|
|
|
+ output.truncate(0)
|
|
|
+ # 逐行写入数据
|
|
|
+ for row in data_set_classified:
|
|
|
+ writer.writerow([row["txId"], row["txDate"], row["txTime"], row["txAmount"], row["txBalance"],
|
|
|
+ row["txDirection"], row["txSummary"], row["txCounterparty"], row["createdAt"],
|
|
|
+ row["businessType"]])
|
|
|
+ yield output.getvalue()
|
|
|
+ output.seek(0)
|
|
|
+ output.truncate(0)
|
|
|
+
|
|
|
+ # 输出文件名(打好标文件)
|
|
|
+ output_file = str(file.filename).split(".")[0] + "_label.csv"
|
|
|
+
|
|
|
+ # 对文件名进行URL编码,处理中文等非ASCII字符
|
|
|
+ encoded_filename = urllib.parse.quote(output_file)
|
|
|
+
|
|
|
+ # 使用 StreamingResponse 返回 CSV 文件
|
|
|
+ return StreamingResponse(
|
|
|
+ generate_csv_rows(),
|
|
|
+ media_type="text/csv;charset=utf-8-sig",
|
|
|
+ headers={
|
|
|
+ "Content-Disposition": f"attachment; filename=\"{encoded_filename}\"; filename*=UTF-8''{encoded_filename}"
|
|
|
+ }
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(f"上传文件分类打标异常 {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/report/gen")
|
|
|
+async def gen_report(file: UploadFile = File(...), question: str = Form(...), industry: str = Form(...)):
|
|
|
+ """
|
|
|
+ 上传原始数据文件(格式CSV),输入问题和行业,生成对应的分析报告
|
|
|
+ Args:
|
|
|
+ file: 用户上传的CSV数据文件
|
|
|
+ question: 用户问题
|
|
|
+ industry: 行业
|
|
|
+ Returns:
|
|
|
+ 报告的JSON结构
|
|
|
+ """
|
|
|
+ # 获取文件的存储路径
|
|
|
+ file_path = os.path.join(UPLOAD_FOLDER, file.filename)
|
|
|
+ full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
|
|
|
+
|
|
|
+ print(f"上传文件绝对路径:{full_path}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 将文件内容写入到上传目录中
|
|
|
+ with open(full_path, "wb") as f:
|
|
|
+ f.write(file.file.read())
|
|
|
+
|
|
|
+ # 读取文件内容,标准化文件
|
|
|
+ data_set = DataManager.load_data_from_csv_file(full_path)
|
|
|
+ # 执行测试
|
|
|
+ result = await run_complete_agent_flow(
|
|
|
+ question=question,
|
|
|
+ industry=industry,
|
|
|
+ data=data_set,
|
|
|
+ file_name=file.filename,
|
|
|
+ api_key=DEEPSEEK_API_KEY,
|
|
|
+ session_id="direct-test"
|
|
|
+ )
|
|
|
+ print(result)
|
|
|
+ return {
|
|
|
+ "status": 0,
|
|
|
+ "message": "success",
|
|
|
+ "outline_draft": result["result"]["outline_draft"],
|
|
|
+ "computed_metrics": result["result"]["computed_metrics"]
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ print(f"生成流水分析报告异常: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ return {
|
|
|
+ "status": 1,
|
|
|
+ "message": "error",
|
|
|
+ "report": {}
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import uvicorn
|
|
|
+ uvicorn.run("main:app", host="0.0.0.0", port=3699, reload=True)
|