|
|
@@ -0,0 +1,124 @@
|
|
|
+
|
|
|
+from llmops.config import DATA_CLASSIFY_ENGINE_PARAM_MAPPER, RULES_ENGINE_BASE_URL
|
|
|
+import requests
|
|
|
+from llmops.agents.data_manager import DataManager
|
|
|
+import os
|
|
|
+from typing import Dict, List, Any
|
|
|
+
|
|
|
+class DataClassifyAgent:
|
|
|
+ """
|
|
|
+ 数据分类打标Agent
|
|
|
+ """
|
|
|
+ def __init__(self):
|
|
|
+ # 数据打标引擎入参映射
|
|
|
+ self.DATA_CLASSIFY_ENGINE_PARAM_MAPPER = DATA_CLASSIFY_ENGINE_PARAM_MAPPER
|
|
|
+ # 保存分类文件的表头顺序
|
|
|
+ self.fields_order = ["txId","txDate","txTime","txAmount","txBalance","txDirection","txSummary","txCounterparty","createdAt", "businessType"]
|
|
|
+
|
|
|
+ def invoke_data_classify(self, industry:str, data_set: list[dict], file_name: str) -> list[dict]:
|
|
|
+ """
|
|
|
+ 调用分类打标接口
|
|
|
+ :param industry: 行业
|
|
|
+ :param data_set: 数据集
|
|
|
+ :param file_name: 文件名称
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ url = f"{RULES_ENGINE_BASE_URL}/api/rules/executeKnowledge"
|
|
|
+ headers = {
|
|
|
+ "Accept": "*/*",
|
|
|
+ "Accept-Encoding": "gzip, deflate, br",
|
|
|
+ "Connection": "keep-alive",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "User-Agent": "PostmanRuntime-ApipostRuntime/1.1.0"
|
|
|
+ }
|
|
|
+
|
|
|
+ json_data = {
|
|
|
+ "id": self.DATA_CLASSIFY_ENGINE_PARAM_MAPPER.get(industry),
|
|
|
+ "input": {
|
|
|
+ "data": data_set
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ # 调用分类打标服务
|
|
|
+ response = requests.post(url, headers=headers, json=json_data, timeout=30)
|
|
|
+ if response.status_code == 200:
|
|
|
+ # 取结果
|
|
|
+ data_set_classified = response.json()
|
|
|
+ if isinstance(data_set_classified, dict):
|
|
|
+ # 取出打标数据集
|
|
|
+ ds = data_set_classified.get('resultTag', 0)
|
|
|
+ print(f"✅ 成功分类打标数量: {len(ds)}")
|
|
|
+ # 将分类好的数据写入数据目录中
|
|
|
+ self.save_classified_data(ds, file_name)
|
|
|
+
|
|
|
+ return ds
|
|
|
+ else:
|
|
|
+ print(f"⚠️ 分类打标格式异常: {data_set_classified}")
|
|
|
+ return []
|
|
|
+ else:
|
|
|
+ print(f"❌ 数据分类打标失败,状态码: {response.status_code}")
|
|
|
+ print(f"响应内容: {response.text}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 调用数据分类打标时发生错误: {str(e)}")
|
|
|
+ return []
|
|
|
+
|
|
|
+
|
|
|
+ def save_classified_data(self, json_data, file_name: str):
|
|
|
+ """
|
|
|
+ 将生成的分类数据写入CSV文件,将文件保存到 data_files目录下
|
|
|
+ :param json_data: 分类后的数据,结构 [{}]
|
|
|
+ :param file_name: 文件名,规则是 原名_label.csv
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ # 新文件名
|
|
|
+ new_file_name = file_name.split(".")[0] + "_label.csv"
|
|
|
+ full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "data_files", new_file_name)
|
|
|
+ print(f"分类文件:{full_path}, 是否存在:{os.path.exists(full_path)}")
|
|
|
+
|
|
|
+ succ = DataManager.write_json_to_csv(json_data=json_data, csv_file_path=full_path, field_order=self.fields_order)
|
|
|
+ print(f"将分类数据写入文件:{full_path} {'成功' if succ else '失败'}")
|
|
|
+
|
|
|
+async def data_classify(industry: str, data_set: List[Dict[str, Any]], file_name: str) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 对数据进行分类
|
|
|
+
|
|
|
+ Args:
|
|
|
+ industry: 行业
|
|
|
+ data_set: 待处理数据
|
|
|
+ file_name: 数据文件名称
|
|
|
+ Returns:
|
|
|
+ 分类好的数据
|
|
|
+ """
|
|
|
+ import asyncio
|
|
|
+ import time
|
|
|
+
|
|
|
+ agent = DataClassifyAgent()
|
|
|
+
|
|
|
+ print(f"📝 开始对文件:{file_name} 数据进行分类打标...")
|
|
|
+ data_set_classified = []
|
|
|
+ try:
|
|
|
+ start_time = time.time()
|
|
|
+ # 进行分类打标
|
|
|
+ data_set_classified = agent.invoke_data_classify(industry=industry, data_set=data_set, file_name=file_name)
|
|
|
+ elapsed_time = time.time() - start_time
|
|
|
+ print(f"分类打标用时:{elapsed_time:.2f} 秒")
|
|
|
+ process_size = len(data_set_classified)
|
|
|
+ if process_size == len(data_set):
|
|
|
+ print(f"\n📝 分类打标处理成功, 处理记录条数:{process_size}")
|
|
|
+ elif process_size == 0:
|
|
|
+ print(f"\n📝 分类打标处理失败, 处理记录条数:{process_size}")
|
|
|
+ else:
|
|
|
+ print(f"\n📝 分类打标处理部分数据, 处理记录条数:{process_size}")
|
|
|
+
|
|
|
+ return data_set_classified
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ elapsed_time = time.time() - start_time if 'start_time' in locals() else 0
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ print(f" 错误详情: {str(e)}")
|
|
|
+
|
|
|
+ return data_set_classified
|