main.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from fastapi import FastAPI, File, UploadFile, Form
  2. import os
  3. from llmops.agents.data_classify_agent import DataClassifyAgent
  4. from llmops.agents.data_manager import DataManager
  5. import csv
  6. from fastapi.responses import StreamingResponse
  7. import io
  8. import urllib.parse
  9. from llmops.complete_agent_flow_rule import run_complete_agent_flow
  10. from llmops.config import DEEPSEEK_API_KEY
  11. from llmops.agents.data_stardard import TransactionParserAgent
  12. from llmops.config import multimodal_api_url
  13. os.environ["LANGCHAIN_TRACING_V2"] = "false"
  14. os.environ["LANGCHAIN_API_KEY"] = ""
  15. # 禁用 LangGraph 的追踪
  16. os.environ["LANGSMITH_TRACING"] = "false"
  17. # 初始化FastAPI应用
  18. app = FastAPI(
  19. title="智能体服务<Agents API>",
  20. description="提供数据分类打标服务",
  21. version="1.0"
  22. )
  23. @app.get("/")
  24. def root():
  25. return {"message": "Hello, Agents for you!"}
  26. # 设置文件保存路径
  27. UPLOAD_FOLDER = "uploads"
  28. os.makedirs(UPLOAD_FOLDER, exist_ok=True)
  29. data_classify_agent = DataClassifyAgent()
  30. @app.post("/api/dataset/classify")
  31. async def dataset_classify(file: UploadFile = File(...), industry: str = Form(...)):
  32. """
  33. 上传原始数据文件(格式CSV)进行分类打标
  34. Args:
  35. file: 用户上传的CSV数据文件
  36. industry: 行业
  37. """
  38. # 获取文件的存储路径
  39. file_path = os.path.join(UPLOAD_FOLDER, file.filename)
  40. full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
  41. print(f"上传文件绝对路径:{full_path}")
  42. try:
  43. # 将文件内容写入到上传目录中
  44. with open(full_path, "wb") as f:
  45. f.write(file.file.read())
  46. # 读取文件内容
  47. data_set = DataManager.load_data_from_csv_file(full_path)
  48. # 对数据进行分类打标
  49. data_set_classified = data_classify_agent.invoke_data_classify(industry=industry, data_set=data_set, file_name=file.filename)
  50. # 格式转换 [{}] -> csv 列表形式
  51. # 生成器函数,用于逐行生成 CSV 内容
  52. def generate_csv_rows():
  53. # 创建内存中的字符串缓冲区
  54. output = io.StringIO()
  55. writer = csv.writer(output)
  56. output.write('\ufeff') # BOM头,确保Excel正确显示中文
  57. # 写入表头
  58. writer.writerow(data_classify_agent.fields_order)
  59. yield output.getvalue()
  60. output.seek(0)
  61. output.truncate(0)
  62. # 逐行写入数据
  63. for row in data_set_classified:
  64. writer.writerow([row["txId"], row["txDate"], row["txTime"], row["txAmount"], row["txBalance"],
  65. row["txDirection"], row["txSummary"], row["txCounterparty"], row["createdAt"],
  66. row["businessType"]])
  67. yield output.getvalue()
  68. output.seek(0)
  69. output.truncate(0)
  70. # 输出文件名(打好标文件)
  71. output_file = str(file.filename).split(".")[0] + "_label.csv"
  72. # 对文件名进行URL编码,处理中文等非ASCII字符
  73. encoded_filename = urllib.parse.quote(output_file)
  74. # 使用 StreamingResponse 返回 CSV 文件
  75. return StreamingResponse(
  76. generate_csv_rows(),
  77. media_type="text/csv;charset=utf-8-sig",
  78. headers={
  79. "Content-Disposition": f"attachment; filename=\"{encoded_filename}\"; filename*=UTF-8''{encoded_filename}"
  80. }
  81. )
  82. except Exception as e:
  83. print(f"上传文件分类打标异常 {e}")
  84. import traceback
  85. traceback.print_exc()
  86. # 数据标准化agent
  87. standard_agent = TransactionParserAgent(
  88. api_key=DEEPSEEK_API_KEY,
  89. multimodal_api_url=multimodal_api_url
  90. )
  91. @app.post("/api/report/gen")
  92. async def gen_report(file: UploadFile = File(...), question: str = Form(...), industry: str = Form(...)):
  93. """
  94. 上传原始数据文件(格式支持pdf/img/csv),输入问题和行业,生成对应的分析报告
  95. Args:
  96. file: 用户上传的数据文件
  97. question: 用户问题
  98. industry: 行业
  99. Returns:
  100. 报告的JSON结构
  101. """
  102. # 获取文件的存储路径
  103. file_path = os.path.join(UPLOAD_FOLDER, file.filename)
  104. full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
  105. print(f"上传文件绝对路径:{full_path}")
  106. try:
  107. # 将文件内容写入到上传目录中
  108. with open(full_path, "wb") as f:
  109. f.write(file.file.read())
  110. # 数据标准化
  111. result = await standard_agent.run_workflow_task(full_path)
  112. if result["status"] == "success":
  113. print(f"🎯 Workflow 任务完成!")
  114. # 标准化后的文件
  115. standard_file_path = result['file_path']
  116. standard_file_name = os.path.basename(standard_file_path)
  117. print(f"📂 文件全路径: {standard_file_path}, 文件名:{standard_file_name}")
  118. # 读取文件内容,标准化文件
  119. data_set = DataManager.load_data_from_csv_file(standard_file_path)
  120. # 执行测试
  121. result = await run_complete_agent_flow(
  122. question=question,
  123. industry=industry,
  124. data=data_set,
  125. file_name=standard_file_name,
  126. api_key=DEEPSEEK_API_KEY,
  127. session_id="direct-test"
  128. )
  129. print(result)
  130. return {
  131. "status": 0,
  132. "message": "success",
  133. "outline_draft": result["result"]["outline_draft"],
  134. "computed_metrics": result["result"]["computed_metrics"]
  135. }
  136. else:
  137. print(f"❌ 任务失败: {result['message']}")
  138. return {
  139. "status": 1,
  140. "message": result['message'],
  141. "report": {}
  142. }
  143. except Exception as e:
  144. print(f"生成流水分析报告异常: {e}")
  145. import traceback
  146. traceback.print_exc()
  147. return {
  148. "status": 1,
  149. "message": "error",
  150. "report": {}
  151. }
  152. if __name__ == "__main__":
  153. import uvicorn
  154. uvicorn.run("main:app", host="0.0.0.0", port=3699, reload=True)