from langchain_core.prompts import ChatPromptTemplate from langgraph.graph import START, StateGraph, END from llmops.agents.datadev.llm import get_llm, get_llm_coder from typing import List from typing_extensions import TypedDict from pydantic import BaseModel, Field from llmops.agents.datadev.memory.memory_saver_with_expiry2 import MemorySaverWithExpiry from langchain_core.output_parsers import PydanticOutputParser from config import linage_agent_config from llmops.agents.datadev.tools.timeit import timeit import asyncio from datetime import datetime class Column(BaseModel): """ 字段信息 """ col_name: str = Field(description="字段名称") def __eq__(self, other): if isinstance(other, Column): return self.col_name == other.col_name return False class ColumnDep(BaseModel): """ 目标表字段信息,包含依赖关系 """ col_name: str = Field(description="目标表字段名称") col_comment: str = Field(description="目标表字段说明") from_cols: List[str] = Field(description="字段来源信息,格式是 源库名.源表名.源字段名 或 中间表名.中间表字段名") dep_type: List[str] = Field(description="字段获取方式 1:直取 2:函数 3:表达式") desp: List[str] = Field(description="详细解释选取来源字段的原因") class SourceTable(BaseModel): """ 单个来源信息表信息 """ database: str = Field(description="来源数据库名称") table_name: str = Field(description="来源表名称") table_name_alias: str = Field(default="", description="来源表别名") col_list: List[Column] = Field(description="来源表的字段集合") is_temp: bool = Field(description="是否是临时表") def __eq__(self, other): if isinstance(other, SourceTable): return self.database == other.database and self.table_name == other.table_name return False class SourceTableList(BaseModel): """ 所有来源信息表信息 """ source_tables: List[SourceTable] = Field(description="所有来源表信息") class WhereCondItem(BaseModel): """ 单个表过滤条件信息 """ database: str = Field(description="数据库名") table_name: str = Field(description="表原名") col_name: str = Field(description="条件字段名称") operator: str = Field(default="=", description="条件操作符,如=,>,<,in") value: str = Field(description="条件值") class WhereCondList(BaseModel): """ 表过滤条件集合 """ where_list: List[WhereCondItem] = Field(description="所有的where条件") class JoinColumnPair(BaseModel): """关联字段对""" main_table_alias: str = Field(description="主表别名") main_column: str = Field(description="主表关联字段") join_table_alias: str = Field(description="从表别名") join_column: str = Field(description="从表关联字段") operator: str = Field(default="=", description="关联操作符,如=,>,<") class JoinRelation(BaseModel): """完整的JOIN关系""" main_database: str = Field(description="主表数据库名") main_table: str = Field(description="主表物理表名") main_alias: str = Field(description="主表别名") join_type: str = Field(description="JOIN类型:INNER/LEFT/RIGHT/FULL") join_database: str = Field(description="从表数据库名") join_table: str = Field(description="从表物理表名") join_alias: str = Field(description="从表别名") column_pairs: List[JoinColumnPair] = Field(description="关联字段对列表") class JoinRelationList(BaseModel): """ join关系集合 """ join_list: List[JoinRelation] = Field(description="所有的join关系") class TargetTable(BaseModel): """ 目标表信息 """ database: str = Field(default="", description="目标表的数据库名称") table_name: str = Field(description="目标表名称") col_size: int = Field(description="目标表字段数量") col_list: List[str] = Field(description="目标表字段集合") col_dep: List[ColumnDep] = Field(description="目标表字段依赖信息") src_table_size: int = Field(description="依赖来源表数量") class CreateTable(BaseModel): """ 建表信息 """ database: str = Field(description="数据库名") table: str = Field(description="表名") col_list: list[str] = Field(description="字段名集合") source_table_list: list[str] = Field(description="来源表集合") class AgentState(TypedDict): """ Agent状态 """ question: str session: str dialect: str sql: str # 带行号的SQL语句 sql_type: str # sql操作类型 lineage: dict status: str error: str stop: bool file_name: str # 脚本文件 target_table: dict # 目标表信息 source_tables: list[dict] # 来源表信息 where_list: list[dict] # where条件信息 join_list: list[dict] # 表关联信息 class SqlLineageAgent: """ 解析用户提供的SQL语句/SQL脚本中的数据血缘关系 """ def __init__(self): # 读取配置文件 c = linage_agent_config # 数据库方言 self.db_dialect = c["dialect"] # 并发度 self.concurrency = c["concurrency"] # 大模型实例 self.llm = get_llm() self.llm_coder = get_llm_coder() self.memory = MemorySaverWithExpiry(expire_time=600, clean_interval=60) # 构建图 self.graph = self._build_graph() # SQL 语句中设置的变量 self.var_list: list[str] = [] # 最终的来源表集合 self.final_source_table_list: list[dict] = [] # 最终的中间表集合 self.final_mid_table_list: list[dict] = [] # 最终的where集合 self.final_where_list: list[dict] = [] # 最终的join集合 self.final_join_list: list[dict] = [] # 最终的目标表集合 self.final_target_list: list[dict] = [] # 最终目标表,需合并 self.final_target_table = {} # 目标表 格式 库名.表名 self.target_table = "" # 内部建表信息 self.ct_list = [] def _build_graph(self): # 构造计算流程 graph_builder = StateGraph(AgentState) # 添加节点 graph_builder.add_node("_sql_type", self._sql_type) graph_builder.add_node("_extract_set", self._extract_set) graph_builder.add_node("_extract_create_table", self._extract_create_table) graph_builder.add_node("_invalid", self._invalid) graph_builder.add_node("source", self.__extract_source) graph_builder.add_node("target", self.__extract_target) graph_builder.add_node("where", self.__extract_where) graph_builder.add_node("join", self.__extract_join) graph_builder.add_node("merge", self.__merge_result) # 添加边 graph_builder.add_edge(START, "_sql_type") graph_builder.add_conditional_edges("_sql_type", path=self.__parallel_route2, path_map={ "_invalid": "_invalid", "_extract_set": "_extract_set", "_extract_create_table": "_extract_create_table", "where": "where", "join": "join" }) graph_builder.add_edge("_invalid", END) graph_builder.add_edge("_extract_set", END) graph_builder.add_edge("_extract_set", END) graph_builder.add_edge("_extract_create_table", END) graph_builder.add_edge("where", "source") graph_builder.add_edge("join", "source") graph_builder.add_edge("source", "target") graph_builder.add_edge("target", "merge") graph_builder.add_edge("merge", END) return graph_builder.compile(checkpointer=self.memory) def _sql_type(self, state: AgentState): """ 判断SQL的操作类型 """ template = """ 你是数据库 {dialect} 专家,对 SQL语句: {sql} 中的操作进行分类, 直接返回分类ID(数字) ### 操作分类标准如下: 1:create, 如建表、建视图等 2:insert, 向表中插入数据 3:set、设置参数操作、变量赋值操作 4:其它 不要重复用户问题,不要中间分析过程,直接回答。 """ pt = ChatPromptTemplate.from_template(template) chain = pt | self.llm response = chain.invoke({"dialect": state["dialect"], "sql": state["question"]}) return {"sql_type": response.content} def __parallel_route(self, state: AgentState): """并发节点""" status = state["status"] if status == "invalid" or status == "canceled" or state["sql_type"] == "3": # 无效问题 或 用户中止 return END return "continue" def __parallel_route2(self, state: AgentState): """并发节点""" if state["sql_type"] == "4": # 非create和非insert操作 return "_invalid" if state["sql_type"] == "3": # 设置参数 return "_extract_set" if state["sql_type"] == "1": # 建表 return "_extract_create_table" return ["where", "join"] def _invalid(self, state: AgentState): """ sql_type == 4,不关注的语句,直接跳过 """ return { "lineage": { "source_tables": [], "join_list": [], "where_list": [], "target_table": {} } } async def _extract_set(self, state: AgentState): """ 提取 SQL 中设置的参数变量 """ if state["sql_type"] != "3": return state template = """ 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 参数设置的全部内容,包括变量赋值、注释。 不要重复用户问题,不需要中间思考过程,直接回答。 """ pt = ChatPromptTemplate.from_template(template) chain = pt | self.llm_coder response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]}) self.var_list.append(response.content) return { "lineage": { "source_tables": [], "join_list": [], "where_list": [], "target_table": {} } } async def _extract_create_table(self, state: AgentState): """ 提取建表信息 """ if state["sql_type"] != "1": return state template = """ 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 建表信息。 ### 提取要求 1、提取建表对应的数据库名(无则留空) 2、提取建表对应的表名 3、提取建表对应的字段名 4、提取建表对应的来源表名,来源表名格式是 来源库库名.来源表名 ### 输出规范 {format_instructions} 不要重复用户问题,不需要中间思考过程,直接回答。 """ parser = PydanticOutputParser(pydantic_object=CreateTable) pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions()) chain = pt | self.llm_coder | parser response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]}) # print(f"extract-create-table: {response}") self.ct_list.append(response.model_dump()) return { "lineage": { "source_tables": [], "join_list": [], "where_list": [], "target_table": {} } } async def _parse(self, semaphore, sql: str, session: str, dialect: str, file_name: str = "SQL_FILE.SQL"): """ 解析SQL(片段)的血缘关系 """ async with semaphore: # 进入信号量区域(限制并发) config = {"configurable": {"thread_id": session}} # 调用任务流 response = await self.graph.ainvoke({ "question": sql, "session": session, "dialect": dialect or self.db_dialect, "file_name": file_name, "stop": False, "error": "", "sql_type": "0", "status": "success"}, config) return response async def find_target_table(self, sql: str, dialect: str): """ 找出目标表 """ template = """ 你是SQL数据血缘分析专家,参照数据库方言{dialect},仔细分析SQL语句,找出目标表并返回。 ### SQL片段如下 {sql} ### 目标表标准 1、目标表一定是出现在 insert into 后的 表 2、最后 insert into 后的表为目标表 3、目标表只有一张 ### 核心要求 1、目标表的返回格式是 数据库名.表名 (如果无数据库名,则留空) 2、SQL语句中可能包括多个insert into语句,但要从全局分析,找出最终插入的目标表 不要重复用户问题、不要中间分析过程,直接回答。 """ pt = ChatPromptTemplate.from_template(template) chain = pt | self.llm_coder response = await chain.ainvoke({"sql": sql, "dialect": dialect}) return response.content.strip().upper() def split_sql_statements(self, sql): """ 将SQL语句按分号分段,同时正确处理字符串和注释中的分号 参数: sql: 包含一个或多个SQL语句的字符串 返回: 分割后的SQL语句列表 """ # 状态变量 in_single_quote = False in_double_quote = False in_line_comment = False in_block_comment = False escape_next = False statements = [] current_start = 0 i = 0 while i < len(sql): char = sql[i] # 处理转义字符 if escape_next: escape_next = False i += 1 continue # 处理字符串和注释中的情况(这些地方的分号不应该作为分隔符) if in_line_comment: if char == '\n': in_line_comment = False elif in_block_comment: if char == '*' and i + 1 < len(sql) and sql[i + 1] == '/': in_block_comment = False i += 1 # 跳过下一个字符 elif in_single_quote: if char == "'": in_single_quote = False elif char == '\\': escape_next = True elif in_double_quote: if char == '"': in_double_quote = False elif char == '\\': escape_next = True else: # 不在字符串或注释中,检查是否进入这些状态 if char == "'": in_single_quote = True elif char == '"': in_double_quote = True elif char == '-' and i + 1 < len(sql) and sql[i + 1] == '-': in_line_comment = True i += 1 # 跳过下一个字符 elif char == '/' and i + 1 < len(sql) and sql[i + 1] == '*': in_block_comment = True i += 1 # 跳过下一个字符 elif char == ';': # 找到真正的分号分隔符 statement = sql[current_start:i].strip() if statement: statements.append(statement) current_start = i + 1 i += 1 # 添加最后一个语句(如果没有以分号结尾) if current_start < len(sql): statement = sql[current_start:].strip() if statement: statements.append(statement) return statements async def find_target_table2(self, results: list[dict]) -> str: """ 根据各段解析结果查找出目标表 规则:一个段的目标表如果没有出现在其它各段的来源表中,即为目标表 """ target_table = "" exists = False for i, item in enumerate(results): tt = item["lineage"]["target_table"] # 取一个目标表 db = tt.get("database", "").strip().upper() table = tt.get("table_name", "").strip().upper() exists = False if len(table) > 0: # 从源表里面查找 for j, item2 in enumerate(results): if i != j: # 不跟自身比 st_list = item2["lineage"]["source_tables"] for st in st_list: sdb = st.get("database", "").strip().upper() stable = st.get("table_name", "").strip().upper() if db == sdb and table == stable: # 目标表在源表中存在 exists = True break if exists: break if not exists: # 说明目标表不在其它各段的源表中存在 target_table = ".".join([db, table]) break return target_table.strip().upper() async def ask_question(self, sql_content: str, session: str, dialect: str, owner: str, sql_file: str = "SQL_FILE.SQL"): """ 功能:根据用户问题,解析SQL中的数据血缘关系。先分段解析,再组织合并。 :param sql_content: SQL内容 :param session: 会话 :param dialect: 数据库方言 :param owner: 脚本平台属主 :param var_map: 变量关系 :sql_file: 脚本名称 """ # 通过session保持记忆 t1 = datetime.now() # 最终解析结果 final_result = {} final_result["task_id"] = session final_result["file_name"] = sql_file final_result["owner"] = owner final_result["status"] = "success" final_result["error"] = "" lineage = {} try: # 根据全局SQL找出目标表 self.target_table = await self.find_target_table(sql_content, dialect) print(f"大模型查找目标表:{self.target_table}, {len(self.target_table)}") # 对SQL分段,并发执行解析 sql_list = self.split_sql_statements(sql=sql_content) print("SQL分段数量:", len(sql_list)) # 信号量,控制并发量 semaphore = asyncio.Semaphore(self.concurrency) task_list = [self._parse(semaphore=semaphore, sql=sql, session=session, dialect=dialect, file_name=sql_file) for sql in sql_list] # 并发分段解析 results = await asyncio.gather(*task_list) if len(self.target_table) == 0: # 从各段解析结果中查找出目标表 self.target_table = await self.find_target_table2(results) or "" print(f"从血缘关系中找出目标表:{self.target_table}") if not self.target_table: # 未出现目标表, 直接返回 final_result["status"] = "error", final_result["error"] = "未找到目标表" return final_result for response in results: # 来源表、where、join、target self._merge_source_tables(response["lineage"]["source_tables"]) self._merge_where(response["lineage"].get("where_list",[])) self._merge_join(response["lineage"]["join_list"]) # 合并目标表 self._merge_target(response["lineage"]["target_table"]) # 增加中间表标识 self._add_mid_table_label(self.final_source_table_list, self.final_mid_table_list) # 补充中间表依赖字段 self._add_mid_table_col(self.final_mid_table_list, self.final_target_list) # 多目标表合并 self.final_target_table["src_table_size"] = len(self.final_source_table_list) lineage["target_table"] = self.final_target_table lineage["source_tables"] = self.final_source_table_list lineage["mid_table_list"] = self.final_mid_table_list lineage["join_list"] = self.final_join_list lineage["where_list"] = self.final_where_list final_result["lineage"] = lineage except Exception as e: final_result["status"] = "error" final_result["error"] = str(e) print(str(e)) t2 = datetime.now() print("总共用时(sec):", (t2-t1).seconds) parse_time = (t2 - t1).seconds # 解析时间(秒) final_result["parse_time"] = parse_time final_result["parse_end_time"] = t2.strftime("%Y-%m-%d %H:%M:%S") return final_result def _merge_source_tables(self, source_table_list: list[dict]): """ 合并分段来源表 """ if len(source_table_list) > 0: # 目标源表的keys table_keys = [(t.get("database","").strip().upper(), t.get("table_name","").strip().upper()) for t in self.final_source_table_list] for st in source_table_list: # 将字典转换为元组 db = st.get("database","").strip().upper() table = st.get("table_name","").strip().upper() key = (db, table) if key not in table_keys: table_keys.append(key) self.final_source_table_list.append(st) else: # 合并字段 col_list = st.get("col_list", []) # 找出相同的元素, 库名,表名相同 e = next(filter(lambda item: item["database"].upper()==db and item["table_name"].upper()==table, self.final_source_table_list), None) if e: final_col_list = e.get("col_list",[]) col_keys = [(c.get("col_name","").upper()) for c in final_col_list] for c in col_list: ck = (c["col_name"].upper()) if ck not in col_keys and len(c) > 0: col_keys.append(ck) final_col_list.append(c) @timeit def _merge_where(self, where_list: list[dict]): """ 最终合并所有的where条件 """ if where_list and len(where_list) > 0: self.final_where_list += where_list def _merge_join(self, join_list: list[dict]): """ 最终合并所有的join条件 """ if len(join_list) > 0: self.final_join_list += join_list def _merge_target(self, mid_target_table: dict): """ 合并中间目标表,只有 名称是 target_table_name 的表才是真目标表 目标表可能出现多次(如多次插入),需要合并 字段信息(字段名称和字段来源) 对于非目标表的中间目标表,放到中间表集合中 """ if not mid_target_table: # 中间目标表为空 return else: # 将临时目标表与目标表做合并 db = mid_target_table.get("database", "").strip().upper() table = mid_target_table.get("table_name", "").strip().upper() col_size = mid_target_table.get("col_list",[]) if len(table) == 0 or len(col_size) == 0: # 表名为空 或 无字段 直接返回 return if len(db) == 0: # 缺少数据库名, 检查表名中是否带库名 arr = table.split('.') if len(arr) == 2: # 说明表名中当库名,解析有错,修正 mid_target_table["database"] = arr[0].strip().upper() mid_target_table["table_name"] = arr[1].strip().upper() db = arr[0].strip().upper() table = arr[1].strip().upper() elif len(arr) == 1: # 无库名,只有表名,使用真目标表库名替换 arr = self.target_table.split(".") if len(arr) == 2: db = mid_target_table["database"] = arr[0].strip().upper() key = ".".join([db, table]) if key == self.target_table: # 找到最终目标表 print(f"合并目标表:{mid_target_table}") if not self.final_target_table: # 目标表为空(首次) self.final_target_table = mid_target_table return # 合并字段和字段来源信息 # 合并新字段 new_col_list = [] # 新字段集 col_list = mid_target_table.get("col_list", []) for new_col in col_list: # 检查字段是否存在、忽略大小写 if new_col.strip().upper() not in [col.strip().upper() for col in self.final_target_table.get("col_list", [])]: # 不存在 print(f"合并新字段:{new_col}") self.final_target_table.get("col_list", []).append(new_col.strip().upper()) new_col_list.append(new_col.strip().upper()) # 合并字段来源信息 col_dep_list = mid_target_table.get("col_dep", []) for col_dep in col_dep_list: col_name = col_dep["col_name"].strip().upper() if col_name in new_col_list: # 新字段, 直接添加 self.final_target_table.get("col_dep",[]).append(col_dep) else: # 合并来源信息 # 找到同名字段,合并来源字段 for dep in self.final_target_table.get("col_dep", []): cn = dep["col_name"].strip().upper() if col_name == cn: print(f"合并来源字段, {col_name}") m_from_col = col_dep.get('from_cols', []) f_from_col = dep.get("from_cols", []) print(f"中间表来源字段:{m_from_col}") print(f"目标表来源字段:{f_from_col}") # 将临时中间表的来源字段合并至目标表的来源字段 self._merge_col_dep(m_from_col, f_from_col) else: # 加入中间表集合中 print(f"加入中间表:{mid_target_table}") self.final_mid_table_list.append(mid_target_table) print(f"final tt:{self.final_target_table}") def _merge_col_dep(self, mid_from_cols: list[str], final_from_cols: list[str]): """ 合并临时中间表来源字段 到目标表来源字段中 """ for from_col in mid_from_cols: if len(from_col.split(".")) > 0: # 忽略常量,NULL值 if from_col.strip().upper() not in [fc.strip().upper() for fc in final_from_cols]: print(f"合并字段:{from_col}") final_from_cols.append(from_col.strip().upper()) @staticmethod @timeit def _add_mid_table_label(source_table_list: list[dict], mid_table_list: list[dict]): """ 给最终的来源表增加中间表标识 """ if len(source_table_list) > 0 and len(mid_table_list) > 0: mid_table_keys = [(t["database"].strip().upper(), t["table_name"].strip().upper()) for t in mid_table_list] for st in source_table_list: st["is_temp"] = False # 将字典转换为元组 key = (st["database"].strip().upper(), st["table_name"].strip().upper()) if key in mid_table_keys: # 是中间表 st["is_temp"] = True @staticmethod @timeit def _add_mid_table_col(mid_table_list: list[dict], target_table_list: list[dict]): """ 给中间表补充来源字段,使用对应的目标表col_dep去替换 """ if len(mid_table_list) > 0 and len(target_table_list) > 0: for mt in mid_table_list: for tt in target_table_list: tt["is_temp"] = False # 目标表的数据库名与表名相同 tdn = tt["database"].strip().upper() ttn = tt["table_name"].strip().upper() if len(tdn) == 0 and len(ttn) == 0: # 全部为空 tt["is_temp"] = True continue if mt["database"].strip().upper()==tdn and mt["table_name"].strip().upper()==ttn: mt["col_dep"] = tt["col_dep"] tt["is_temp"] = True if len(mt["col_list"]) == 0: mt["col_list"] = [c.upper() for c in tt["col_list"]] mt["col_size"] = len(mt["col_list"]) @timeit def _merge_final_target_table(self, seg_target_table_list: list[dict], target_table_name: str): """ 合并分段目标表内容 """ # 可能存在多个相同的真正目标表(多次插入目标表情况) for tt in seg_target_table_list: tt["is_target"] = False database = tt["database"].strip() table = tt["table_name"].strip() if len(database) > 0: dt = ".".join([database,table]) else: dt = table if dt.upper() == target_table_name.strip().upper(): tt["is_target"] = True # 合并真正的目标表 i = 0 result = {} final_col_list: list[str] = [] final_col_dep_list: list[dict] = [] # 合并目标表信息 for target_table in seg_target_table_list: if target_table["is_target"]: if i == 0: result["database"] = target_table["database"] result["table_name"] = target_table["table_name"] i = 1 # 合并列信息 col_list = target_table["col_list"] for c in col_list: if c not in final_col_list and len(c.strip()) > 0: final_col_list.append(c.upper()) # 合并列来源信息 col_dep_list = target_table["col_dep"] for col_dep in col_dep_list: cn = col_dep["col_name"].strip().upper() from_cols = col_dep["from_cols"] existed = False if len(cn) > 0: for fcd in final_col_dep_list: if cn == fcd["col_name"].upper(): existed = True # 合并来源字段 final_from_cols = fcd["from_cols"] fcd["from_cols"] += [c.upper() for c in from_cols if c.upper() not in final_from_cols] if not existed: final_col_dep_list.append({ "col_name": cn.upper(), "col_comment": col_dep["col_comment"], "from_cols": [c.upper() for c in from_cols] }) result["col_size"] = len(final_col_list) result["col_list"] = final_col_list result["col_dep"] = final_col_dep_list return result async def __check(self, state: AgentState): """ 检查用户问题是否是有效的SQL语句/SQL片段 """ template = """ 你是 数据库 {dialect} SQL分析助手, 仔细分析SQL语句: {sql} ,判断是否语法正确,如果是 直接返回 1,否则说明错误地方,并返回 0。 ### 分析要求 1. 判断子查询的正确性 2. 判断整体语句的正确性 3. 如果存在语法错误,给出说明 不要重复用户问题、不要分析过程、直接回答。 """ pt = ChatPromptTemplate.from_template(template) chain = pt | self.llm stopped = False response = "" config = {"configurable": {"thread_id": state["session"]}} async for chunk in chain.astream({"sql": state["question"], "dialect": state["dialect"]}, config=config): current_state = self.memory.get(config) if current_state and current_state["channel_values"]["stop"]: stopped = True break response += chunk.content if stopped: return {"stop": True, "error": "用户中止", "status": "canceled", "lineage": {}} if response == "0": return {"status": "invalid", "error": "无效问题或SQL存在语法错误,请重新描述!", "lineage": {}} return {"status": "success"} async def __merge_result(self, state: AgentState): """ 合并所有节点结果,输出血缘关系 """ result = {} result["file_name"] = state["file_name"] source_tables = state["source_tables"] target_table = state["target_table"] # 设置目标表依赖的来源表数量和字段数量 target_table["src_table_size"] = len(source_tables) target_table["col_size"] = len(target_table["col_dep"]) result["target_table"] = target_table result["source_tables"] = source_tables result["where_list"] = state.get("where_list",[]) result["join_list"] = state["join_list"] return {"lineage": result} async def __extract_source(self, state: AgentState): """ 根据SQL,提取其中的来源库表信息 """ dt1 = datetime.now() # 默认DB default_db = "TMP" if len(self.target_table) > 0: arr = self.target_table.split(".") if len(arr) == 2: # 使用识别出的目标表db作为默认DB default_db = arr[0] template = """ 你是SQL数据血缘分析专家,仔细分析SQL片段和上下文信息, 从SQL片段中提取出 来源表信息。 ### SQL和上下文信息如下 - SQL片段:{sql} - 数据库方言: {dialect} - 变量设置信息:{var} - 已有的建表信息: {create_table_info} ### 核心要求: 1. 提取 来源表数据库名称(无则使用 {default_db} 填充) 2. 提取 来源表名,包括 from子句、子查询、嵌套子查询、union语句、with语句中出现的物理表名(注意:不要带库名前缀) 3. 提取 来源表别名(无则留空) 4. 提取来源表的所有字段信息,提取来自几种: - 从select中提取,如果带有表别名,则转换成物理表名 - 从where过滤条件中提取 - 从关联条件(inner join,left join,right join,full join)中提取 - 从group中提取 - 从函数参数中提取,比如函数 COALESCE(T13.ECIF_CUST_ID,T2.RELAT_PTY_ID,'') AS ECIF_CUST_ID,提取出T13.ECIF_CUST_ID和T2.RELAT_PTY_ID - 从表达式参数中提取 5. 判断来源表是否是临时表 6. 不要包含目标表(目标表指最终插入的表) ### 输出规范 {format_instructions} 不要重复用户问题,不要分析中间过程,直接给出答案。 """ parser = PydanticOutputParser(pydantic_object=SourceTableList) pt = ChatPromptTemplate.from_template(template=template).partial(format_instructions=parser.get_format_instructions()) chain = pt | self.llm_coder | parser answer = {} config = {"configurable": {"thread_id": state["session"]}} stopped = False async for chunk in chain.astream({ "sql": state["question"], "dialect": state["dialect"], "var": self.var_list or [], "default_db": default_db, "create_table_info": self.ct_list }): current_state = self.memory.get(config) if current_state and current_state["channel_values"]["stop"]: stopped = True break # 设置目标表字段数量,来源表数量 # 合并相同源表 table_list: list[SourceTable] = [] for table in chunk.source_tables: if table not in table_list: table_list.append(table) else: idx = table_list.index(table) tt = table_list[idx] # 合并相同的列 for col in table.col_list: if col not in tt.col_list: tt.col_list.append(col) answer["source_tables"] = [table.model_dump() for table in table_list] dt2 = datetime.now() print("提取源表用时(sec):", (dt2-dt1).seconds) if stopped: return {"status": "canceled", "error": "用户中止", "lineage": {}} return answer async def __extract_target(self, state: AgentState): """ 根据SQL,提取其中的目标表信息 """ if state["sql_type"] == "1": return {"target_table": {"database":"", "table_name": "", "col_list": [], "col_size": 0, "col_dep": []}} # 默认DB default_db = "TMP" if len(self.target_table) > 0: arr = self.target_table.split(".") if len(arr) == 2: # 使用识别出的目标表db作为默认DB default_db = arr[0] template = """ 你是SQL数据血缘分析专家,仔细分析SQL片段和以下上下文: - SQL片段: {sql} - 来源表信息: {source_tables} - 数据库方言: {dialect} - 变量设置信息:{var} - 已有的建表信息: {create_table_info} ### 核心要求: 1. **目标表识别**: - 提取 INSERT INTO 后的表作为目标表 - 格式:`数据库名.表名`(未指定库名则使用 {default_db} 填充) 2. **字段依赖分析**: - 目标表字段必须溯源到来源表的物理字段 - 字段获取方式分类: - `直取`:直接引用源字段(如 `src.col`) - `函数`:通过函数转换(如 `COALESCE(col1, col2)`) - `表达式`:计算表达式(如 `col1 + col2`) - `常量`:固定值(如 `'2023'`) - `子查询`:嵌套SELECT结果 - 字段来源格式:`源库.源表.字段名`(不要使用表别名,如果源库未指定则使用 {default_db}替换) - 字段来源信息只能包含 库名、表名和字段信息,如果经过函数或表达式处理,则从参数中提取出具体的字段信息 3. **关键约束**: - 每个目标字段必须有详细的来源说明 - 函数参数需完整展开(如 `COALESCE(T1.id, T2.id)` 需解析所有参数) - 表名必须是 英文字符 构成(非汉字构成) ### 输出规范 {format_instructions} 请直接输出JSON格式结果,无需解释过程。 """ parser = PydanticOutputParser(pydantic_object=TargetTable) pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions()) chain = pt | self.llm_coder | parser answer = {} config = {"configurable": {"thread_id": state["session"]}} dt1 = datetime.now() print("开始解析目标表...") stopped = False async for chunk in chain.astream({ "sql": state["question"], "dialect": state["dialect"], "source_tables": state["source_tables"], "var": self.var_list or [], "default_db": default_db, "default_db": default_db, "create_table_info": self.ct_list }): current_state = self.memory.get(config) if current_state and current_state["channel_values"]["stop"]: stopped = True break answer["target_table"] = chunk.model_dump() if stopped: return {"status": "canceled", "error": "用户中止", "lineage": {}} dt2 = datetime.now() print("提取目标表用时(sec):", (dt2 - dt1).seconds) print(f"target:{answer}") return answer async def __extract_where(self, state: AgentState): """ 根据SQL,提取来源表 where 信息 """ # 默认DB default_db = "TMP" if len(self.target_table) > 0: arr = self.target_table.split(".") if len(arr) == 2: # 使用识别出的目标表db作为默认DB default_db = arr[0] dt1 = datetime.now() template = """ 你是SQL数据血缘分析专家,仔细分析 SQL语句和上下文,提取 where 条件信息。 ### SQL语句和上下文如下 - SQL语句: {sql} - 数据库方言: {dialect} - 变量设置信息: {var} ### 提取要求: 1. 提取 所有 来源表 中的 where 条件信息,包含 where 中出现的 表原名(非别名)、 表所在数据库名(如果没有则使用 {default_db} 填充)、字段名、条件操作符和条件值 2. 字段名不能包括函数,比如 length(INT_ORG_NO),提取的字段名是 INT_ORG_NO 3. 如果SQL语句中存在变量,则根据变量设置信息转换成实际值 ### 输出规范 {format_instructions} 不要重复用户问题,不要分析中间过程,直接给出答案。 """ parser = PydanticOutputParser(pydantic_object=WhereCondList) pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions()) chain = pt | self.llm_coder | parser answer = {} config = {"configurable": {"thread_id": state["session"]}} stopped = False async for chunk in chain.astream({ "sql": state["question"], "dialect": state["dialect"], "var": self.var_list or [], "default_db": default_db }): current_state = self.memory.get(config) if current_state and current_state["channel_values"]["stop"]: stopped = True break # 设置目标表字段数量,来源表数量 dt2 = datetime.now() if stopped: return {"status": "canceled", "error": "用户中止", "lineage": {}} return answer async def __extract_join(self, state: AgentState): """ 根据SQL,提取 表关联 信息 """ # 默认DB default_db = "TMP" if len(self.target_table) > 0: arr = self.target_table.split(".") if len(arr) == 2: # 使用识别出的目标表db作为默认DB default_db = arr[0] dt1 = datetime.now() template = """ 你是SQL数据血缘分析专家,负责提取完整的JOIN关系。仔细分析SQL片段: - SQL片段: {sql} - 数据库方言: {dialect} ### 关键要求: 1. **SQL语句中必须存在关联条件(包括inner、left join、right join、full join)** 2. **识别JOIN结构**: - 提取FROM子句中的主表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {default_db} 填充 - 提取每个JOIN子句中的从表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {j_default_db} 填充 - 明确标注JOIN类型(INNER/LEFT/RIGHT/FULL) 3. **关联字段提取**: - 必须提取每对关联字段的完整信息: * 主表端:表别名.字段名 * 从表端:表别名.字段名 - 明确标注关联操作符(=, >, <等) - 多条件JOIN拆分为多个字段对 4. **特殊场景处理**: - 子查询作为表时,表名填写子查询别名 - 隐式JOIN(WHERE条件)需转换为显式JOIN结构 - 多表JOIN时保持原始顺序 - 如果关联字段是 常量,则关联字段取空 - 如果关联字段经函数处,则提取参数中的 字段,如果参数是常量,则关联字段取空 ### 输出规范 {format_instructions} 不要重复用户问题,不要中间思考过程,直接回答。 """ parser = PydanticOutputParser(pydantic_object=JoinRelationList) pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions()) chain = pt | self.llm_coder | parser answer = {} config = {"configurable": {"thread_id": state["session"]}} stopped = False async for chunk in chain.astream({ "sql": state["question"], "dialect": state["dialect"], "default_db": default_db, "j_default_db": default_db }): current_state = self.memory.get(config) if current_state and current_state["channel_values"]["stop"]: stopped = True break # 设置目标表字段数量,来源表数量 answer["join_list"] = [join.model_dump() for join in chunk.join_list] dt2 = datetime.now() if stopped: return {"status": "canceled", "error": "用户中止", "lineage": {}} return answer def trigger_stop(self, session: str): """外部触发中止""" config = {"configurable": {"thread_id": session}} current_state = self.memory.get(config) if current_state: current_state["channel_values"]["stop"] = True async def main(sql_content: str, sql_file: str, dialect: str, owner: str): agent = SqlLineageAgent() result = await agent.ask_question(sql_content=sql_content, session="s-1", dialect=dialect, sql_file=sql_file, owner=owner) return result def read_var_file(var_file: str): """ 读取变量文件 :param var_file: 变量文件(格式excel) """ import pandas as pd # 读取变量excel文件 result = [] df = pd.read_excel(var_file) for row in df.itertuples(): print(f"变量 {row.Index + 1}: {row.变量}, {row.含义}") item = {} arr = row.变量.split("=") item["var_name"] = arr[0] item["var_value"] = arr[1] item["var_comment"] = row.含义 result.append(item) return result def replace_vars(sql_file: str, var_list: list[dict]): """ 替换 sql_file 中的变量 :param sql_file: SQL文件 :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}] """ new_content = "" encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312'] for encoding in encodings: try: with open(sql_file, 'r', encoding=encoding) as f: # 读取文件内容 content = f.read() # 替换变量 for var in var_list: name = var["var_name"] value = var["var_value"] new_content = content.replace(name, value) content = new_content break except UnicodeDecodeError: print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持") continue else: print("无法找到合适的编码") raise Exception(f"无法读取文件 {sql_file}, 未找到合适的编码.") return new_content def get_ddl_list(ddl_dir: str, var_list: list[dict]): """ 获取 DDL目录中定义的所有建表DDL,并替换其中的变量 :param ddl_dir: DDL目录,目录结构 ddl/db/xx.ddl :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}] """ file_list = [] # 遍历目录,获取所有.ddl文件 for root, dirs, files in os.walk(ddl_dir): for file in files: if file.endswith(".ddl"): file_path = os.path.join(root, file) file_list.append(file_path) print(f"DDL文件数量:{len(file_list)}") # 读取ddl文件,替换其中的变量 result = [] for file in file_list: fp = Path(file) # 文件的名称规则:db__table,解析出数据库和表名 name = fp.name.replace(".ddl", "") arr = name.split("__") db = arr[0] table = arr[1] # key key = f"{db.upper()}.{table.upper()}" ddl_content = "" encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312'] for encoding in encodings: try: with open(file, 'r', encoding=encoding) as f: # 读取文件内容 content = f.read() # 替换变量 for var in var_list: name = var["var_name"] value = var["var_value"] new_content = content.replace(name, value) content = new_content result.append({key: content}) break except UnicodeDecodeError: print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持") continue else: print("无法找到合适的编码") raise Exception(f"无法读取文件 {file}, 未找到合适的编码.") print(f"result:{result[0]}") return result if __name__ == "__main__": import sys import os from pathlib import Path import json # 参数检查 if len(sys.argv) <= 4: print(f"usage python sql_lineage_agent_xmgj.py var_file sql_file dialect owner") exit(1) # 变量文件 var_file = sys.argv[1] # sql脚本文件或目录 sql_file = sys.argv[2] # 数据库方言 dialect = sys.argv[3] # 属主 owner = sys.argv[4] print(f"1、解析SQL文件/目录:{sql_file}") print(f"2、SQL文件属主平台:{owner}") print(f"3、变量文件:{var_file}") print(f"4、SQL数据库方言: {dialect}") # 检查变量文件是否存在 var_fp = Path(var_file) if not var_fp.exists(): raise FileNotFoundError(f"变量文件 {var_file} 不存在,请指定正确路径.") # 检查SQL脚本文件是否存在 sql_fp = Path(sql_file) if not sql_fp.exists(): raise FileNotFoundError(f"SQL脚本文件 {sql_file} 不存在,请指定正确路径.") # 读取变量文件,获取变量值 var_list = read_var_file(var_file) # 读取SQL文件, 替换变量 sql_content = replace_vars(sql_file=sql_file, var_list=var_list) # 解析SQL血缘 result = asyncio.run(main(sql_content=sql_content, sql_file=sql_fp.name, dialect=dialect, owner=owner)) # 将结果写入同级目录下,文件后缀为.json target_file = Path(sql_fp.parent.absolute() / (sql_fp.name + ".json")) print(f"写目标文件:{target_file}") # 写文件 with open(target_file, 'w', encoding="utf-8") as t: t.write(json.dumps(result, ensure_ascii=False, indent=2))