Browse Source

数据分类打标智能体

jiaqiang 2 days ago
parent
commit
53b69a74ce
1 changed files with 124 additions and 0 deletions
  1. 124 0
      llmops/agents/data_classify_agent.py

+ 124 - 0
llmops/agents/data_classify_agent.py

@@ -0,0 +1,124 @@
+
+from llmops.config import DATA_CLASSIFY_ENGINE_PARAM_MAPPER, RULES_ENGINE_BASE_URL
+import requests
+from llmops.agents.data_manager import DataManager
+import os
+from typing import Dict, List, Any
+
+class DataClassifyAgent:
+    """
+    数据分类打标Agent
+    """
+    def __init__(self):
+        # 数据打标引擎入参映射
+        self.DATA_CLASSIFY_ENGINE_PARAM_MAPPER = DATA_CLASSIFY_ENGINE_PARAM_MAPPER
+        # 保存分类文件的表头顺序
+        self.fields_order = ["txId","txDate","txTime","txAmount","txBalance","txDirection","txSummary","txCounterparty","createdAt", "businessType"]
+
+    def invoke_data_classify(self, industry:str, data_set: list[dict], file_name: str) -> list[dict]:
+        """
+        调用分类打标接口
+        :param industry:  行业
+        :param data_set:  数据集
+        :param file_name: 文件名称
+        :return:
+        """
+        try:
+            url = f"{RULES_ENGINE_BASE_URL}/api/rules/executeKnowledge"
+            headers = {
+                "Accept": "*/*",
+                "Accept-Encoding": "gzip, deflate, br",
+                "Connection": "keep-alive",
+                "Content-Type": "application/json",
+                "User-Agent": "PostmanRuntime-ApipostRuntime/1.1.0"
+            }
+
+            json_data = {
+                "id": self.DATA_CLASSIFY_ENGINE_PARAM_MAPPER.get(industry),
+                "input": {
+                    "data": data_set
+                }
+            }
+
+            # 调用分类打标服务
+            response = requests.post(url, headers=headers, json=json_data, timeout=30)
+            if response.status_code == 200:
+                # 取结果
+                data_set_classified = response.json()
+                if isinstance(data_set_classified, dict):
+                    # 取出打标数据集
+                    ds = data_set_classified.get('resultTag', 0)
+                    print(f"✅ 成功分类打标数量: {len(ds)}")
+                    # 将分类好的数据写入数据目录中
+                    self.save_classified_data(ds, file_name)
+
+                    return ds
+                else:
+                    print(f"⚠️ 分类打标格式异常: {data_set_classified}")
+                    return []
+            else:
+                print(f"❌ 数据分类打标失败,状态码: {response.status_code}")
+                print(f"响应内容: {response.text}")
+                return []
+
+        except Exception as e:
+            print(f"❌ 调用数据分类打标时发生错误: {str(e)}")
+            return []
+
+
+    def save_classified_data(self, json_data, file_name: str):
+        """
+        将生成的分类数据写入CSV文件,将文件保存到 data_files目录下
+        :param json_data: 分类后的数据,结构 [{}]
+        :param file_name: 文件名,规则是 原名_label.csv
+        :return:
+        """
+        # 新文件名
+        new_file_name = file_name.split(".")[0] + "_label.csv"
+        full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "data_files", new_file_name)
+        print(f"分类文件:{full_path}, 是否存在:{os.path.exists(full_path)}")
+
+        succ = DataManager.write_json_to_csv(json_data=json_data, csv_file_path=full_path, field_order=self.fields_order)
+        print(f"将分类数据写入文件:{full_path} {'成功' if succ else '失败'}")
+
+async def data_classify(industry: str, data_set: List[Dict[str, Any]], file_name: str) -> List[Dict]:
+    """
+    对数据进行分类
+
+    Args:
+        industry: 行业
+        data_set: 待处理数据
+        file_name: 数据文件名称
+    Returns:
+        分类好的数据
+    """
+    import asyncio
+    import time
+
+    agent = DataClassifyAgent()
+
+    print(f"📝 开始对文件:{file_name} 数据进行分类打标...")
+    data_set_classified = []
+    try:
+        start_time = time.time()
+        # 进行分类打标
+        data_set_classified = agent.invoke_data_classify(industry=industry, data_set=data_set, file_name=file_name)
+        elapsed_time = time.time() - start_time
+        print(f"分类打标用时:{elapsed_time:.2f} 秒")
+        process_size = len(data_set_classified)
+        if process_size == len(data_set):
+            print(f"\n📝 分类打标处理成功, 处理记录条数:{process_size}")
+        elif process_size == 0:
+            print(f"\n📝 分类打标处理失败, 处理记录条数:{process_size}")
+        else:
+            print(f"\n📝 分类打标处理部分数据, 处理记录条数:{process_size}")
+
+        return data_set_classified
+
+    except Exception as e:
+        elapsed_time = time.time() - start_time if 'start_time' in locals() else 0
+        import traceback
+        traceback.print_exc()
+        print(f"   错误详情: {str(e)}")
+
+    return data_set_classified