jiayongqiang пре 2 дана
родитељ
комит
125f001de7

BIN
agent/agent-0.1.0-py3-none-any.whl


+ 16 - 5
agent/src/agent/api_outter.py

@@ -179,14 +179,25 @@ class FeedbackRequest(BaseModel):
     user_org: Optional[str] = Field(None, description="用户所属机构")
     user_endpoint: Optional[str] = Field(None, description="用户所属网点")
     business_attr: str = Field(..., description="业务属性")
+    phrase: Optional[str] = Field(None, description="打标文本")
 
 @router.post("/feedback")
 def ai_feedback(feedback_request: FeedbackRequest):
     logger.info(f"Received feedback request: {feedback_request}")
-    # 这里将用户的反馈信息保存到数据库中aitag_tag_log,供后续分析和模型优化使用
-    dao.execute(
-        """update aitag_tag_log set feedback = %s, feedback_result = %s, feedback_time = %s, feedback_user_id = %s, feedback_user_nm = %s, contract_no = %s, feedback_user_org = %s, feedback_user_endpoint = %s, state = %s where business_attr = %s""",   
-        (feedback_request.feedback, feedback_request.feedback_result, datetime.now(), feedback_request.user_id, feedback_request.user_nm, feedback_request.contract_no, feedback_request.user_org, feedback_request.user_endpoint, TAGGING_STATE.FEEDBACK.value, feedback_request.business_attr)
-    )
+
+    r = dao.query("""SELECT count(*) FROM aitag_tag_log WHERE business_attr = %s and is_delete = 0""", (feedback_request.business_attr,))
+    if r and len(r) > 0 and r[0][0]==1:
+        # update
+        # 这里将用户的反馈信息保存到数据库中aitag_tag_log,供后续分析和模型优化使用
+        dao.execute(
+            """update aitag_tag_log set feedback = %s, feedback_result = %s, feedback_time = %s, feedback_user_id = %s, feedback_user_nm = %s, contract_no = %s, feedback_user_org = %s, feedback_user_endpoint = %s, state = %s where business_attr = %s""",   
+            (feedback_request.feedback, feedback_request.feedback_result, datetime.now(), feedback_request.user_id, feedback_request.user_nm, feedback_request.contract_no, feedback_request.user_org, feedback_request.user_endpoint, TAGGING_STATE.FEEDBACK.value, feedback_request.business_attr)
+        )
+    else:
+        dao.execute(
+            """INSERT INTO aitag_tag_log (id, insert_time, business_attr, phrase, state, feedback, feedback_result, feedback_time, feedback_user_id, feedback_user_nm, contract_no, feedback_user_org, feedback_user_endpoint) 
+            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
+            (uuid.uuid4().hex,  datetime.now(), feedback_request.business_attr, feedback_request.phrase, TAGGING_STATE.FEEDBACK.value, feedback_request.feedback, feedback_request.feedback_result, datetime.now(), feedback_request.user_id, feedback_request.user_nm, feedback_request.contract_no, feedback_request.user_org, feedback_request.user_endpoint)
+        )
     return {"code": 200, "message": "Feedback received successfully"}
 

+ 3 - 12
agent/src/agent/core/dao.py

@@ -60,15 +60,6 @@ def execute(sql, params=None):
 
 if __name__ == "__main__":
     print(host, port, user, password, database)
-    with get_db_connection() as conn:
-        with conn.cursor() as cursor:
-            cursor.execute("SELECT VERSION()")
-            version = cursor.fetchone()
-            print("Database version:", version)
-
-    # execute("insert into tab_tag_category (id,category_code,category_nm,state) values (%s,%s,%s,%s)", ("3", "Test Category", "Test Category Name",0))    
-    # result = query("select * from tab_tag_category")
-    # print(result)
-    execute("update tab_tag_category set category_nm=%s where id=%s", ("Updated Category Name", 3))
-    result = query("select * from tab_tag_category")
-    print(result)
+    r = query("""SELECT count(*) FROM aitag_tag_log WHERE business_attr = %s and is_delete = 0""", ('ca6967d61b1d4454a7879085259fe219',))
+    if r and len(r) > 0:
+        print(r[0][0])

+ 1 - 57
agent/src/agent/core/es.py

@@ -1,12 +1,12 @@
 from elasticsearch import Elasticsearch,helpers
 from agent.logger import logger
-
 from agent.core.config import get_config_path
 config = get_config_path()
 TOP_K = int(config['app']['top_k'])
 
 url = config['es']['url']
 DIMS = int(config['embedding']['default_dims'])
+DIMS = 512
 
 RRF_CONST:int = 60
 
@@ -124,8 +124,6 @@ def bulk_upsert(documents):
         logger.error(f"Bulk upsert errors: {errors}")
     except Exception as e:
         logger.error(f"Bulk upsert failed: {e}")
-        for error in e.errors:
-            logger.error(f"Error: {error}")
         raise
 
 def hybrid_search( target_doc_ids, query_vector, top_k=2):
@@ -155,60 +153,6 @@ def hybrid_search( target_doc_ids, query_vector, top_k=2):
     )
     return response["hits"]["hits"]
 
-def bm25_vector_search(query:str,query_vector,vector_threshold =0.76,rrf_threshold=0.031):
-    resp_bm25 = es.search(
-        index=INDEX_NAME,
-        size=max(TOP_K, 10),
-        query={"match": {"tag_path": query}}
-    )
-    resp_vector = es.search(
-        index=INDEX_NAME,
-        size=max(TOP_K,10),
-        knn={
-            "field": "tag_vector",
-            "query_vector": query_vector,
-            "k": max(TOP_K, 10),
-            "num_candidates": 100
-        }
-    )
-    rrf_scores = {}
-    rrf_scores_data = {}
-    for rank, hit in enumerate(resp_bm25['hits']['hits'], start=1):
-        hit["_source"]["tag_vector"] = None
-        print("a:"+hit["_source"]["tag_path"])
-        doc_id = hit['_id']
-        score = rrf_scores.get(doc_id, 0.0)
-        rrf_scores[doc_id] = score + (1.0 / (RRF_CONST + rank))
-        if doc_id not in rrf_scores_data:
-             rrf_scores_data[doc_id] = hit
-
-    for rank, hit in enumerate(resp_vector['hits']['hits'], start=1):
-        hit["_source"]["tag_vector"] = None
-        print("b:"+hit["_source"]["tag_path"]+";"+str(hit["_score"]))
-        if hit["_score"]>vector_threshold:
-            doc_id = hit['_id']
-            score = rrf_scores.get(doc_id, 0.0)
-            rrf_scores[doc_id] = score + (1.0 / (RRF_CONST + rank))
-            if doc_id not in rrf_scores_data:
-                rrf_scores_data[doc_id] = hit
-            rrf_scores_data[doc_id]['vector_score'] = hit["_score"]
-
-    # 4. 排序并截取 Top K
-    # 按 RRF 分数降序排序
-    sorted_ids = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:TOP_K]
-    final_results = []
-    for doc_id, score in sorted_ids:
-        hit = rrf_scores_data[doc_id]
-        hit['_score'] = score # 覆盖为计算出的 RRF 分数
-        if 'vector_score' in hit and score>rrf_threshold:
-            final_results.append({
-                'rrf_score':score,
-                'vector_score':hit['vector_score'],
-                'tag_path':hit['_source']['tag_path']
-            })
-        
-    return final_results
-
 if __name__ == "__main__":
     results = search_all()
     for r in results: