data_classify_agent.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from llmops.config import DATA_CLASSIFY_ENGINE_PARAM_MAPPER, RULES_ENGINE_BASE_URL
  2. import requests
  3. from llmops.agents.data_manager import DataManager
  4. import os
  5. from typing import Dict, List, Any
  6. class DataClassifyAgent:
  7. """
  8. 数据分类打标Agent
  9. """
  10. def __init__(self):
  11. # 数据打标引擎入参映射
  12. self.DATA_CLASSIFY_ENGINE_PARAM_MAPPER = DATA_CLASSIFY_ENGINE_PARAM_MAPPER
  13. # 保存分类文件的表头顺序
  14. self.fields_order = ["txId","txDate","txTime","txAmount","txBalance","txDirection","txSummary","txCounterparty","createdAt", "businessType"]
  15. def invoke_data_classify(self, industry:str, data_set: list[dict], file_name: str) -> list[dict]:
  16. """
  17. 调用分类打标接口
  18. :param industry: 行业
  19. :param data_set: 数据集
  20. :param file_name: 文件名称
  21. :return:
  22. """
  23. try:
  24. url = f"{RULES_ENGINE_BASE_URL}/api/rules/executeKnowledge"
  25. headers = {
  26. "Accept": "*/*",
  27. "Accept-Encoding": "gzip, deflate, br",
  28. "Connection": "keep-alive",
  29. "Content-Type": "application/json",
  30. "User-Agent": "PostmanRuntime-ApipostRuntime/1.1.0"
  31. }
  32. json_data = {
  33. "id": self.DATA_CLASSIFY_ENGINE_PARAM_MAPPER.get(industry),
  34. "input": {
  35. "transactions": data_set
  36. }
  37. }
  38. # 调用分类打标服务
  39. response = requests.post(url, headers=headers, json=json_data, timeout=30)
  40. if response.status_code == 200:
  41. # 取结果
  42. data_set_classified = response.json()
  43. if isinstance(data_set_classified, dict):
  44. # 取出打标数据集
  45. ds = data_set_classified.get('transactions', [])
  46. print(f"✅ 成功分类打标数量: {len(ds)}")
  47. # 将分类好的数据写入数据目录中
  48. self.save_classified_data(ds, file_name)
  49. return ds
  50. else:
  51. print(f"⚠️ 分类打标格式异常: {data_set_classified}")
  52. return []
  53. else:
  54. print(f"❌ 数据分类打标失败,状态码: {response.status_code}")
  55. print(f"响应内容: {response.text}")
  56. return []
  57. except Exception as e:
  58. print(f"❌ 调用数据分类打标时发生错误: {str(e)}")
  59. import traceback
  60. traceback.print_exc()
  61. return []
  62. def save_classified_data(self, json_data, file_name: str):
  63. """
  64. 将生成的分类数据写入CSV文件,将文件保存到 data_files目录下
  65. :param json_data: 分类后的数据,结构 [{}]
  66. :param file_name: 文件名,规则是 原名_label.csv
  67. :return:
  68. """
  69. # 新文件名
  70. new_file_name = file_name.split(".")[0] + "_label.csv"
  71. full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "data_files", new_file_name)
  72. print(f"分类文件:{full_path}, 是否存在:{os.path.exists(full_path)}")
  73. succ = DataManager.write_json_to_csv(json_data=json_data, csv_file_path=full_path, field_order=self.fields_order)
  74. print(f"将分类数据写入文件:{full_path} {'成功' if succ else '失败'}")
  75. async def data_classify(industry: str, data_set: List[Dict[str, Any]], file_name: str) -> List[Dict]:
  76. """
  77. 对数据进行分类
  78. Args:
  79. industry: 行业
  80. data_set: 待处理数据
  81. file_name: 数据文件名称
  82. Returns:
  83. 分类好的数据
  84. """
  85. import asyncio
  86. import time
  87. agent = DataClassifyAgent()
  88. print(f"📝 开始对文件:{file_name} 数据进行分类打标...")
  89. data_set_classified = []
  90. try:
  91. start_time = time.time()
  92. # 进行分类打标
  93. data_set_classified = agent.invoke_data_classify(industry=industry, data_set=data_set, file_name=file_name)
  94. elapsed_time = time.time() - start_time
  95. print(f"分类打标用时:{elapsed_time:.2f} 秒")
  96. process_size = len(data_set_classified)
  97. if process_size == len(data_set):
  98. print(f"\n📝 分类打标处理成功, 处理记录条数:{process_size}")
  99. elif process_size == 0:
  100. print(f"\n📝 分类打标处理失败, 处理记录条数:{process_size}")
  101. else:
  102. print(f"\n📝 分类打标处理部分数据, 处理记录条数:{process_size}")
  103. return data_set_classified
  104. except Exception as e:
  105. elapsed_time = time.time() - start_time if 'start_time' in locals() else 0
  106. import traceback
  107. traceback.print_exc()
  108. print(f" 错误详情: {str(e)}")
  109. return data_set_classified