| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298 |
- from langchain_core.prompts import ChatPromptTemplate
- from langgraph.graph import START, StateGraph, END
- from llmops.agents.datadev.llm import get_llm, get_llm_coder
- from typing import List
- from typing_extensions import TypedDict
- from pydantic import BaseModel, Field
- from llmops.agents.datadev.memory.memory_saver_with_expiry2 import MemorySaverWithExpiry
- from langchain_core.output_parsers import PydanticOutputParser
- from config import linage_agent_config
- from llmops.agents.datadev.tools.timeit import timeit
- import asyncio
- from datetime import datetime
- class Column(BaseModel):
- """
- 字段信息
- """
- col_name: str = Field(description="字段名称")
- def __eq__(self, other):
- if isinstance(other, Column):
- return self.col_name == other.col_name
- return False
- class ColumnDep(BaseModel):
- """
- 目标表字段信息,包含依赖关系
- """
- col_name: str = Field(description="目标表字段名称")
- col_comment: str = Field(description="目标表字段说明")
- from_cols: List[str] = Field(description="字段来源信息,格式是 源库名.源表名.源字段名 或 中间表名.中间表字段名")
- dep_type: List[str] = Field(description="字段获取方式 1:直取 2:函数 3:表达式")
- desp: List[str] = Field(description="详细解释选取来源字段的原因")
- class SourceTable(BaseModel):
- """
- 单个来源信息表信息
- """
- database: str = Field(description="来源数据库名称")
- table_name: str = Field(description="来源表名称")
- table_name_alias: str = Field(default="", description="来源表别名")
- col_list: List[Column] = Field(description="来源表的字段集合")
- is_temp: bool = Field(description="是否是临时表")
- def __eq__(self, other):
- if isinstance(other, SourceTable):
- return self.database == other.database and self.table_name == other.table_name
- return False
- class SourceTableList(BaseModel):
- """
- 所有来源信息表信息
- """
- source_tables: List[SourceTable] = Field(description="所有来源表信息")
- class WhereCondItem(BaseModel):
- """
- 单个表过滤条件信息
- """
- database: str = Field(description="数据库名")
- table_name: str = Field(description="表原名")
- col_name: str = Field(description="条件字段名称")
- operator: str = Field(default="=", description="条件操作符,如=,>,<,in")
- value: str = Field(description="条件值")
- class WhereCondList(BaseModel):
- """
- 表过滤条件集合
- """
- where_list: List[WhereCondItem] = Field(description="所有的where条件")
- class JoinColumnPair(BaseModel):
- """关联字段对"""
- main_table_alias: str = Field(description="主表别名")
- main_column: str = Field(description="主表关联字段")
- join_table_alias: str = Field(description="从表别名")
- join_column: str = Field(description="从表关联字段")
- operator: str = Field(default="=", description="关联操作符,如=,>,<")
- class JoinRelation(BaseModel):
- """完整的JOIN关系"""
- main_database: str = Field(description="主表数据库名")
- main_table: str = Field(description="主表物理表名")
- main_alias: str = Field(description="主表别名")
- join_type: str = Field(description="JOIN类型:INNER/LEFT/RIGHT/FULL")
- join_database: str = Field(description="从表数据库名")
- join_table: str = Field(description="从表物理表名")
- join_alias: str = Field(description="从表别名")
- column_pairs: List[JoinColumnPair] = Field(description="关联字段对列表")
- class JoinRelationList(BaseModel):
- """
- join关系集合
- """
- join_list: List[JoinRelation] = Field(description="所有的join关系")
- class TargetTable(BaseModel):
- """
- 目标表信息
- """
- database: str = Field(default="", description="目标表的数据库名称")
- table_name: str = Field(description="目标表名称")
- col_size: int = Field(description="目标表字段数量")
- col_list: List[str] = Field(description="目标表字段集合")
- col_dep: List[ColumnDep] = Field(description="目标表字段依赖信息")
- src_table_size: int = Field(description="依赖来源表数量")
- class CreateTable(BaseModel):
- """
- 建表信息
- """
- database: str = Field(description="数据库名")
- table: str = Field(description="表名")
- col_list: list[str] = Field(description="字段名集合")
- source_table_list: list[str] = Field(description="来源表集合")
- class AgentState(TypedDict):
- """
- Agent状态
- """
- question: str
- session: str
- dialect: str
- sql: str # 带行号的SQL语句
- sql_type: str # sql操作类型
- lineage: dict
- status: str
- error: str
- stop: bool
- file_name: str # 脚本文件
- target_table: dict # 目标表信息
- source_tables: list[dict] # 来源表信息
- where_list: list[dict] # where条件信息
- join_list: list[dict] # 表关联信息
- class SqlLineageAgent:
- """
- 解析用户提供的SQL语句/SQL脚本中的数据血缘关系
- """
- def __init__(self):
- # 读取配置文件
- c = linage_agent_config
- # 数据库方言
- self.db_dialect = c["dialect"]
- # 并发度
- self.concurrency = c["concurrency"]
- # 大模型实例
- self.llm = get_llm()
- self.llm_coder = get_llm_coder()
- self.memory = MemorySaverWithExpiry(expire_time=600, clean_interval=60)
- # 构建图
- self.graph = self._build_graph()
- # SQL 语句中设置的变量
- self.var_list: list[str] = []
- # 最终的来源表集合
- self.final_source_table_list: list[dict] = []
- # 最终的中间表集合
- self.final_mid_table_list: list[dict] = []
- # 最终的where集合
- self.final_where_list: list[dict] = []
- # 最终的join集合
- self.final_join_list: list[dict] = []
- # 最终的目标表集合
- self.final_target_list: list[dict] = []
- # 最终目标表,需合并
- self.final_target_table = {}
- # 目标表 格式 库名.表名
- self.target_table = ""
- # 内部建表信息
- self.ct_list = []
- def _build_graph(self):
- # 构造计算流程
- graph_builder = StateGraph(AgentState)
- # 添加节点
- graph_builder.add_node("_sql_type", self._sql_type)
- graph_builder.add_node("_extract_set", self._extract_set)
- graph_builder.add_node("_extract_create_table", self._extract_create_table)
- graph_builder.add_node("_invalid", self._invalid)
- graph_builder.add_node("source", self.__extract_source)
- graph_builder.add_node("target", self.__extract_target)
- graph_builder.add_node("where", self.__extract_where)
- graph_builder.add_node("join", self.__extract_join)
- graph_builder.add_node("merge", self.__merge_result)
- # 添加边
- graph_builder.add_edge(START, "_sql_type")
- graph_builder.add_conditional_edges("_sql_type", path=self.__parallel_route2, path_map={
- "_invalid": "_invalid",
- "_extract_set": "_extract_set",
- "_extract_create_table": "_extract_create_table",
- "where": "where",
- "join": "join"
- })
- graph_builder.add_edge("_invalid", END)
- graph_builder.add_edge("_extract_set", END)
- graph_builder.add_edge("_extract_set", END)
- graph_builder.add_edge("_extract_create_table", END)
- graph_builder.add_edge("where", "source")
- graph_builder.add_edge("join", "source")
- graph_builder.add_edge("source", "target")
- graph_builder.add_edge("target", "merge")
- graph_builder.add_edge("merge", END)
- return graph_builder.compile(checkpointer=self.memory)
- def _sql_type(self, state: AgentState):
- """
- 判断SQL的操作类型
- """
- template = """
- 你是数据库 {dialect} 专家,对 SQL语句: {sql} 中的操作进行分类, 直接返回分类ID(数字)
-
- ### 操作分类标准如下:
- 1:create, 如建表、建视图等
- 2:insert, 向表中插入数据
- 3:set、设置参数操作、变量赋值操作
- 4:其它
-
- 不要重复用户问题,不要中间分析过程,直接回答。
- """
- pt = ChatPromptTemplate.from_template(template)
- chain = pt | self.llm
- response = chain.invoke({"dialect": state["dialect"], "sql": state["question"]})
- return {"sql_type": response.content}
- def __parallel_route(self, state: AgentState):
- """并发节点"""
- status = state["status"]
- if status == "invalid" or status == "canceled" or state["sql_type"] == "3": # 无效问题 或 用户中止
- return END
- return "continue"
- def __parallel_route2(self, state: AgentState):
- """并发节点"""
- if state["sql_type"] == "4": # 非create和非insert操作
- return "_invalid"
- if state["sql_type"] == "3": # 设置参数
- return "_extract_set"
- if state["sql_type"] == "1": # 建表
- return "_extract_create_table"
- return ["where", "join"]
- def _invalid(self, state: AgentState):
- """
- sql_type == 4,不关注的语句,直接跳过
- """
- return {
- "lineage": {
- "source_tables": [],
- "join_list": [],
- "where_list": [],
- "target_table": {}
- }
- }
- async def _extract_set(self, state: AgentState):
- """
- 提取 SQL 中设置的参数变量
- """
- if state["sql_type"] != "3":
- return state
- template = """
- 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 参数设置的全部内容,包括变量赋值、注释。
- 不要重复用户问题,不需要中间思考过程,直接回答。
- """
- pt = ChatPromptTemplate.from_template(template)
- chain = pt | self.llm_coder
- response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]})
- self.var_list.append(response.content)
- return {
- "lineage": {
- "source_tables": [],
- "join_list": [],
- "where_list": [],
- "target_table": {}
- }
- }
- async def _extract_create_table(self, state: AgentState):
- """
- 提取建表信息
- """
- if state["sql_type"] != "1":
- return state
- template = """
- 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 建表信息。
-
- ### 提取要求
- 1、提取建表对应的数据库名(无则留空)
- 2、提取建表对应的表名
- 3、提取建表对应的字段名
- 4、提取建表对应的来源表名,来源表名格式是 来源库库名.来源表名
-
- ### 输出规范
- {format_instructions}
-
- 不要重复用户问题,不需要中间思考过程,直接回答。
- """
- parser = PydanticOutputParser(pydantic_object=CreateTable)
- pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
- chain = pt | self.llm_coder | parser
- response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]})
- #
- print(f"extract-create-table: {response}")
- self.ct_list.append(response.model_dump())
- return {
- "lineage": {
- "source_tables": [],
- "join_list": [],
- "where_list": [],
- "target_table": {}
- }
- }
- async def _parse(self, semaphore, sql: str, session: str, dialect: str, file_name: str = "SQL_FILE.SQL"):
- """
- 解析SQL(片段)的血缘关系
- """
- async with semaphore: # 进入信号量区域(限制并发)
- config = {"configurable": {"thread_id": session}}
- # 调用任务流
- response = await self.graph.ainvoke({
- "question": sql,
- "session": session,
- "dialect": dialect or self.db_dialect,
- "file_name": file_name,
- "stop": False,
- "error": "",
- "sql_type": "0",
- "status": "success"}, config)
- return response
- async def find_target_table(self, sql: str, dialect: str):
- """
- 找出目标表
- """
- template = """
- 你是SQL数据血缘分析专家,参照数据库方言{dialect},仔细分析SQL语句,找出目标表并返回。
- ### SQL片段如下
- {sql}
-
- ### 目标表标准
- 1、目标表一定是出现在 insert into 后的 表
- 2、最后 insert into 后的表为目标表
- 3、目标表只有一张
-
- ### 核心要求
- 1、目标表的返回格式是 数据库名.表名 (如果无数据库名,则留空)
- 2、SQL语句中可能包括多个insert into语句,但要从全局分析,找出最终插入的目标表
-
- 不要重复用户问题、不要中间分析过程,直接回答。
- """
- pt = ChatPromptTemplate.from_template(template)
- chain = pt | self.llm_coder
- response = await chain.ainvoke({"sql": sql, "dialect": dialect})
- return response.content.strip().upper()
- def split_sql_statements(self, sql):
- """
- 将SQL语句按分号分段,同时正确处理字符串和注释中的分号
- 参数:
- sql: 包含一个或多个SQL语句的字符串
- 返回:
- 分割后的SQL语句列表
- """
- # 状态变量
- in_single_quote = False
- in_double_quote = False
- in_line_comment = False
- in_block_comment = False
- escape_next = False
- statements = []
- current_start = 0
- i = 0
- while i < len(sql):
- char = sql[i]
- # 处理转义字符
- if escape_next:
- escape_next = False
- i += 1
- continue
- # 处理字符串和注释中的情况(这些地方的分号不应该作为分隔符)
- if in_line_comment:
- if char == '\n':
- in_line_comment = False
- elif in_block_comment:
- if char == '*' and i + 1 < len(sql) and sql[i + 1] == '/':
- in_block_comment = False
- i += 1 # 跳过下一个字符
- elif in_single_quote:
- if char == "'":
- in_single_quote = False
- elif char == '\\':
- escape_next = True
- elif in_double_quote:
- if char == '"':
- in_double_quote = False
- elif char == '\\':
- escape_next = True
- else:
- # 不在字符串或注释中,检查是否进入这些状态
- if char == "'":
- in_single_quote = True
- elif char == '"':
- in_double_quote = True
- elif char == '-' and i + 1 < len(sql) and sql[i + 1] == '-':
- in_line_comment = True
- i += 1 # 跳过下一个字符
- elif char == '/' and i + 1 < len(sql) and sql[i + 1] == '*':
- in_block_comment = True
- i += 1 # 跳过下一个字符
- elif char == ';':
- # 找到真正的分号分隔符
- statement = sql[current_start:i].strip()
- if statement:
- statements.append(statement)
- current_start = i + 1
- i += 1
- # 添加最后一个语句(如果没有以分号结尾)
- if current_start < len(sql):
- statement = sql[current_start:].strip()
- if statement:
- statements.append(statement)
- return statements
- async def find_target_table2(self, results: list[dict]) -> str:
- """
- 根据各段解析结果查找出目标表
- 规则:一个段的目标表如果没有出现在其它各段的来源表中,即为目标表
- """
- target_table = ""
- exists = False
- for i, item in enumerate(results):
- tt = item["lineage"]["target_table"]
- # 取一个目标表
- db = tt.get("database", "").strip().upper()
- table = tt.get("table_name", "").strip().upper()
- exists = False
- if len(table) > 0:
- # 从源表里面查找
- for j, item2 in enumerate(results):
- if i != j: # 不跟自身比
- st_list = item2["lineage"]["source_tables"]
- for st in st_list:
- sdb = st.get("database", "").strip().upper()
- stable = st.get("table_name", "").strip().upper()
- if db == sdb and table == stable: # 目标表在源表中存在
- exists = True
- break
- if exists:
- break
- if not exists: # 说明目标表不在其它各段的源表中存在
- target_table = ".".join([db, table])
- break
- return target_table.strip().upper()
- async def ask_question(self, sql_content: str, session: str, dialect: str, owner: str, sql_file: str = "SQL_FILE.SQL"):
- """
- 功能:根据用户问题,解析SQL中的数据血缘关系。先分段解析,再组织合并。
- :param sql_content: SQL内容
- :param session: 会话
- :param dialect: 数据库方言
- :param owner: 脚本平台属主
- :param var_map: 变量关系
- :sql_file: 脚本名称
- """
- # 通过session保持记忆
- t1 = datetime.now()
- # 最终解析结果
- final_result = {}
- final_result["task_id"] = session
- final_result["file_name"] = sql_file
- final_result["owner"] = owner
- final_result["status"] = "success"
- final_result["error"] = ""
- lineage = {}
- try:
- # 根据全局SQL找出目标表
- self.target_table = await self.find_target_table(sql_content, dialect)
- print(f"大模型查找目标表:{self.target_table}, {len(self.target_table)}")
- # 对SQL分段,并发执行解析
- sql_list = self.split_sql_statements(sql=sql_content)
- print("SQL分段数量:", len(sql_list))
- # 信号量,控制并发量
- semaphore = asyncio.Semaphore(self.concurrency)
- task_list = [self._parse(semaphore=semaphore, sql=sql, session=session, dialect=dialect, file_name=sql_file) for sql in sql_list]
- # 并发分段解析
- results = await asyncio.gather(*task_list)
- if len(self.target_table) == 0:
- # 从各段解析结果中查找出目标表
- self.target_table = await self.find_target_table2(results) or ""
- print(f"从血缘关系中找出目标表:{self.target_table}")
- if not self.target_table: # 未出现目标表, 直接返回
- final_result["status"] = "error",
- final_result["error"] = "未找到目标表"
- return final_result
- for response in results:
- # 来源表、where、join、target
- self._merge_source_tables(response["lineage"]["source_tables"])
- self._merge_where(response["lineage"].get("where_list",[]))
- self._merge_join(response["lineage"]["join_list"])
- # 合并目标表
- self._merge_target(response["lineage"]["target_table"])
- # 增加中间表标识
- self._add_mid_table_label(self.final_source_table_list, self.final_mid_table_list)
- # 补充中间表依赖字段
- self._add_mid_table_col(self.final_mid_table_list, self.final_target_list)
- # 多目标表合并
- self.final_target_table["src_table_size"] = len(self.final_source_table_list)
- lineage["target_table"] = self.final_target_table
- lineage["source_tables"] = self.final_source_table_list
- lineage["mid_table_list"] = self.final_mid_table_list
- lineage["join_list"] = self.final_join_list
- lineage["where_list"] = self.final_where_list
- final_result["lineage"] = lineage
- except Exception as e:
- final_result["status"] = "error"
- final_result["error"] = str(e)
- print(str(e))
- t2 = datetime.now()
- print("总共用时(sec):", (t2-t1).seconds)
- parse_time = (t2 - t1).seconds
- # 解析时间(秒)
- final_result["parse_time"] = parse_time
- final_result["parse_end_time"] = t2.strftime("%Y-%m-%d %H:%M:%S")
- return final_result
- def _merge_source_tables(self, source_table_list: list[dict]):
- """
- 合并分段来源表
- """
- if len(source_table_list) > 0:
- # 目标源表的keys
- table_keys = [(t.get("database","").strip().upper(), t.get("table_name","").strip().upper()) for t in self.final_source_table_list]
- for st in source_table_list:
- # 将字典转换为元组
- db = st.get("database","").strip().upper()
- table = st.get("table_name","").strip().upper()
- key = (db, table)
- if key not in table_keys:
- table_keys.append(key)
- self.final_source_table_list.append(st)
- else:
- # 合并字段
- col_list = st.get("col_list", [])
- # 找出相同的元素, 库名,表名相同
- e = next(filter(lambda item: item["database"].upper()==db and item["table_name"].upper()==table, self.final_source_table_list), None)
- if e:
- final_col_list = e.get("col_list",[])
- col_keys = [(c.get("col_name","").upper()) for c in final_col_list]
- for c in col_list:
- ck = (c["col_name"].upper())
- if ck not in col_keys and len(c) > 0:
- col_keys.append(ck)
- final_col_list.append(c)
- @timeit
- def _merge_where(self, where_list: list[dict]):
- """
- 最终合并所有的where条件
- """
- if where_list and len(where_list) > 0:
- self.final_where_list += where_list
- def _merge_join(self, join_list: list[dict]):
- """
- 最终合并所有的join条件
- """
- if len(join_list) > 0:
- self.final_join_list += join_list
- def _merge_target(self, mid_target_table: dict):
- """
- 合并中间目标表,只有 名称是 target_table_name 的表才是真目标表
- 目标表可能出现多次(如多次插入),需要合并 字段信息(字段名称和字段来源)
- 对于非目标表的中间目标表,放到中间表集合中
- """
- if not mid_target_table: # 中间目标表为空
- return
- else: # 将临时目标表与目标表做合并
- db = mid_target_table.get("database", "").strip().upper()
- table = mid_target_table.get("table_name", "").strip().upper()
- col_size = mid_target_table.get("col_list",[])
- if len(table) == 0 or len(col_size) == 0: # 表名为空 或 无字段 直接返回
- return
- if len(db) == 0: # 缺少数据库名, 检查表名中是否带库名
- arr = table.split('.')
- if len(arr) == 2: # 说明表名中当库名,解析有错,修正
- mid_target_table["database"] = arr[0].strip().upper()
- mid_target_table["table_name"] = arr[1].strip().upper()
- db = arr[0].strip().upper()
- table = arr[1].strip().upper()
- elif len(arr) == 1: # 无库名,只有表名,使用真目标表库名替换
- arr = self.target_table.split(".")
- if len(arr) == 2:
- db = mid_target_table["database"] = arr[0].strip().upper()
- key = ".".join([db, table])
- if key == self.target_table: # 找到最终目标表
- print(f"合并目标表:{mid_target_table}")
- if not self.final_target_table: # 目标表为空(首次)
- self.final_target_table = mid_target_table
- return
- # 合并字段和字段来源信息
- # 合并新字段
- new_col_list = [] # 新字段集
- col_list = mid_target_table.get("col_list", [])
- for new_col in col_list:
- # 检查字段是否存在、忽略大小写
- if new_col.strip().upper() not in [col.strip().upper() for col in self.final_target_table.get("col_list", [])]: # 不存在
- print(f"合并新字段:{new_col}")
- self.final_target_table.get("col_list", []).append(new_col.strip().upper())
- new_col_list.append(new_col.strip().upper())
- # 合并字段来源信息
- col_dep_list = mid_target_table.get("col_dep", [])
- for col_dep in col_dep_list:
- col_name = col_dep["col_name"].strip().upper()
- if col_name in new_col_list: # 新字段, 直接添加
- self.final_target_table.get("col_dep",[]).append(col_dep)
- else: # 合并来源信息
- # 找到同名字段,合并来源字段
- for dep in self.final_target_table.get("col_dep", []):
- cn = dep["col_name"].strip().upper()
- if col_name == cn:
- print(f"合并来源字段, {col_name}")
- m_from_col = col_dep.get('from_cols', [])
- f_from_col = dep.get("from_cols", [])
- print(f"中间表来源字段:{m_from_col}")
- print(f"目标表来源字段:{f_from_col}")
- # 将临时中间表的来源字段合并至目标表的来源字段
- self._merge_col_dep(m_from_col, f_from_col)
- else:
- # 加入中间表集合中
- print(f"加入中间表:{mid_target_table}")
- self.final_mid_table_list.append(mid_target_table)
- print(f"final tt:{self.final_target_table}")
- def _merge_col_dep(self, mid_from_cols: list[str], final_from_cols: list[str]):
- """
- 合并临时中间表来源字段 到目标表来源字段中
- """
- for from_col in mid_from_cols:
- if len(from_col.split(".")) > 0: # 忽略常量,NULL值
- if from_col.strip().upper() not in [fc.strip().upper() for fc in final_from_cols]:
- print(f"合并字段:{from_col}")
- final_from_cols.append(from_col.strip().upper())
- @staticmethod
- @timeit
- def _add_mid_table_label(source_table_list: list[dict], mid_table_list: list[dict]):
- """
- 给最终的来源表增加中间表标识
- """
- if len(source_table_list) > 0 and len(mid_table_list) > 0:
- mid_table_keys = [(t["database"].strip().upper(), t["table_name"].strip().upper()) for t in mid_table_list]
- for st in source_table_list:
- st["is_temp"] = False
- # 将字典转换为元组
- key = (st["database"].strip().upper(), st["table_name"].strip().upper())
- if key in mid_table_keys: # 是中间表
- st["is_temp"] = True
- @staticmethod
- @timeit
- def _add_mid_table_col(mid_table_list: list[dict], target_table_list: list[dict]):
- """
- 给中间表补充来源字段,使用对应的目标表col_dep去替换
- """
- if len(mid_table_list) > 0 and len(target_table_list) > 0:
- for mt in mid_table_list:
- for tt in target_table_list:
- tt["is_temp"] = False
- # 目标表的数据库名与表名相同
- tdn = tt["database"].strip().upper()
- ttn = tt["table_name"].strip().upper()
- if len(tdn) == 0 and len(ttn) == 0: # 全部为空
- tt["is_temp"] = True
- continue
- if mt["database"].strip().upper()==tdn and mt["table_name"].strip().upper()==ttn:
- mt["col_dep"] = tt["col_dep"]
- tt["is_temp"] = True
- if len(mt["col_list"]) == 0:
- mt["col_list"] = [c.upper() for c in tt["col_list"]]
- mt["col_size"] = len(mt["col_list"])
- @timeit
- def _merge_final_target_table(self, seg_target_table_list: list[dict], target_table_name: str):
- """
- 合并分段目标表内容
- """
- # 可能存在多个相同的真正目标表(多次插入目标表情况)
- for tt in seg_target_table_list:
- tt["is_target"] = False
- database = tt["database"].strip()
- table = tt["table_name"].strip()
- if len(database) > 0:
- dt = ".".join([database,table])
- else:
- dt = table
- if dt.upper() == target_table_name.strip().upper():
- tt["is_target"] = True
- # 合并真正的目标表
- i = 0
- result = {}
- final_col_list: list[str] = []
- final_col_dep_list: list[dict] = []
- # 合并目标表信息
- for target_table in seg_target_table_list:
- if target_table["is_target"]:
- if i == 0:
- result["database"] = target_table["database"]
- result["table_name"] = target_table["table_name"]
- i = 1
- # 合并列信息
- col_list = target_table["col_list"]
- for c in col_list:
- if c not in final_col_list and len(c.strip()) > 0:
- final_col_list.append(c.upper())
- # 合并列来源信息
- col_dep_list = target_table["col_dep"]
- for col_dep in col_dep_list:
- cn = col_dep["col_name"].strip().upper()
- from_cols = col_dep["from_cols"]
- existed = False
- if len(cn) > 0:
- for fcd in final_col_dep_list:
- if cn == fcd["col_name"].upper():
- existed = True
- # 合并来源字段
- final_from_cols = fcd["from_cols"]
- fcd["from_cols"] += [c.upper() for c in from_cols if c.upper() not in final_from_cols]
- if not existed:
- final_col_dep_list.append({
- "col_name": cn.upper(),
- "col_comment": col_dep["col_comment"],
- "from_cols": [c.upper() for c in from_cols]
- })
- result["col_size"] = len(final_col_list)
- result["col_list"] = final_col_list
- result["col_dep"] = final_col_dep_list
- return result
- async def __check(self, state: AgentState):
- """
- 检查用户问题是否是有效的SQL语句/SQL片段
- """
- template = """
- 你是 数据库 {dialect} SQL分析助手, 仔细分析SQL语句: {sql} ,判断是否语法正确,如果是 直接返回 1,否则说明错误地方,并返回 0。
-
- ### 分析要求
- 1. 判断子查询的正确性
- 2. 判断整体语句的正确性
- 3. 如果存在语法错误,给出说明
-
- 不要重复用户问题、不要分析过程、直接回答。
- """
- pt = ChatPromptTemplate.from_template(template)
- chain = pt | self.llm
- stopped = False
- response = ""
- config = {"configurable": {"thread_id": state["session"]}}
- async for chunk in chain.astream({"sql": state["question"], "dialect": state["dialect"]}, config=config):
- current_state = self.memory.get(config)
- if current_state and current_state["channel_values"]["stop"]:
- stopped = True
- break
- response += chunk.content
- if stopped:
- return {"stop": True, "error": "用户中止", "status": "canceled", "lineage": {}}
- if response == "0":
- return {"status": "invalid", "error": "无效问题或SQL存在语法错误,请重新描述!", "lineage": {}}
- return {"status": "success"}
- async def __merge_result(self, state: AgentState):
- """
- 合并所有节点结果,输出血缘关系
- """
- result = {}
- result["file_name"] = state["file_name"]
- source_tables = state["source_tables"]
- target_table = state["target_table"]
- # 设置目标表依赖的来源表数量和字段数量
- target_table["src_table_size"] = len(source_tables)
- target_table["col_size"] = len(target_table["col_dep"])
- result["target_table"] = target_table
- result["source_tables"] = source_tables
- result["where_list"] = state.get("where_list",[])
- result["join_list"] = state["join_list"]
- return {"lineage": result}
- async def __extract_source(self, state: AgentState):
- """
- 根据SQL,提取其中的来源库表信息
- """
- dt1 = datetime.now()
- # 默认DB
- default_db = "TMP"
- if len(self.target_table) > 0:
- arr = self.target_table.split(".")
- if len(arr) == 2: # 使用识别出的目标表db作为默认DB
- default_db = arr[0]
- template = """
- 你是SQL数据血缘分析专家,仔细分析SQL片段和上下文信息, 从SQL片段中提取出 来源表信息。
-
- ### SQL和上下文信息如下
- - SQL片段:{sql}
- - 数据库方言: {dialect}
- - 变量设置信息:{var}
- - 已有的建表信息: {create_table_info}
-
- ### 核心要求:
- 1. 提取 来源表数据库名称(无则使用 {default_db} 填充)
- 2. 提取 来源表名,包括 from子句、子查询、嵌套子查询、union语句、with语句中出现的物理表名(注意:不要带库名前缀)
- 3. 提取 来源表别名(无则留空)
- 4. 提取来源表的所有字段信息,提取来自几种:
- - 从select中提取,如果带有表别名,则转换成物理表名
- - 从where过滤条件中提取
- - 从关联条件(inner join,left join,right join,full join)中提取
- - 从group中提取
- - 从函数参数中提取,比如函数 COALESCE(T13.ECIF_CUST_ID,T2.RELAT_PTY_ID,'') AS ECIF_CUST_ID,提取出T13.ECIF_CUST_ID和T2.RELAT_PTY_ID
- - 从表达式参数中提取
- 5. 判断来源表是否是临时表
- 6. 不要包含目标表(目标表指最终插入的表)
-
- ### 输出规范
- {format_instructions}
-
- 不要重复用户问题,不要分析中间过程,直接给出答案。
- """
- parser = PydanticOutputParser(pydantic_object=SourceTableList)
- pt = ChatPromptTemplate.from_template(template=template).partial(format_instructions=parser.get_format_instructions())
- chain = pt | self.llm_coder | parser
- answer = {}
- config = {"configurable": {"thread_id": state["session"]}}
- stopped = False
- async for chunk in chain.astream({
- "sql": state["question"],
- "dialect": state["dialect"],
- "var": self.var_list or [],
- "default_db": default_db,
- "create_table_info": self.ct_list
- }):
- current_state = self.memory.get(config)
- if current_state and current_state["channel_values"]["stop"]:
- stopped = True
- break
- # 设置目标表字段数量,来源表数量
- # 合并相同源表
- table_list: list[SourceTable] = []
- for table in chunk.source_tables:
- if table not in table_list:
- table_list.append(table)
- else:
- idx = table_list.index(table)
- tt = table_list[idx]
- # 合并相同的列
- for col in table.col_list:
- if col not in tt.col_list:
- tt.col_list.append(col)
- answer["source_tables"] = [table.model_dump() for table in table_list]
- dt2 = datetime.now()
- print("提取源表用时(sec):", (dt2-dt1).seconds)
- if stopped:
- return {"status": "canceled", "error": "用户中止", "lineage": {}}
- return answer
- async def __extract_target(self, state: AgentState):
- """
- 根据SQL,提取其中的目标表信息
- """
- if state["sql_type"] == "1":
- return {"target_table": {"database":"", "table_name": "", "col_list": [], "col_size": 0, "col_dep": []}}
- # 默认DB
- default_db = "TMP"
- if len(self.target_table) > 0:
- arr = self.target_table.split(".")
- if len(arr) == 2: # 使用识别出的目标表db作为默认DB
- default_db = arr[0]
- template = """
- 你是SQL数据血缘分析专家,仔细分析SQL片段和以下上下文:
- - SQL片段: {sql}
- - 来源表信息: {source_tables}
- - 数据库方言: {dialect}
- - 变量设置信息:{var}
- - 已有的建表信息: {create_table_info}
-
- ### 核心要求:
- 1. **目标表识别**:
- - 提取 INSERT INTO 后的表作为目标表
- - 格式:`数据库名.表名`(未指定库名则使用 {default_db} 填充)
- 2. **字段依赖分析**:
- - 目标表字段必须溯源到来源表的物理字段
- - 字段获取方式分类:
- - `直取`:直接引用源字段(如 `src.col`)
- - `函数`:通过函数转换(如 `COALESCE(col1, col2)`)
- - `表达式`:计算表达式(如 `col1 + col2`)
- - `常量`:固定值(如 `'2023'`)
- - `子查询`:嵌套SELECT结果
- - 字段来源格式:`源库.源表.字段名`(不要使用表别名,如果源库未指定则使用 {default_db}替换)
- - 字段来源信息只能包含 库名、表名和字段信息,如果经过函数或表达式处理,则从参数中提取出具体的字段信息
- 3. **关键约束**:
- - 每个目标字段必须有详细的来源说明
- - 函数参数需完整展开(如 `COALESCE(T1.id, T2.id)` 需解析所有参数)
- - 表名必须是 英文字符 构成(非汉字构成)
-
- ### 输出规范
- {format_instructions}
-
- 请直接输出JSON格式结果,无需解释过程。
- """
- parser = PydanticOutputParser(pydantic_object=TargetTable)
- pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
- chain = pt | self.llm_coder | parser
- answer = {}
- config = {"configurable": {"thread_id": state["session"]}}
- dt1 = datetime.now()
- print("开始解析目标表...")
- stopped = False
- async for chunk in chain.astream({
- "sql": state["question"],
- "dialect": state["dialect"],
- "source_tables": state["source_tables"],
- "var": self.var_list or [],
- "default_db": default_db,
- "default_db": default_db,
- "create_table_info": self.ct_list
- }):
- current_state = self.memory.get(config)
- if current_state and current_state["channel_values"]["stop"]:
- stopped = True
- break
- answer["target_table"] = chunk.model_dump()
- if stopped:
- return {"status": "canceled", "error": "用户中止", "lineage": {}}
- dt2 = datetime.now()
- print("提取目标表用时(sec):", (dt2 - dt1).seconds)
- print(f"target:{answer}")
- return answer
- async def __extract_where(self, state: AgentState):
- """
- 根据SQL,提取来源表 where 信息
- """
- # 默认DB
- default_db = "TMP"
- if len(self.target_table) > 0:
- arr = self.target_table.split(".")
- if len(arr) == 2: # 使用识别出的目标表db作为默认DB
- default_db = arr[0]
- dt1 = datetime.now()
- template = """
- 你是SQL数据血缘分析专家,仔细分析 SQL语句和上下文,提取 where 条件信息。
- ### SQL语句和上下文如下
- - SQL语句: {sql}
- - 数据库方言: {dialect}
- - 变量设置信息: {var}
-
- ### 提取要求:
- 1. 提取 所有 来源表 中的 where 条件信息,包含 where 中出现的 表原名(非别名)、 表所在数据库名(如果没有则使用 {default_db} 填充)、字段名、条件操作符和条件值
- 2. 字段名不能包括函数,比如 length(INT_ORG_NO),提取的字段名是 INT_ORG_NO
- 3. 如果SQL语句中存在变量,则根据变量设置信息转换成实际值
-
- ### 输出规范
- {format_instructions}
-
- 不要重复用户问题,不要分析中间过程,直接给出答案。
- """
- parser = PydanticOutputParser(pydantic_object=WhereCondList)
- pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
- chain = pt | self.llm_coder | parser
- answer = {}
- config = {"configurable": {"thread_id": state["session"]}}
- stopped = False
- async for chunk in chain.astream({
- "sql": state["question"],
- "dialect": state["dialect"],
- "var": self.var_list or [],
- "default_db": default_db
- }):
- current_state = self.memory.get(config)
- if current_state and current_state["channel_values"]["stop"]:
- stopped = True
- break
- # 设置目标表字段数量,来源表数量
- dt2 = datetime.now()
- if stopped:
- return {"status": "canceled", "error": "用户中止", "lineage": {}}
- return answer
- async def __extract_join(self, state: AgentState):
- """
- 根据SQL,提取 表关联 信息
- """
- # 默认DB
- default_db = "TMP"
- if len(self.target_table) > 0:
- arr = self.target_table.split(".")
- if len(arr) == 2: # 使用识别出的目标表db作为默认DB
- default_db = arr[0]
- dt1 = datetime.now()
- template = """
- 你是SQL数据血缘分析专家,负责提取完整的JOIN关系。仔细分析SQL片段:
- - SQL片段: {sql}
- - 数据库方言: {dialect}
-
- ### 关键要求:
- 1. **SQL语句中必须存在关联条件(包括inner、left join、right join、full join)**
- 2. **识别JOIN结构**:
- - 提取FROM子句中的主表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {default_db} 填充
- - 提取每个JOIN子句中的从表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {j_default_db} 填充
- - 明确标注JOIN类型(INNER/LEFT/RIGHT/FULL)
- 3. **关联字段提取**:
- - 必须提取每对关联字段的完整信息:
- * 主表端:表别名.字段名
- * 从表端:表别名.字段名
- - 明确标注关联操作符(=, >, <等)
- - 多条件JOIN拆分为多个字段对
- 4. **特殊场景处理**:
- - 子查询作为表时,表名填写子查询别名
- - 隐式JOIN(WHERE条件)需转换为显式JOIN结构
- - 多表JOIN时保持原始顺序
- - 如果关联字段是 常量,则关联字段取空
- - 如果关联字段经函数处,则提取参数中的 字段,如果参数是常量,则关联字段取空
-
- ### 输出规范
- {format_instructions}
- 不要重复用户问题,不要中间思考过程,直接回答。
- """
- parser = PydanticOutputParser(pydantic_object=JoinRelationList)
- pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
- chain = pt | self.llm_coder | parser
- answer = {}
- config = {"configurable": {"thread_id": state["session"]}}
- stopped = False
- async for chunk in chain.astream({
- "sql": state["question"],
- "dialect": state["dialect"],
- "default_db": default_db,
- "j_default_db": default_db
- }):
- current_state = self.memory.get(config)
- if current_state and current_state["channel_values"]["stop"]:
- stopped = True
- break
- # 设置目标表字段数量,来源表数量
- answer["join_list"] = [join.model_dump() for join in chunk.join_list]
- dt2 = datetime.now()
- if stopped:
- return {"status": "canceled", "error": "用户中止", "lineage": {}}
- return answer
- def trigger_stop(self, session: str):
- """外部触发中止"""
- config = {"configurable": {"thread_id": session}}
- current_state = self.memory.get(config)
- if current_state:
- current_state["channel_values"]["stop"] = True
- async def main(sql_content: str, sql_file: str, dialect: str, owner: str):
- agent = SqlLineageAgent()
- result = await agent.ask_question(sql_content=sql_content, session="s-1", dialect=dialect, sql_file=sql_file, owner=owner)
- return result
- def read_var_file(var_file: str):
- """
- 读取变量文件
- :param var_file: 变量文件(格式excel)
- """
- import pandas as pd
- # 读取变量excel文件
- result = []
- df = pd.read_excel(var_file)
- for row in df.itertuples():
- print(f"变量 {row.Index + 1}: {row.变量}, {row.含义}")
- item = {}
- arr = row.变量.split("=")
- item["var_name"] = arr[0]
- item["var_value"] = arr[1]
- item["var_comment"] = row.含义
- result.append(item)
- return result
- def replace_vars(sql_file: str, var_list: list[dict]):
- """
- 替换 sql_file 中的变量
- :param sql_file: SQL文件
- :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}]
- """
- new_content = ""
- encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312']
- for encoding in encodings:
- try:
- with open(sql_file, 'r', encoding=encoding) as f:
- # 读取文件内容
- content = f.read()
- # 替换变量
- for var in var_list:
- name = var["var_name"]
- value = var["var_value"]
- new_content = content.replace(name, value)
- content = new_content
- break
- except UnicodeDecodeError:
- print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持")
- continue
- else:
- print("无法找到合适的编码")
- raise Exception(f"无法读取文件 {sql_file}, 未找到合适的编码.")
- return new_content
- def get_ddl_list(ddl_dir: str, var_list: list[dict]):
- """
- 获取 DDL目录中定义的所有建表DDL,并替换其中的变量
- :param ddl_dir: DDL目录,目录结构 ddl/db/xx.ddl
- :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}]
- """
- file_list = []
- # 遍历目录,获取所有.ddl文件
- for root, dirs, files in os.walk(ddl_dir):
- for file in files:
- if file.endswith(".ddl"):
- file_path = os.path.join(root, file)
- file_list.append(file_path)
- print(f"DDL文件数量:{len(file_list)}")
- # 读取ddl文件,替换其中的变量
- result = []
- for file in file_list:
- fp = Path(file)
- # 文件的名称规则:db__table,解析出数据库和表名
- name = fp.name.replace(".ddl", "")
- arr = name.split("__")
- db = arr[0]
- table = arr[1]
- # key
- key = f"{db.upper()}.{table.upper()}"
- ddl_content = ""
- encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312']
- for encoding in encodings:
- try:
- with open(file, 'r', encoding=encoding) as f:
- # 读取文件内容
- content = f.read()
- # 替换变量
- for var in var_list:
- name = var["var_name"]
- value = var["var_value"]
- new_content = content.replace(name, value)
- content = new_content
- result.append({key: content})
- break
- except UnicodeDecodeError:
- print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持")
- continue
- else:
- print("无法找到合适的编码")
- raise Exception(f"无法读取文件 {file}, 未找到合适的编码.")
- print(f"result:{result[0]}")
- return result
- if __name__ == "__main__":
- import sys
- import os
- from pathlib import Path
- import json
- # 参数检查
- if len(sys.argv) <= 4:
- print(f"usage python sql_lineage_agent_xmgj.py var_file sql_file dialect owner")
- exit(1)
- # 变量文件
- var_file = sys.argv[1]
- # sql脚本文件或目录
- sql_file = sys.argv[2]
- # 数据库方言
- dialect = sys.argv[3]
- # 属主
- owner = sys.argv[4]
- print(f"1、解析SQL文件/目录:{sql_file}")
- print(f"2、SQL文件属主平台:{owner}")
- print(f"3、变量文件:{var_file}")
- print(f"4、SQL数据库方言: {dialect}")
- # 检查变量文件是否存在
- var_fp = Path(var_file)
- if not var_fp.exists():
- raise FileNotFoundError(f"变量文件 {var_file} 不存在,请指定正确路径.")
- # 检查SQL脚本文件是否存在
- sql_fp = Path(sql_file)
- if not sql_fp.exists():
- raise FileNotFoundError(f"SQL脚本文件 {sql_file} 不存在,请指定正确路径.")
- # 读取变量文件,获取变量值
- var_list = read_var_file(var_file)
- # 读取SQL文件, 替换变量
- sql_content = replace_vars(sql_file=sql_file, var_list=var_list)
- # 解析SQL血缘
- result = asyncio.run(main(sql_content=sql_content, sql_file=sql_fp.name, dialect=dialect, owner=owner))
- # 将结果写入同级目录下,文件后缀为.json
- target_file = Path(sql_fp.parent.absolute() / (sql_fp.name + ".json"))
- print(f"写目标文件:{target_file}")
- # 写文件
- with open(target_file, 'w', encoding="utf-8") as t:
- t.write(json.dumps(result, ensure_ascii=False, indent=2))
|