| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- from fastapi import FastAPI, File, UploadFile, Form
- import os
- from llmops.agents.data_classify_agent import DataClassifyAgent
- from llmops.agents.data_manager import DataManager
- import csv
- from fastapi.responses import StreamingResponse
- import io
- import urllib.parse
- from llmops.complete_agent_flow_rule import run_complete_agent_flow
- from llmops.config import DEEPSEEK_API_KEY
- # 初始化FastAPI应用
- app = FastAPI(
- title="智能体服务<Agents API>",
- description="提供数据分类打标服务",
- version="1.0"
- )
- @app.get("/")
- def root():
- return {"message": "Hello, Agents for you!"}
- # 设置文件保存路径
- UPLOAD_FOLDER = "uploads"
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
- data_classify_agent = DataClassifyAgent()
- @app.post("/api/dataset/classify")
- async def dataset_classify(file: UploadFile = File(...), industry: str = Form(...)):
- """
- 上传原始数据文件(格式CSV)进行分类打标
- Args:
- file: 用户上传的CSV数据文件
- industry: 行业
- """
- # 获取文件的存储路径
- file_path = os.path.join(UPLOAD_FOLDER, file.filename)
- full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
- print(f"上传文件绝对路径:{full_path}")
- try:
- # 将文件内容写入到上传目录中
- with open(full_path, "wb") as f:
- f.write(file.file.read())
- # 读取文件内容
- data_set = DataManager.load_data_from_csv_file(full_path)
- # 对数据进行分类打标
- data_set_classified = data_classify_agent.invoke_data_classify(industry=industry, data_set=data_set, file_name=file.filename)
- # 格式转换 [{}] -> csv 列表形式
- # 生成器函数,用于逐行生成 CSV 内容
- def generate_csv_rows():
- # 创建内存中的字符串缓冲区
- output = io.StringIO()
- writer = csv.writer(output)
- output.write('\ufeff') # BOM头,确保Excel正确显示中文
- # 写入表头
- writer.writerow(data_classify_agent.fields_order)
- yield output.getvalue()
- output.seek(0)
- output.truncate(0)
- # 逐行写入数据
- for row in data_set_classified:
- writer.writerow([row["txId"], row["txDate"], row["txTime"], row["txAmount"], row["txBalance"],
- row["txDirection"], row["txSummary"], row["txCounterparty"], row["createdAt"],
- row["businessType"]])
- yield output.getvalue()
- output.seek(0)
- output.truncate(0)
- # 输出文件名(打好标文件)
- output_file = str(file.filename).split(".")[0] + "_label.csv"
- # 对文件名进行URL编码,处理中文等非ASCII字符
- encoded_filename = urllib.parse.quote(output_file)
- # 使用 StreamingResponse 返回 CSV 文件
- return StreamingResponse(
- generate_csv_rows(),
- media_type="text/csv;charset=utf-8-sig",
- headers={
- "Content-Disposition": f"attachment; filename=\"{encoded_filename}\"; filename*=UTF-8''{encoded_filename}"
- }
- )
- except Exception as e:
- print(f"上传文件分类打标异常 {e}")
- import traceback
- traceback.print_exc()
- @app.post("/api/report/gen")
- async def gen_report(file: UploadFile = File(...), question: str = Form(...), industry: str = Form(...)):
- """
- 上传原始数据文件(格式CSV),输入问题和行业,生成对应的分析报告
- Args:
- file: 用户上传的CSV数据文件
- question: 用户问题
- industry: 行业
- Returns:
- 报告的JSON结构
- """
- # 获取文件的存储路径
- file_path = os.path.join(UPLOAD_FOLDER, file.filename)
- full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
- print(f"上传文件绝对路径:{full_path}")
- try:
- # 将文件内容写入到上传目录中
- with open(full_path, "wb") as f:
- f.write(file.file.read())
- # 读取文件内容,标准化文件
- data_set = DataManager.load_data_from_csv_file(full_path)
- # 执行测试
- result = await run_complete_agent_flow(
- question=question,
- industry=industry,
- data=data_set,
- file_name=file.filename,
- api_key=DEEPSEEK_API_KEY,
- session_id="direct-test"
- )
- print(result)
- return {
- "status": 0,
- "message": "success",
- "outline_draft": result["result"]["outline_draft"],
- "computed_metrics": result["result"]["computed_metrics"]
- }
- except Exception as e:
- print(f"生成流水分析报告异常: {e}")
- import traceback
- traceback.print_exc()
- return {
- "status": 1,
- "message": "error",
- "report": {}
- }
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run("main:app", host="0.0.0.0", port=3699, reload=True)
|