jiaqiang 3 дней назад
Родитель
Сommit
e40017caa1
1 измененных файлов с 152 добавлено и 0 удалено
  1. 152 0
      llmops/main.py

+ 152 - 0
llmops/main.py

@@ -0,0 +1,152 @@
+from fastapi import FastAPI, File, UploadFile, Form
+import os
+from llmops.agents.data_classify_agent import DataClassifyAgent
+from llmops.agents.data_manager import DataManager
+import csv
+from fastapi.responses import StreamingResponse
+import io
+import urllib.parse
+from llmops.complete_agent_flow_rule import run_complete_agent_flow
+from llmops.config import DEEPSEEK_API_KEY
+
+# 初始化FastAPI应用
+app = FastAPI(
+    title="智能体服务<Agents API>",
+    description="提供数据分类打标服务",
+    version="1.0"
+)
+
+@app.get("/")
+def root():
+    return {"message": "Hello, Agents for you!"}
+
+# 设置文件保存路径
+UPLOAD_FOLDER = "uploads"
+os.makedirs(UPLOAD_FOLDER, exist_ok=True)
+
+
+
+data_classify_agent = DataClassifyAgent()
+
+@app.post("/api/dataset/classify")
+async def dataset_classify(file: UploadFile = File(...), industry: str = Form(...)):
+    """
+    上传原始数据文件(格式CSV)进行分类打标
+    Args:
+        file: 用户上传的CSV数据文件
+        industry: 行业
+    """
+    # 获取文件的存储路径
+    file_path = os.path.join(UPLOAD_FOLDER, file.filename)
+    full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
+
+    print(f"上传文件绝对路径:{full_path}")
+
+    try:
+        # 将文件内容写入到上传目录中
+        with open(full_path, "wb") as f:
+            f.write(file.file.read())
+
+        # 读取文件内容
+        data_set = DataManager.load_data_from_csv_file(full_path)
+        # 对数据进行分类打标
+        data_set_classified = data_classify_agent.invoke_data_classify(industry=industry, data_set=data_set, file_name=file.filename)
+        # 格式转换 [{}] -> csv 列表形式
+        # 生成器函数,用于逐行生成 CSV 内容
+        def generate_csv_rows():
+            # 创建内存中的字符串缓冲区
+            output = io.StringIO()
+            writer = csv.writer(output)
+            output.write('\ufeff')  # BOM头,确保Excel正确显示中文
+
+            # 写入表头
+            writer.writerow(data_classify_agent.fields_order)
+            yield output.getvalue()
+            output.seek(0)
+            output.truncate(0)
+            # 逐行写入数据
+            for row in data_set_classified:
+                writer.writerow([row["txId"], row["txDate"], row["txTime"], row["txAmount"], row["txBalance"],
+                                 row["txDirection"], row["txSummary"], row["txCounterparty"], row["createdAt"],
+                                 row["businessType"]])
+                yield output.getvalue()
+                output.seek(0)
+                output.truncate(0)
+
+        # 输出文件名(打好标文件)
+        output_file = str(file.filename).split(".")[0] + "_label.csv"
+
+        # 对文件名进行URL编码,处理中文等非ASCII字符
+        encoded_filename = urllib.parse.quote(output_file)
+
+        # 使用 StreamingResponse 返回 CSV 文件
+        return StreamingResponse(
+            generate_csv_rows(),
+            media_type="text/csv;charset=utf-8-sig",
+            headers={
+                "Content-Disposition": f"attachment; filename=\"{encoded_filename}\"; filename*=UTF-8''{encoded_filename}"
+            }
+        )
+    except Exception as e:
+        print(f"上传文件分类打标异常 {e}")
+        import traceback
+        traceback.print_exc()
+
+
+@app.post("/api/report/gen")
+async def gen_report(file: UploadFile = File(...), question: str = Form(...), industry: str = Form(...)):
+    """
+    上传原始数据文件(格式CSV),输入问题和行业,生成对应的分析报告
+    Args:
+        file: 用户上传的CSV数据文件
+        question: 用户问题
+        industry: 行业
+    Returns:
+        报告的JSON结构
+    """
+    # 获取文件的存储路径
+    file_path = os.path.join(UPLOAD_FOLDER, file.filename)
+    full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), file_path)
+
+    print(f"上传文件绝对路径:{full_path}")
+
+    try:
+        # 将文件内容写入到上传目录中
+        with open(full_path, "wb") as f:
+            f.write(file.file.read())
+
+        # 读取文件内容,标准化文件
+        data_set = DataManager.load_data_from_csv_file(full_path)
+        # 执行测试
+        result = await run_complete_agent_flow(
+            question=question,
+            industry=industry,
+            data=data_set,
+            file_name=file.filename,
+            api_key=DEEPSEEK_API_KEY,
+            session_id="direct-test"
+        )
+        print(result)
+        return {
+            "status": 0,
+            "message": "success",
+            "outline_draft": result["result"]["outline_draft"],
+            "computed_metrics": result["result"]["computed_metrics"]
+        }
+    except Exception as e:
+        print(f"生成流水分析报告异常: {e}")
+        import traceback
+        traceback.print_exc()
+        return {
+            "status": 1,
+            "message": "error",
+            "report": {}
+        }
+
+
+
+
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run("main:app", host="0.0.0.0", port=3699, reload=True)