|
|
@@ -26,7 +26,7 @@ CONCURRENCE = int(config['app']['concurrence'])
|
|
|
# background_semaphore = threading.BoundedSemaphore(CONCURRENCE)
|
|
|
executor = ThreadPoolExecutor(max_workers=CONCURRENCE)
|
|
|
ESB_CALLBACK = config['app']['esb_callback']
|
|
|
-
|
|
|
+TOP_N = int(config['app'].get("top_n",5))
|
|
|
router = APIRouter(prefix="/v1", tags=["AI Tagging"])
|
|
|
|
|
|
class TaggingRequest(BaseModel):
|
|
|
@@ -61,20 +61,32 @@ class TaggingRequest(BaseModel):
|
|
|
return None
|
|
|
return str(v)
|
|
|
|
|
|
+def reg_match(query:str, reglist:list[any]):
|
|
|
+ result = []
|
|
|
+ for id,reg in reglist:
|
|
|
+ try:
|
|
|
+ if re.search(reg, query):
|
|
|
+ result.append(id)
|
|
|
+ except re.error:
|
|
|
+ continue
|
|
|
+ return result
|
|
|
+
|
|
|
def execute_reg(log_id:str,tag_category_id:str,phrase: str)-> list:
|
|
|
+ logger.info(f"[{log_id}] Regex filtering start!")
|
|
|
sql = f"""select
|
|
|
tti.id,
|
|
|
tti.reg
|
|
|
from aitag_tag_info tti left join aitag_tag_category ttc
|
|
|
on tti.category_id = ttc.id
|
|
|
where ttc.is_delete=0 and tti.is_delete=0 and ttc.state = 0 and tti.state = 0 and tti.tag_level = ttc.visibility_level
|
|
|
- and '{phrase}' ~ tti.reg and tti.reg is not null and length(tti.reg) > 0
|
|
|
+ and tti.reg is not null and length(tti.reg) > 0
|
|
|
"""
|
|
|
if tag_category_id:
|
|
|
sql += f""" and ttc.id = '{tag_category_id}'"""
|
|
|
labels = dao.query(sql)
|
|
|
+ result = reg_match(phrase, labels)
|
|
|
# 循环调用reg匹配phrase,匹配成功则返回标签id
|
|
|
- result = [label[0] for label in labels]
|
|
|
+ # result = [label[0] for label in labels]
|
|
|
if result and len(result) > 0:
|
|
|
dao.execute(
|
|
|
"""UPDATE aitag_tag_log SET reg_result = %s,tagging_channel = %s WHERE id = %s""",
|
|
|
@@ -83,11 +95,18 @@ def execute_reg(log_id:str,tag_category_id:str,phrase: str)-> list:
|
|
|
logger.info(f"[{log_id}] Regex filtering result: {result}")
|
|
|
return result
|
|
|
|
|
|
-def vector_similarity_search(log_id:str,phrase: str)-> list:
|
|
|
+def vector_similarity_search(log_id:str,phrase: str,tag_ids:list[str]=None)-> list:
|
|
|
logger.info("Starting vector similarity search...")
|
|
|
# 这里应该调用向量数据库进行相似度检索,返回相关标签id列表
|
|
|
+ l1 = time.time()
|
|
|
query = get_embeddings([phrase])[0]
|
|
|
- results = bm25_vector_search(phrase,query)
|
|
|
+ l2 = time.time()
|
|
|
+ logger.info(f"[{log_id}] Vector embedding time: {l2-l1}")
|
|
|
+ l3 = time.time()
|
|
|
+ rrf_score_threshold = float(config['es'].get('rrf_score_threshold',0.016))
|
|
|
+ results = bm25_vector_search(phrase,query,tag_ids=tag_ids,rrf_score_threshold=rrf_score_threshold)
|
|
|
+ l4 = time.time()
|
|
|
+ logger.info(f"[{log_id}] Vector search time: {l4-l3}")
|
|
|
dao.execute(
|
|
|
"""UPDATE aitag_tag_log SET tagging_channel = %s WHERE id = %s""",
|
|
|
(TAGGING_CHANNEL.VECTOR.value, log_id)
|
|
|
@@ -104,15 +123,15 @@ def init_tag_log(request: TaggingRequest):
|
|
|
# 业务编号如果以test开头,则tag_scope = 1,否则都是0
|
|
|
tag_scope = 1 if request.business_attr.startswith("test") else 0
|
|
|
dao.execute(
|
|
|
- """INSERT INTO aitag_tag_log (id,app_id, insert_time, business_attr, phrase, state, tag_scope,esb_seq_no,instucde,instucde_nm,company_nm,company_code,start_user_id,start_user_nm,start_user_org,start_user_endpoint) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,%s,%s,%s,%s,%s,%s,%s,%s)""",
|
|
|
- (id,request.app_id, datetime.now(), request.business_attr, request.phrase, TAGGING_STATE.REQUEST.value, tag_scope, request.esb_seq_no,request.instucde,request.instucde_nm,request.company_nm,request.company_code,request.user_id,request.user_nm,request.user_org,request.user_endpoint)
|
|
|
+ """INSERT INTO aitag_tag_log (id,app_id, insert_time, business_attr, phrase, state, tag_scope,esb_seq_no,instucde,instucde_nm,company_nm,company_code,start_user_id,start_user_nm,start_user_org,start_user_endpoint,contract_no) VALUES (%s, %s, %s, %s, %s, %s, %s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)""",
|
|
|
+ (id,request.app_id, datetime.now(), request.business_attr, request.phrase, TAGGING_STATE.REQUEST.value, tag_scope, request.esb_seq_no,request.instucde,request.instucde_nm,request.company_nm,request.company_code,request.user_id,request.user_nm,request.user_org,request.user_endpoint,request.contract_no)
|
|
|
)
|
|
|
return id
|
|
|
|
|
|
-def end_tagging(id:str, result:str):
|
|
|
+def end_tagging(id:str, result:str,x_input:str):
|
|
|
dao.execute(
|
|
|
- """UPDATE aitag_tag_log SET state = %s, result = %s, ai_result_endtime = %s WHERE id = %s""",
|
|
|
- (TAGGING_STATE.END.value, result, datetime.now(), id)
|
|
|
+ """UPDATE aitag_tag_log SET state = %s, result = %s, ai_result_endtime = %s, x_input = %s WHERE id = %s""",
|
|
|
+ (TAGGING_STATE.END.value, result, datetime.now(), x_input, id)
|
|
|
)
|
|
|
|
|
|
def fail_tagging(id:str):
|
|
|
@@ -134,44 +153,73 @@ def start_tagging(id:str, instucde: Optional[str] = None):
|
|
|
)
|
|
|
return is_marine
|
|
|
|
|
|
+
|
|
|
+def highlight_long_common_substrings(str_a, str_b):
|
|
|
+ """
|
|
|
+ 提取 A 和 B 中长度大于1的共同字符,并将 A 中的这些字符用 <strong> 标记
|
|
|
+ """
|
|
|
+ # 1. 从字符串 B 中提取所有长度大于1的连续字符片段(过滤掉正则符号)
|
|
|
+ # \w+ 会匹配字母、数字和下划线(如果你想匹配中文、字母和数字,可以保留\w;如果只想匹配纯中文,可以改成 [\u4e00-\u9fff]+)
|
|
|
+ b_substrings = re.findall(r'[\w\u4e00-\u9fff]{2,}', str_b)
|
|
|
+
|
|
|
+ # 2. 去重,并按照长度从长到短排序
|
|
|
+ # 排序非常重要!这能确保先匹配长词(如“海马”),避免短词(如“海马”被拆成“海”和“马”)干扰
|
|
|
+ unique_substrings = sorted(set(b_substrings), key=len, reverse=True)
|
|
|
+
|
|
|
+ highlighted_a = str_a
|
|
|
+
|
|
|
+ # 3. 遍历这些长字符串,如果在 A 中出现,就进行高亮替换
|
|
|
+ for substring in unique_substrings:
|
|
|
+ if substring in highlighted_a:
|
|
|
+ # 使用 re.sub 进行替换,re.escape 用于防止字符串中包含特殊正则符号报错
|
|
|
+ highlighted_a = re.sub(f'({re.escape(substring)})', r'<strong>\1</strong>', highlighted_a)
|
|
|
+
|
|
|
+ return highlighted_a
|
|
|
+
|
|
|
def generate_html(phrase, matched_rule, tag_name):
|
|
|
- highlighted_phrase = re.sub(matched_rule, r'<strong>\g<0></strong>', phrase)
|
|
|
+ highlighted_phrase = highlight_long_common_substrings(phrase, matched_rule)
|
|
|
html_output = f"映射文本【{highlighted_phrase}】与映射规则【{matched_rule}】匹配,映射为标签【{tag_name}】"
|
|
|
return html_output
|
|
|
|
|
|
# 定义预设规则匹配函数
|
|
|
def defined_rule_match(phrase: str):
|
|
|
- phrase = re.sub(r'职业.*?(?=投向)', '', phrase)
|
|
|
+ # phrase = re.sub(r'职业.*?(?=投向)', '', phrase)
|
|
|
+ seen_ids = set()
|
|
|
result = []
|
|
|
try:
|
|
|
- sql = """select DISTINCT on (tag_type,tag_nm) tag_type,tag_nm, defined_rule from aitag_predefined_rules where %s ~ defined_rule and defined_rule is not null and tag_nm is not null order by tag_type,tag_nm,defined_rule desc"""
|
|
|
- rules = dao.query(sql, (phrase,))
|
|
|
- print(rules)
|
|
|
+ sql = """select r.tag_type,r.tag_nm, r.defined_rule from aitag_predefined_rules r
|
|
|
+left join aitag_tag_category c on r.tag_type=c.category_code
|
|
|
+where c.is_delete=0 and c.state = 0 and r.defined_rule is not null and r.tag_nm is not null
|
|
|
+order by r.tag_type,r.tag_nm,r.defined_rule desc"""
|
|
|
+ rules = dao.query(sql)
|
|
|
if rules and len(rules) > 0:
|
|
|
for matched in rules:
|
|
|
- tag_info = dao.query("""select ati.id,ati.category_id, ati.tag_nm, ati.tag_path,ati.tag_code from aitag_tag_info ati left join aitag_tag_category atc on ati.category_id = atc.id where ati.tag_nm = %s and ati.is_delete = 0 and atc.category_code = %s""", (matched[1], matched[0]))
|
|
|
- # 安全检查:只有当 tag_info 有数据时才加入结果
|
|
|
- if tag_info and len(tag_info) > 0:
|
|
|
- result.append({
|
|
|
- "id": tag_info[0][0],
|
|
|
- "desc": generate_html(phrase, matched[2], tag_info[0][2]),
|
|
|
- "passr": True,
|
|
|
- "tag_code": tag_info[0][4],
|
|
|
- "tag_name": tag_info[0][2],
|
|
|
- "tag_path": tag_info[0][3],
|
|
|
- "category_id": tag_info[0][1]
|
|
|
- })
|
|
|
- else:
|
|
|
- logger.warning(f"预设规则匹配成功,但找不到对应的标签记录: tag_nm={matched[1]}, category_code={matched[0]}")
|
|
|
+ try:
|
|
|
+ if re.search(matched[2], phrase):
|
|
|
+ tag_info = dao.query("""select ati.id,ati.category_id, ati.tag_nm, ati.tag_path,ati.tag_code from aitag_tag_info ati left join aitag_tag_category atc on ati.category_id = atc.id where ati.tag_nm = %s and ati.is_delete = 0 and atc.category_code = %s""", (matched[1], matched[0]))
|
|
|
+ # 安全检查:只有当 tag_info 有数据时才加入结果
|
|
|
+ if tag_info and len(tag_info) > 0 and tag_info[0][0] not in seen_ids:
|
|
|
+ seen_ids.add(tag_info[0][0])
|
|
|
+ result.append({
|
|
|
+ "id": tag_info[0][0],
|
|
|
+ "desc": generate_html(phrase, matched[2], tag_info[0][2]),
|
|
|
+ "passr": True,
|
|
|
+ "tag_code": tag_info[0][4],
|
|
|
+ "tag_name": tag_info[0][2],
|
|
|
+ "tag_path": tag_info[0][3],
|
|
|
+ "category_id": tag_info[0][1]
|
|
|
+ })
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Defined rule match failed 1: {e}")
|
|
|
except Exception as e:
|
|
|
- logger.error(f"Defined rule match failed: {e}")
|
|
|
+ logger.error(f"Defined rule match failed 2: {e}")
|
|
|
result = []
|
|
|
return result
|
|
|
|
|
|
def end_tagging_predefined_rule(id:str, result:str, business_attr: Optional[str] = None):
|
|
|
dao.execute(
|
|
|
- """UPDATE aitag_tag_log SET state = %s, result = %s,ai_result_endtime = %s,tagging_channel = %s WHERE id = %s""",
|
|
|
- (TAGGING_STATE.PREDEFINED_RULE_MATCH.value, result,datetime.now(), TAGGING_CHANNEL.RULES.value, id)
|
|
|
+ """UPDATE aitag_tag_log SET state = %s, result = %s,ai_result_endtime = %s,tagging_channel = %s,feedback=%s WHERE id = %s""",
|
|
|
+ (TAGGING_STATE.PREDEFINED_RULE_MATCH.value, result,datetime.now(), TAGGING_CHANNEL.RULES.value,"agree", id)
|
|
|
)
|
|
|
if business_attr is not None and ESB_CALLBACK is not None:
|
|
|
try:
|
|
|
@@ -190,21 +238,37 @@ def run_ai_pipeline(log_id: str, tag_category_id: str, phrase: str, instucde: Op
|
|
|
defined_rule_result = defined_rule_match(phrase)
|
|
|
if defined_rule_result:
|
|
|
logger.info(f"预设规则匹配成功,直接返回结果: {defined_rule_result}")
|
|
|
- end_tagging_predefined_rule(log_id, json.dumps(defined_rule_result, business_attr))
|
|
|
+ end_tagging_predefined_rule(log_id, json.dumps(defined_rule_result),business_attr)
|
|
|
return
|
|
|
|
|
|
# step1: 正则过滤
|
|
|
result = execute_reg(log_id,tag_category_id,phrase)
|
|
|
+ logger.info(f"正则过滤结果: {result}")
|
|
|
# step2: 向量检索
|
|
|
- if not result or len(result) == 0:
|
|
|
- result = vector_similarity_search(log_id,phrase)
|
|
|
+ # if not result or len(result) == 0 or len(result) >TOP_N: # 正则过滤结果过多或没有结果都进行向量检索,避免正则规则不完善导致的漏匹配问题,同时也避免正则规则过于宽泛导致的过多匹配问题
|
|
|
+ v_result = vector_similarity_search(log_id,phrase)
|
|
|
+ logger.info(f"向量检索结果: {v_result}")
|
|
|
+ # step2.5: 合并结果,取交集优先,交集为空则取并集
|
|
|
+ if result and len(result) > 0:
|
|
|
+ v_result1 = list(set(result) & set(v_result)) # 取交集,既满足正则规则又满足向量相似度的标签,优先级更高
|
|
|
+ if v_result1 and len(v_result1) > 0:
|
|
|
+ result = v_result1
|
|
|
+ logger.info(f"交集结果: {v_result1}")
|
|
|
+ else:
|
|
|
+ result = list(set(result) | set(v_result)) # 取并集,满足正则规则或者满足向量相似度的标签
|
|
|
+ if result and len(result) > TOP_N:
|
|
|
+ result = vector_similarity_search(log_id,phrase,tag_ids=result) # 如果合并后结果过多,则再次进行向量检索过滤一次
|
|
|
+ logger.info(f"并集后再次向量检索结果: {result}")
|
|
|
+ else:
|
|
|
+ result = v_result
|
|
|
+ logger.info(f"最终候选结果: {result}")
|
|
|
# step3: LLM 打标
|
|
|
if result and len(result) > 0:
|
|
|
try:
|
|
|
tags = dao.query_dict(""" select id,tag_nm as tag_name,tag_code, tag_path,category_id,tag_prompt from aitag_tag_info where id in %s """, (tuple(result),))
|
|
|
logger.info(f"筛选结果: {tags}")
|
|
|
from agent.agent import reflect_check_sync
|
|
|
- result = reflect_check_sync(phrase,is_marine, tags)
|
|
|
+ result, x_input = reflect_check_sync(phrase,is_marine, tags)
|
|
|
except Exception as e:
|
|
|
logger.error(f"LLM reflection check failed: {e}")
|
|
|
result = None
|
|
|
@@ -212,7 +276,7 @@ def run_ai_pipeline(log_id: str, tag_category_id: str, phrase: str, instucde: Op
|
|
|
return
|
|
|
# step4: 更新数据库
|
|
|
# 如果result是个空集合,插入None
|
|
|
- end_tagging(log_id, result if result else None)
|
|
|
+ end_tagging(log_id, result if result else None,x_input)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"[{log_id}] Pipeline failed: {e}")
|