|
|
@@ -5,7 +5,7 @@ from requests import request
|
|
|
from agent.core.sign_check import check
|
|
|
import agent.core.dao as dao
|
|
|
from agent.core.vector import get_embeddings
|
|
|
-from agent.core.es import hybrid_search
|
|
|
+from agent.core.es import hybrid_search,bm25_vector_search
|
|
|
from agent.agent import reflect_check
|
|
|
from datetime import datetime
|
|
|
import uuid
|
|
|
@@ -19,7 +19,6 @@ from agent.core.tagging_state import TAGGING_STATE
|
|
|
import time
|
|
|
|
|
|
config = get_config_path()
|
|
|
-TOP_K = config['app']['top_k']
|
|
|
CONCURRENCE = int(config['app']['concurrence'])
|
|
|
background_semaphore = asyncio.Semaphore(CONCURRENCE)
|
|
|
|
|
|
@@ -71,19 +70,6 @@ async def execute_reg(log_id:str,tag_category_id:str,phrase: str)-> list:
|
|
|
labels = dao.query(sql)
|
|
|
# 循环调用reg匹配phrase,匹配成功则返回标签id
|
|
|
result = [label[0] for label in labels]
|
|
|
- # try:
|
|
|
- # for label in labels:
|
|
|
- # reg = label[1]
|
|
|
- # if reg is not None:
|
|
|
- # pattern = re.compile(reg, re.VERBOSE)
|
|
|
- # # logger.info(f"Executing regex for label_id {label[0]}: {reg}")
|
|
|
- # if pattern.match(phrase):
|
|
|
- # logger.info(f"Executing regex for label_id {label[0]}: {reg} true")
|
|
|
- # result.append(label[0])
|
|
|
- # else:
|
|
|
- # result.append(label[0])
|
|
|
- # except Exception as e:
|
|
|
- # logger.error(f"Regex execution failed: {e}")
|
|
|
dao.execute(
|
|
|
"""UPDATE aitag_tag_log SET reg_result = %s WHERE id = %s""",
|
|
|
(str(result), log_id)
|
|
|
@@ -91,15 +77,12 @@ async def execute_reg(log_id:str,tag_category_id:str,phrase: str)-> list:
|
|
|
logger.info(f"[{log_id}] Regex filtering result: {result}")
|
|
|
return result
|
|
|
|
|
|
-def vector_similarity_search(phrase: str, ids:list)-> list:
|
|
|
+def vector_similarity_search(phrase: str)-> list:
|
|
|
logger.info("Starting vector similarity search...")
|
|
|
# 这里应该调用向量数据库进行相似度检索,返回相关标签id列表
|
|
|
query = get_embeddings([phrase])[0]
|
|
|
- results = hybrid_search(ids, query, top_k=TOP_K)
|
|
|
- # return [{"id": r["_id"], "score": r["_score"], "tag_prompt": r["_source"]["tag_prompt"],"tag_name": r["_source"]["tag_name"],"tag_code": r["_source"]["tag_code"]} for r in results]
|
|
|
- r = [{"id": r["_id"], "tag_remark":r["_source"]["tag_remark"], "tag_prompt": r["_source"]["tag_prompt"],"tag_name": r["_source"]["tag_name"],"tag_code": r["_source"]["tag_code"],"tag_path": r["_source"]["tag_path"],"category_id": r["_source"]["category_id"] } for r in results]
|
|
|
- # logger.info(f"{phrase} Vector search result: {r}")
|
|
|
- return r
|
|
|
+ results = bm25_vector_search(phrase,query)
|
|
|
+ return results
|
|
|
|
|
|
def init_tag_log(request: TaggingRequest):
|
|
|
id = uuid.uuid4().hex
|
|
|
@@ -192,11 +175,13 @@ async def run_ai_pipeline(log_id: str, tag_category_id: str, phrase: str, instuc
|
|
|
# step1: 正则过滤
|
|
|
result = await execute_reg(log_id,tag_category_id,phrase)
|
|
|
# step2: 向量检索
|
|
|
- result = vector_similarity_search(phrase, result)
|
|
|
+ if not result or len(result) == 0:
|
|
|
+ result = vector_similarity_search(phrase)
|
|
|
# step3: LLM 打标
|
|
|
if result:
|
|
|
try:
|
|
|
- result = await reflect_check(phrase,is_marine, result)
|
|
|
+ tags = dao.query_dict(""" select id,tag_nm as tag_name,tag_code, tag_path,category_id,tag_prompt from aitag_tag_info where id in %s """, (tuple(result),))
|
|
|
+ result = await reflect_check(phrase,is_marine, tags)
|
|
|
except Exception as e:
|
|
|
logger.error(f"LLM reflection check failed: {e}")
|
|
|
result = None
|