浏览代码

测试BM-25

jiayongqiang 6 天之前
父节点
当前提交
00b35fd169

+ 2 - 2
agent/config.ini

@@ -1,5 +1,5 @@
 [database]
-host = 10.192.72.11  ;数据库地址
+host = 10.192.72.11  
 port = 4321
 user = root
 password = admin
@@ -25,7 +25,7 @@ url = http://10.192.72.13:9200
 [app]
 top_k = 2
 port=9876
-concurrence=3
+concurrence=1
 
 [logging]
 log_path= logs/aitagging-app.log

二进制
agent/data/样本数据标注后-完整版.xlsx


二进制
agent/dist/agent-0.1.5-py3-none-any.whl


二进制
agent/dist/agent-0.1.5.tar.gz


二进制
agent/logs/aitagging-app.2026-03-17_10-36-37_002744.log.zip


二进制
agent/logs/aitagging-app.2026-03-23_14-44-11_403520.log.zip


二进制
agent/logs/aitagging-app.2026-03-26_18-19-50_023840.log.zip


+ 1 - 1
agent/pyproject.toml

@@ -1,6 +1,6 @@
 [project]
 name = "agent"
-version = "0.1.4"
+version = "0.1.5"
 description = "Default template for PDM package"
 authors = [
     {name = "jiayongqiang", email = "15936285643@163.com"},

+ 33 - 9
agent/src/agent/api_outter.py

@@ -16,6 +16,8 @@ import json
 from agent.logger import logger
 from agent.core.config import get_config_path
 import asyncio
+from agent.core.tagging_state import TAGGING_STATE
+import time
 
 config = get_config_path()
 TOP_K = config['app']['top_k']
@@ -31,6 +33,7 @@ class TaggingRequest(BaseModel):
     business_attr: str = Field(..., description="业务属性")
     phrase: str = Field(..., description="需要打标签的文本")
     tag_category_id: Optional[str] = Field(None, description="指定标签类别ID,默认为空表示不指定")
+    esb_seq_no: Optional[str] = Field(None,description="ESB流水号")
 
 async def execute_reg(log_id:str,tag_category_id:str,phrase: str)-> list:
     sql = f"""select 
@@ -73,7 +76,7 @@ def vector_similarity_search(phrase: str, ids:list)-> list:
     results = hybrid_search(ids, query, top_k=TOP_K)
     # return [{"id": r["_id"], "score": r["_score"], "tag_prompt": r["_source"]["tag_prompt"],"tag_name": r["_source"]["tag_name"],"tag_code": r["_source"]["tag_code"]} for r in results]
     r = [{"id": r["_id"], "tag_remark":r["_source"]["tag_remark"], "tag_prompt": r["_source"]["tag_prompt"],"tag_name": r["_source"]["tag_name"],"tag_code": r["_source"]["tag_code"],"tag_path": r["_source"]["tag_path"],"category_id": r["_source"]["category_id"]   } for r in results]
-    logger.info(f"{phrase} Vector search result: {r}")
+    # logger.info(f"{phrase} Vector search result: {r}")
     return r
 
 def init_tag_log(request: TaggingRequest):
@@ -86,20 +89,36 @@ def init_tag_log(request: TaggingRequest):
     # 业务编号如果以test开头,则tag_scope = 1,否则都是0
     tag_scope = 1 if request.business_attr.startswith("test") else 0
     dao.execute(
-        """INSERT INTO aitag_tag_log (id,app_id, insert_time, business_attr, phrase, state, tag_scope) VALUES (%s, %s, %s, %s, %s, %s, %s)""",
-        (id,request.app_id, datetime.now(), request.business_attr, request.phrase, 0, tag_scope)
+        """INSERT INTO aitag_tag_log (id,app_id, insert_time, business_attr, phrase, state, tag_scope,esb_seq_no) VALUES (%s, %s, %s, %s, %s, %s, %s,%s)""",
+        (id,request.app_id, datetime.now(), request.business_attr, request.phrase, TAGGING_STATE.REQUEST.value, tag_scope, request.esb_seq_no)
     )
     return id
 
-def update_tag_log(id:str, result:str):
+def end_tagging(id:str, result:str):
     dao.execute(
             """UPDATE aitag_tag_log SET state = %s, result = %s, ai_result_endtime = %s WHERE id = %s""",
-            (1, result, datetime.now(), id)
+            (TAGGING_STATE.END.value, result, datetime.now(), id)
         )
 
+def fail_tagging(id:str):
+    dao.execute(
+            """UPDATE aitag_tag_log SET state = %s,  ai_result_endtime = %s WHERE id = %s""",
+            (TAGGING_STATE.FAIL.value,  datetime.now(), id)
+        )
+
+def start_tagging(id:str):
+    dao.execute(
+            """UPDATE aitag_tag_log SET state = %s,  ai_result_starttime = %s WHERE id = %s""",
+            (TAGGING_STATE.BEGIN.value, datetime.now(),  id)
+        )
+
+
 async def run_ai_pipeline(log_id: str, tag_category_id: str, phrase: str):
     try:
         async with background_semaphore:
+            logger.info(f"开始打标:{log_id}, {phrase}")
+            # step0: 开始打标
+            start_tagging(log_id)
             # step1: 正则过滤
             result = await execute_reg(log_id,tag_category_id,phrase)
             # step2: 向量检索
@@ -111,17 +130,22 @@ async def run_ai_pipeline(log_id: str, tag_category_id: str, phrase: str):
                 except Exception as e:
                     logger.error(f"LLM reflection check failed: {e}")
                     result = None
+                    fail_tagging(log_id)
+                    return
             # step4: 更新数据库
             # 如果result是个空集合,插入None
-            update_tag_log(log_id, result if result else None)
+            end_tagging(log_id, result if result else None)
             
     except Exception as e:
         logger.error(f"[{log_id}] Pipeline failed: {e}")
-        update_tag_log(log_id, None)
+        fail_tagging(log_id)
+
 
+# 0:请求已接收;1:打标完成; 2:客户经理已经确认;3,结果已推送; 
+# 4:开始打标, 5:打标失败
 @router.post("/tagging")
 async def ai_tagging(request: TaggingRequest,background_tasks: BackgroundTasks):
-    logger.info(f"app_id: {request.app_id}, timestamp: {request.timestamp}, sign: {request.sign}, business_attr: {request.business_attr}, phrase: {request.phrase}")
+    logger.info(f"esb_seq_no: {request.esb_seq_no}, business_attr: {request.business_attr}, phrase: {request.phrase}")
     # 数据库中插入一条记录,记录请求的app_id、timestamp、business_attr、phrase等信息,状态设为“处理中”,后续步骤完成后更新状态和结果
     id = init_tag_log(request)
     # 执行异步任务
@@ -162,7 +186,7 @@ def ai_feedback(feedback_request: FeedbackRequest):
     # 这里将用户的反馈信息保存到数据库中aitag_tag_log,供后续分析和模型优化使用
     dao.execute(
         """update aitag_tag_log set feedback = %s, feedback_result = %s, feedback_time = %s, feedback_user_id = %s, feedback_user_nm = %s, contract_no = %s, feedback_user_org = %s, feedback_user_endpoint = %s, state = %s where business_attr = %s""",   
-        (feedback_request.feedback, feedback_request.feedback_result, datetime.now(), feedback_request.user_id, feedback_request.user_nm, feedback_request.contract_no, feedback_request.user_org, feedback_request.user_endpoint, 2, feedback_request.business_attr)
+        (feedback_request.feedback, feedback_request.feedback_result, datetime.now(), feedback_request.user_id, feedback_request.user_nm, feedback_request.contract_no, feedback_request.user_org, feedback_request.user_endpoint, TAGGING_STATE.FEEDBACK.value, feedback_request.business_attr)
     )
     return {"code": 200, "message": "Feedback received successfully"}
 

+ 94 - 4
agent/src/agent/core/es.py

@@ -3,11 +3,13 @@ from agent.logger import logger
 
 from agent.core.config import get_config_path
 config = get_config_path()
-TOP_K = config['app']['top_k']
+TOP_K = int(config['app']['top_k'])
 
 url = config['es']['url']
 DIMS = int(config['embedding']['default_dims'])
 
+RRF_CONST:int = 60
+
 es = Elasticsearch(
     hosts=[url],
     retry_on_timeout=True,
@@ -19,11 +21,40 @@ INDEX_NAME = "ai-tagging"
 if not es.indices.exists(index=INDEX_NAME):
     es.indices.create(
         index=INDEX_NAME,
+        settings={
+            "index": {
+                "similarity": {
+                    "my_bm25": { 
+                    "type": "BM25",
+                    "b": "0.9",   
+                    "k1": "0.6"   
+                    }
+                },
+                "analysis": {
+                    "analyzer": {
+                        "my_custom_analyzer": {
+                            "type": "custom",
+                            "tokenizer": "standard", 
+                            "filter": [
+                                "lowercase", 
+                                "my_stop_filter" 
+                            ]
+                        }
+                    },
+                    "filter": {
+                        "my_stop_filter": {
+                            "type": "stop",
+                            "stopwords": ["的", "了", "是", "在", "职","业","投","向","海","洋","水"] 
+                        }
+                    }
+                }
+            },
+        },
         mappings={
             "properties": {
                 "tag_code": {"type": "text"},
                 "tag_name": {"type": "text"},
-                "tag_path": {"type": "text"},
+                "tag_path": {"type": "text","similarity": "my_bm25"},
                 "tag_level": {"type": "integer"},
                 "tag_remark": {"type": "text"},
                 "tag_reg": {"type": "text"},
@@ -58,7 +89,8 @@ def delete_category_documents(category_id):
             }
         }
     }
-    es.options(request_timeout=60).delete_by_query(index=INDEX_NAME, body=query)
+    if es.indices.exists(index=INDEX_NAME):
+        es.options(request_timeout=60).delete_by_query(index=INDEX_NAME, body=query)
     
 def search_documents(query):
     response = es.search(
@@ -116,17 +148,75 @@ def hybrid_search( target_doc_ids, query_vector, top_k=2):
             "k": top_k,           # 不超过候选集大小
             "num_candidates": 100
         }
-    print(f"Constructed knn query: {knn}")
+  
     response = es.search(
         index=INDEX_NAME,
         knn=knn
     )
     return response["hits"]["hits"]
 
+def bm25_vector_search(query:str,query_vector,vector_threshold =0.76,rrf_threshold=0.031):
+    resp_bm25 = es.search(
+        index=INDEX_NAME,
+        size=max(TOP_K, 10),
+        query={"match": {"tag_path": query}}
+    )
+    resp_vector = es.search(
+        index=INDEX_NAME,
+        size=max(TOP_K,10),
+        knn={
+            "field": "tag_vector",
+            "query_vector": query_vector,
+            "k": max(TOP_K, 10),
+            "num_candidates": 100
+        }
+    )
+    rrf_scores = {}
+    rrf_scores_data = {}
+    for rank, hit in enumerate(resp_bm25['hits']['hits'], start=1):
+        hit["_source"]["tag_vector"] = None
+        print("a:"+hit["_source"]["tag_path"])
+        doc_id = hit['_id']
+        score = rrf_scores.get(doc_id, 0.0)
+        rrf_scores[doc_id] = score + (1.0 / (RRF_CONST + rank))
+        if doc_id not in rrf_scores_data:
+             rrf_scores_data[doc_id] = hit
+
+    for rank, hit in enumerate(resp_vector['hits']['hits'], start=1):
+        hit["_source"]["tag_vector"] = None
+        print("b:"+hit["_source"]["tag_path"]+";"+str(hit["_score"]))
+        if hit["_score"]>vector_threshold:
+            doc_id = hit['_id']
+            score = rrf_scores.get(doc_id, 0.0)
+            rrf_scores[doc_id] = score + (1.0 / (RRF_CONST + rank))
+            if doc_id not in rrf_scores_data:
+                rrf_scores_data[doc_id] = hit
+            rrf_scores_data[doc_id]['vector_score'] = hit["_score"]
+
+    # 4. 排序并截取 Top K
+    # 按 RRF 分数降序排序
+    sorted_ids = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:TOP_K]
+    final_results = []
+    for doc_id, score in sorted_ids:
+        hit = rrf_scores_data[doc_id]
+        hit['_score'] = score # 覆盖为计算出的 RRF 分数
+        if 'vector_score' in hit and score>rrf_threshold:
+            final_results.append({
+                'rrf_score':score,
+                'vector_score':hit['vector_score'],
+                'tag_path':hit['_source']['tag_path']
+            })
+        
+    return final_results
+
 if __name__ == "__main__":
     results = search_all()
     for r in results:
         # 排除 tag_vector 字段的输出
         print({k: v for k, v in r["_source"].items() if k != "tag_vector"})
     print(f"Search results: {len(results)}")
+
+
+
+
     

+ 11 - 0
agent/src/agent/core/tagging_state.py

@@ -0,0 +1,11 @@
+# 0:请求已接收;1:打标完成; 2:客户经理已经确认;3,结果已推送; 
+# 4:开始打标, 5:打标失败
+from enum import Enum
+
+class TAGGING_STATE(Enum):
+    REQUEST = 0
+    END = 1
+    FEEDBACK = 2
+    PUSHED = 3
+    BEGIN = 4
+    FAIL = 5

+ 0 - 48
agent/tests/load_tag_2_es.py

@@ -1,48 +0,0 @@
-import agent.core.es as es
-import agent.core.dao as dao
-import numpy as np
-
-if __name__ == "__main__":
-    tags = dao.query("""select 
-                     tti.id,
-                    tti.tag_nm,
-                    tti.tag_code,
-                    tti.tag_remark,
-                    tti.reg,
-                    tti.level,
-                    tti.tag_path,
-                    ttc.category_nm,
-                    ttc.category_code
-                    from fjnx.tab_tag_info tti left join fjnx.tab_tag_category  ttc 
-                    on tti.category_id = ttc.id 
-                    where ttc.is_delete=0 and tti.is_delete=0 and ttc.state =  0 and tti.state = 0 and tti.`level` = ttc.visibility_level""")
-    for tag in tags:
-        tag_id = tag[0]         
-        tag_name = tag[1]
-        tag_code = tag[2]
-        tag_remark = tag[3]
-        tag_reg = tag[4]
-        tag_level = tag[5]
-        tag_path = tag[6]
-        tag_category = tag[7]
-        tag_category_code = tag[8]
-        print(tag_name, tag_code, tag_remark, tag_reg, tag_level, tag_path, tag_category, tag_category_code)
-        # 生成随机向量,实际应用中应使用真实的向量
-        tag_vector = np.random.rand(es.DIMS).tolist()
-        document = {
-            "tag_code": tag_code,
-            "tag_name": tag_name,   
-            "tag_remark": tag_remark,
-            "tag_reg": tag_reg,
-            "tag_level": tag_level,
-            "tag_path": tag_path,   
-            "tag_category": tag_category,
-            "tag_category_code": tag_category_code, 
-            "tag_vector": tag_vector    
-        }
-        es.upsert_document(tag_id, document)
-        print(f"Upserted document with ID {tag_id}")
-        res = es.search_documents({
-            "tag_name": "海水养殖"
-        })
-        print(res)

+ 19 - 0
agent/tests/test_bm25.py

@@ -0,0 +1,19 @@
+from agent.core.es import bm25_vector_search
+from agent.core.vector import get_embeddings
+
+phrase = "渔业产品批发,海带边周转"
+phrase_vector = get_embeddings([phrase])[0]
+r = bm25_vector_search(phrase,phrase_vector)
+print(r)
+# from openpyxl import load_workbook
+# workbook = load_workbook('data/样本数据标注后-完整版.xlsx', data_only=True)
+# sheet = workbook['核对结果'] 
+# for row in sheet.iter_rows(min_row=2,min_col=2, max_col=3, values_only=True):
+#     touxiang = row[0] # B列的值
+#     yongtu = row[1] # C列的值
+#     phrase = f"投向:{touxiang},用途:{yongtu}"
+#     phrase_vector = get_embeddings([phrase])[0]
+#     r = bm25_vector_search(phrase,phrase_vector)
+#     print(f"输入:{phrase}")
+#     print(f"输出:{r}")
+# workbook.close()

+ 10 - 10
agent/tests/test_inner_api.py

@@ -15,7 +15,7 @@ ids = [
 "20260200013",
 "20260200014",
 "20260200015",
-"202602000016",
+"20260200016",
 "20260200017",  
 "20260200018",
 "20260200019",
@@ -85,17 +85,17 @@ ids = [
 "20260200083",
 ]
 
-res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_tag", json={
-    "tag_ids": ids
-})
-print(res.text)
+# res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_tag", json={
+#     "tag_ids": ids
+# })
+# print(res.text)
 
-res = requests.post("http://localhost:9876/api/aitag/admin/v1/delete_tag", json={
-    "tag_ids": ["20260200083"]
-})
-print(res.text)
+# res = requests.post("http://localhost:9876/api/aitag/admin/v1/delete_tag", json={
+#     "tag_ids": ["20260200083"]
+# })
+# print(res.text)
 
-res = requests.post("http://10.192.72.13:9876/api/aitag/admin/v1/synchronize_category", json={
+res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_category", json={
     "category_id": "f47ac10b-58cc-4372-a567-0e02b2c3d479"
 })
 print(res.text)

+ 3 - 0
agent/tests/test_state.py

@@ -0,0 +1,3 @@
+from agent.core.tagging_state import TAGGING_STATE
+
+print(TAGGING_STATE.REQUEST.value)

+ 1 - 1
agent/tests/test_sync_category.py

@@ -1,6 +1,6 @@
 import requests
 
-res = requests.post("http://10.192.72.13:9876/api/aitag/admin/v1/synchronize_category", json={
+res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_category", json={
     "category_id": "f47ac10b-58cc-4372-a567-0e02b2c3d479"
 })
 print(res.text)

+ 2 - 1
agent/tests/test_tagging.py

@@ -3,10 +3,11 @@ import logging
 logging.basicConfig(level=logging.INFO, force=True,format='%(asctime)s - %(levelname)s - %(message)s')
 logging.info("app starting!")
 
-res = requests.post("http://10.192.72.13:9876/api/aitag/v1/tagging", json={
+res = requests.post("http://localhost:9876/api/aitag/v1/tagging", json={
     # "app_id": "test_app",
     # "timestamp": 1234567890,
     # "sign": "test_sign",
+    "esb_seq_no":"abc",
     "business_attr": "test_attr",
     "phrase": "职业:水产养殖人员 投向:内陆养殖 用途:养殖鲍鱼"
 })

+ 1 - 0
agent/tests/test_tagging_batch.py

@@ -13,6 +13,7 @@ with open("tests/test_data.txt", "r", encoding="utf-8") as f:
             phrase = "职业:"+phrase[0]+" "+"投向:"+phrase[1]+" "+"用途:"+phrase[2]
 
         res = requests.post("http://10.192.72.13:9876/api/aitag/v1/tagging", json={
+            "esb_seq_no": uuid.uuid4().hex,
             "business_attr": uuid.uuid4().hex,
             "phrase": phrase
         })