Ver código fonte

优化打标逻辑 移除TOPN

jiayongqiang 4 dias atrás
pai
commit
b58e922695

+ 0 - 1
agent/config.ini

@@ -29,7 +29,6 @@ username = elastic
 password = 123456
 
 [app]
-top_k = 2
 port=9876
 concurrence=1
 

BIN
agent/logs/aitagging-app.2026-05-12_10-11-58_363564.log.zip


+ 1 - 1
agent/pyproject.toml

@@ -1,6 +1,6 @@
 [project]
 name = "agent"
-version = "0.1.8.1"
+version = "0.1.8.2"
 description = "Default template for PDM package"
 authors = [
     {name = "jiayongqiang", email = "15936285643@163.com"},

+ 1 - 2
agent/src/agent/agent.py

@@ -8,7 +8,6 @@ import uuid
 from datetime import datetime
 from agent.logger import logger
 config = get_config_path()
-TOP_K = config['app']['top_k']
 
 base_url = config['llm']['base_url']
 api_key_env_var = config['llm']['api_key']
@@ -53,7 +52,7 @@ class Lables(BaseModel):
     labels: list[Lable] = Field(description="List of optimized labels after reflection.")
     
 
-async def reflect_check(context: str,is_marine: bool, labels: list):
+async def reflect_check(context: str,is_marine: bool, labels: list[str]):
     agent = create_agent(
         model = llm, 
         response_format=Lables

+ 2 - 3
agent/src/agent/api_inner.py

@@ -10,7 +10,6 @@ from agent.logger import logger
 from agent.core.config import get_config_path
 from typing import Optional
 config = get_config_path()
-TOP_K = config['app']['top_k']
 
 router = APIRouter(prefix="/v1", tags=["平台内部接口"])
 
@@ -65,7 +64,7 @@ def load_tag_2_es(tag_ids: list[str]):
         "tag_reg": label[4],
         "tag_prompt": label[9],
         "category_id":label[10],
-        "tag_vector": get_embeddings([(label[6] or '')+(label[4] or '')])[0] 
+        "tag_vector": get_embeddings([(label[6] or '')+"\n"+(label[3] or '')])[0] 
     } for label in labels])
     return labels
 
@@ -99,7 +98,7 @@ def load_category_2_es(category_id: str):
         "tag_reg": label[4],
         "tag_prompt": label[9],
         "category_id":label[10],
-        "tag_vector": get_embeddings([(label[6] or '')+(label[4] or '')])[0] 
+        "tag_vector": get_embeddings([(label[6] or '')+"\n"+(label[3] or '')])[0] 
     } for label in labels])
     return labels
 

+ 8 - 23
agent/src/agent/api_outter.py

@@ -5,7 +5,7 @@ from requests import request
 from agent.core.sign_check import check
 import agent.core.dao as dao
 from agent.core.vector import get_embeddings
-from agent.core.es import hybrid_search
+from agent.core.es import hybrid_search,bm25_vector_search
 from agent.agent import reflect_check
 from datetime import datetime
 import uuid
@@ -19,7 +19,6 @@ from agent.core.tagging_state import TAGGING_STATE
 import time
 
 config = get_config_path()
-TOP_K = config['app']['top_k']
 CONCURRENCE = int(config['app']['concurrence'])
 background_semaphore = asyncio.Semaphore(CONCURRENCE)
 
@@ -71,19 +70,6 @@ async def execute_reg(log_id:str,tag_category_id:str,phrase: str)-> list:
     labels = dao.query(sql)
     # 循环调用reg匹配phrase,匹配成功则返回标签id
     result = [label[0] for label in labels]
-    # try:
-        # for label in labels:
-        #     reg = label[1]
-        #     if reg is not None:
-        #         pattern = re.compile(reg, re.VERBOSE)
-        #         # logger.info(f"Executing regex for label_id {label[0]}: {reg}")
-        #         if pattern.match(phrase):
-        #             logger.info(f"Executing regex for label_id {label[0]}: {reg} true")
-        #             result.append(label[0])
-        #     else:
-        #         result.append(label[0])
-    # except Exception as e:
-    #     logger.error(f"Regex execution failed: {e}")
     dao.execute(
             """UPDATE aitag_tag_log SET reg_result = %s WHERE id = %s""",
             (str(result), log_id)
@@ -91,15 +77,12 @@ async def execute_reg(log_id:str,tag_category_id:str,phrase: str)-> list:
     logger.info(f"[{log_id}] Regex filtering result: {result}")
     return result
 
-def vector_similarity_search(phrase: str, ids:list)-> list:
+def vector_similarity_search(phrase: str)-> list:
     logger.info("Starting vector similarity search...")
     # 这里应该调用向量数据库进行相似度检索,返回相关标签id列表
     query = get_embeddings([phrase])[0]
-    results = hybrid_search(ids, query, top_k=TOP_K)
-    # return [{"id": r["_id"], "score": r["_score"], "tag_prompt": r["_source"]["tag_prompt"],"tag_name": r["_source"]["tag_name"],"tag_code": r["_source"]["tag_code"]} for r in results]
-    r = [{"id": r["_id"], "tag_remark":r["_source"]["tag_remark"], "tag_prompt": r["_source"]["tag_prompt"],"tag_name": r["_source"]["tag_name"],"tag_code": r["_source"]["tag_code"],"tag_path": r["_source"]["tag_path"],"category_id": r["_source"]["category_id"]   } for r in results]
-    # logger.info(f"{phrase} Vector search result: {r}")
-    return r
+    results = bm25_vector_search(phrase,query)
+    return results
 
 def init_tag_log(request: TaggingRequest):
     id = uuid.uuid4().hex
@@ -192,11 +175,13 @@ async def run_ai_pipeline(log_id: str, tag_category_id: str, phrase: str, instuc
             # step1: 正则过滤
             result = await execute_reg(log_id,tag_category_id,phrase)
             # step2: 向量检索
-            result = vector_similarity_search(phrase, result)
+            if not result or len(result) == 0:
+                result = vector_similarity_search(phrase)
             # step3: LLM 打标
             if result:
                 try:
-                    result = await reflect_check(phrase,is_marine, result)
+                    tags = dao.query_dict(""" select id,tag_nm as tag_name,tag_code, tag_path,category_id,tag_prompt from aitag_tag_info where id in %s """, (tuple(result),))
+                    result = await reflect_check(phrase,is_marine, tags)
                 except Exception as e:
                     logger.error(f"LLM reflection check failed: {e}")
                     result = None

+ 0 - 1
agent/src/agent/core/dao.py

@@ -5,7 +5,6 @@ from contextlib import contextmanager
 
 from agent.core.config import get_config_path
 config = get_config_path()
-TOP_K = config['app']['top_k']
 
 host = config['database']['host']
 port = int(config['database']['port'])

+ 45 - 18
agent/src/agent/core/es.py

@@ -2,7 +2,6 @@ from elasticsearch import Elasticsearch,helpers
 from agent.logger import logger
 from agent.core.config import get_config_path
 config = get_config_path()
-TOP_K = int(config['app']['top_k'])
 
 url = config['es']['url']
 DIMS = int(config['embedding']['default_dims'])
@@ -127,32 +126,60 @@ def bulk_upsert(documents):
         logger.error(f"Bulk upsert failed: {e}")
         raise
 
-def hybrid_search( target_doc_ids, query_vector, top_k=2):
-    logger.info(f"Performing hybrid search with target_doc_ids: {target_doc_ids}, query_vector: {len(query_vector)}, top_k: {top_k}")
+def hybrid_search(query_vector):
+    logger.info(f"Performing hybrid search with query_vector: {len(query_vector)}")
     knn={
-            "field": "tag_vector",
-            "query_vector": query_vector,
-            "k": top_k,           # 不超过候选集大小
-            "num_candidates": 100,
-            "filter": {
-                "terms": {
-                    "_id": target_doc_ids
-                }
+        "field": "tag_vector",
+        "query_vector": query_vector,
+        "k": 50,           # 不超过候选集大小
+        "num_candidates": 100
+    }
+    response = es.search(
+        index=INDEX_NAME,
+        knn=knn,
+        size=50,
+        min_score=0.7
+    )
+    r =  response["hits"]["hits"]
+    return [item["_id"] for item in r]
+
+def bm25_vector_search(query:str,query_vector,rrf_score_threshold=0.016):
+    resp_bm25 = es.search(
+        index=INDEX_NAME,
+        size=10,
+        query={
+            "multi_match": {
+                "query": query,
+                "fields": ["tag_path", "tag_remark", "tag_prompt"]
             }
         }
-    if not target_doc_ids:
+    )
+    resp_vector = es.search(
+        index=INDEX_NAME,
+        size=10,
         knn={
             "field": "tag_vector",
             "query_vector": query_vector,
-            "k": top_k,           # 不超过候选集大小
+            "k": 10,
             "num_candidates": 100
         }
-  
-    response = es.search(
-        index=INDEX_NAME,
-        knn=knn
     )
-    return response["hits"]["hits"]
+    rrf_scores = {}
+    for rank, hit in enumerate(resp_bm25['hits']['hits'], start=1):
+        hit["_source"]["tag_vector"] = None
+        doc_id = hit['_id']
+        score = rrf_scores.get(doc_id, 0.0) 
+        rrf_scores[doc_id] = score + (1.0 / (RRF_CONST + rank))
+    
+    for rank, hit in enumerate(resp_vector['hits']['hits'], start=1):
+        hit["_source"]["tag_vector"] = None
+        doc_id = hit['_id']
+        score = rrf_scores.get(doc_id, 0.0)
+        rrf_scores[doc_id] = score + (1.0 / (RRF_CONST + rank))
+    
+    sorted_ids = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
+    result =  [id for id,score in sorted_ids if score>rrf_score_threshold]
+    return result[:10]
 
 if __name__ == "__main__":
     results = search_all()

+ 0 - 1
agent/src/agent/core/vector.py

@@ -3,7 +3,6 @@ import json
 from agent.logger import logger
 from agent.core.config import get_config_path
 config = get_config_path()
-TOP_K = config['app']['top_k']
 
 model = config['embedding']['model']
 base_url = config['embedding']['base_url']

+ 0 - 0
agent/tests/load_tag_2_es.py


+ 8 - 17
agent/tests/test_bm25.py

@@ -1,19 +1,10 @@
-from agent.core.es import bm25_vector_search
-from agent.core.vector import get_embeddings
+from agent.api_outter import vector_similarity_search
+import agent.core.dao as dao
 
-phrase = "渔业产品批发,海带边周转"
-phrase_vector = get_embeddings([phrase])[0]
-r = bm25_vector_search(phrase,phrase_vector)
+phrase = "职业:民宿服务; 投向:民宿服务; 用途:经营民宿"
+r = vector_similarity_search(phrase)
 print(r)
-# from openpyxl import load_workbook
-# workbook = load_workbook('data/样本数据标注后-完整版.xlsx', data_only=True)
-# sheet = workbook['核对结果'] 
-# for row in sheet.iter_rows(min_row=2,min_col=2, max_col=3, values_only=True):
-#     touxiang = row[0] # B列的值
-#     yongtu = row[1] # C列的值
-#     phrase = f"投向:{touxiang},用途:{yongtu}"
-#     phrase_vector = get_embeddings([phrase])[0]
-#     r = bm25_vector_search(phrase,phrase_vector)
-#     print(f"输入:{phrase}")
-#     print(f"输出:{r}")
-# workbook.close()
+
+tags = dao.query_dict(""" select id,tag_nm as tag_name,tag_code, tag_path,category_id,tag_prompt from aitag_tag_info where id in %s """, (tuple(r),))
+print(tags)
+

+ 10 - 4
agent/tests/test_sync_category.py

@@ -1,10 +1,16 @@
 import requests
 
-res = requests.post("http://10.192.72.13:9876/api/aitag/admin/v1/synchronize_category", json={
-    "category_id": "f47ac10b-58cc-4372-a567-0e02b2c3d479"
-})
-# res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_category", json={
+# res = requests.post("http://10.192.72.13:9876/api/aitag/admin/v1/synchronize_category", json={
 #     "category_id": "f47ac10b-58cc-4372-a567-0e02b2c3d479"
 # })
+res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_category", json={
+    "category_id": "f47ac10b-58cc-4372-a567-0e02b2c3d479"
+})
+res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_category", json={
+    "category_id": "0a2dc889-6205-4cb2-be31-d67c6390a0d6"
+})
+res = requests.post("http://localhost:9876/api/aitag/admin/v1/synchronize_category", json={
+    "category_id": "cd4de5d4-491f-4779-8d96-9246c861e907"
+})
 print(res.text)