|
|
@@ -0,0 +1,1298 @@
|
|
|
+
|
|
|
+
|
|
|
+from langchain_core.prompts import ChatPromptTemplate
|
|
|
+from langgraph.graph import START, StateGraph, END
|
|
|
+
|
|
|
+from llmops.agents.datadev.llm import get_llm, get_llm_coder
|
|
|
+from typing import List
|
|
|
+from typing_extensions import TypedDict
|
|
|
+from pydantic import BaseModel, Field
|
|
|
+from llmops.agents.datadev.memory.memory_saver_with_expiry2 import MemorySaverWithExpiry
|
|
|
+from langchain_core.output_parsers import PydanticOutputParser
|
|
|
+from config import linage_agent_config
|
|
|
+from llmops.agents.datadev.tools.timeit import timeit
|
|
|
+
|
|
|
+import asyncio
|
|
|
+from datetime import datetime
|
|
|
+
|
|
|
+class Column(BaseModel):
|
|
|
+ """
|
|
|
+ 字段信息
|
|
|
+ """
|
|
|
+ col_name: str = Field(description="字段名称")
|
|
|
+
|
|
|
+ def __eq__(self, other):
|
|
|
+ if isinstance(other, Column):
|
|
|
+ return self.col_name == other.col_name
|
|
|
+ return False
|
|
|
+
|
|
|
+class ColumnDep(BaseModel):
|
|
|
+ """
|
|
|
+ 目标表字段信息,包含依赖关系
|
|
|
+ """
|
|
|
+ col_name: str = Field(description="目标表字段名称")
|
|
|
+ col_comment: str = Field(description="目标表字段说明")
|
|
|
+ from_cols: List[str] = Field(description="字段来源信息,格式是 源库名.源表名.源字段名 或 中间表名.中间表字段名")
|
|
|
+ dep_type: List[str] = Field(description="字段获取方式 1:直取 2:函数 3:表达式")
|
|
|
+ desp: List[str] = Field(description="详细解释选取来源字段的原因")
|
|
|
+
|
|
|
+class SourceTable(BaseModel):
|
|
|
+ """
|
|
|
+ 单个来源信息表信息
|
|
|
+ """
|
|
|
+ database: str = Field(description="来源数据库名称")
|
|
|
+ table_name: str = Field(description="来源表名称")
|
|
|
+ table_name_alias: str = Field(default="", description="来源表别名")
|
|
|
+ col_list: List[Column] = Field(description="来源表的字段集合")
|
|
|
+ is_temp: bool = Field(description="是否是临时表")
|
|
|
+
|
|
|
+ def __eq__(self, other):
|
|
|
+ if isinstance(other, SourceTable):
|
|
|
+ return self.database == other.database and self.table_name == other.table_name
|
|
|
+ return False
|
|
|
+
|
|
|
+class SourceTableList(BaseModel):
|
|
|
+ """
|
|
|
+ 所有来源信息表信息
|
|
|
+ """
|
|
|
+ source_tables: List[SourceTable] = Field(description="所有来源表信息")
|
|
|
+
|
|
|
+class WhereCondItem(BaseModel):
|
|
|
+ """
|
|
|
+ 单个表过滤条件信息
|
|
|
+ """
|
|
|
+ database: str = Field(description="数据库名")
|
|
|
+ table_name: str = Field(description="表原名")
|
|
|
+ col_name: str = Field(description="条件字段名称")
|
|
|
+ operator: str = Field(default="=", description="条件操作符,如=,>,<,in")
|
|
|
+ value: str = Field(description="条件值")
|
|
|
+
|
|
|
+
|
|
|
+class WhereCondList(BaseModel):
|
|
|
+ """
|
|
|
+ 表过滤条件集合
|
|
|
+ """
|
|
|
+ where_list: List[WhereCondItem] = Field(description="所有的where条件")
|
|
|
+
|
|
|
+
|
|
|
+class JoinColumnPair(BaseModel):
|
|
|
+ """关联字段对"""
|
|
|
+ main_table_alias: str = Field(description="主表别名")
|
|
|
+ main_column: str = Field(description="主表关联字段")
|
|
|
+ join_table_alias: str = Field(description="从表别名")
|
|
|
+ join_column: str = Field(description="从表关联字段")
|
|
|
+ operator: str = Field(default="=", description="关联操作符,如=,>,<")
|
|
|
+
|
|
|
+
|
|
|
+class JoinRelation(BaseModel):
|
|
|
+ """完整的JOIN关系"""
|
|
|
+ main_database: str = Field(description="主表数据库名")
|
|
|
+ main_table: str = Field(description="主表物理表名")
|
|
|
+ main_alias: str = Field(description="主表别名")
|
|
|
+ join_type: str = Field(description="JOIN类型:INNER/LEFT/RIGHT/FULL")
|
|
|
+ join_database: str = Field(description="从表数据库名")
|
|
|
+ join_table: str = Field(description="从表物理表名")
|
|
|
+ join_alias: str = Field(description="从表别名")
|
|
|
+ column_pairs: List[JoinColumnPair] = Field(description="关联字段对列表")
|
|
|
+
|
|
|
+
|
|
|
+class JoinRelationList(BaseModel):
|
|
|
+ """
|
|
|
+ join关系集合
|
|
|
+ """
|
|
|
+ join_list: List[JoinRelation] = Field(description="所有的join关系")
|
|
|
+
|
|
|
+
|
|
|
+class TargetTable(BaseModel):
|
|
|
+ """
|
|
|
+ 目标表信息
|
|
|
+ """
|
|
|
+ database: str = Field(default="", description="目标表的数据库名称")
|
|
|
+ table_name: str = Field(description="目标表名称")
|
|
|
+ col_size: int = Field(description="目标表字段数量")
|
|
|
+ col_list: List[str] = Field(description="目标表字段集合")
|
|
|
+ col_dep: List[ColumnDep] = Field(description="目标表字段依赖信息")
|
|
|
+ src_table_size: int = Field(description="依赖来源表数量")
|
|
|
+
|
|
|
+
|
|
|
+class CreateTable(BaseModel):
|
|
|
+ """
|
|
|
+ 建表信息
|
|
|
+ """
|
|
|
+ database: str = Field(description="数据库名")
|
|
|
+ table: str = Field(description="表名")
|
|
|
+ col_list: list[str] = Field(description="字段名集合")
|
|
|
+ source_table_list: list[str] = Field(description="来源表集合")
|
|
|
+
|
|
|
+
|
|
|
+class AgentState(TypedDict):
|
|
|
+ """
|
|
|
+ Agent状态
|
|
|
+ """
|
|
|
+ question: str
|
|
|
+ session: str
|
|
|
+ dialect: str
|
|
|
+ sql: str # 带行号的SQL语句
|
|
|
+ sql_type: str # sql操作类型
|
|
|
+ lineage: dict
|
|
|
+ status: str
|
|
|
+ error: str
|
|
|
+ stop: bool
|
|
|
+ file_name: str # 脚本文件
|
|
|
+ target_table: dict # 目标表信息
|
|
|
+ source_tables: list[dict] # 来源表信息
|
|
|
+ where_list: list[dict] # where条件信息
|
|
|
+ join_list: list[dict] # 表关联信息
|
|
|
+
|
|
|
+
|
|
|
+class SqlLineageAgent:
|
|
|
+ """
|
|
|
+ 解析用户提供的SQL语句/SQL脚本中的数据血缘关系
|
|
|
+ """
|
|
|
+ def __init__(self):
|
|
|
+ # 读取配置文件
|
|
|
+ c = linage_agent_config
|
|
|
+ # 数据库方言
|
|
|
+ self.db_dialect = c["dialect"]
|
|
|
+ # 并发度
|
|
|
+ self.concurrency = c["concurrency"]
|
|
|
+
|
|
|
+ # 大模型实例
|
|
|
+ self.llm = get_llm()
|
|
|
+ self.llm_coder = get_llm_coder()
|
|
|
+
|
|
|
+ self.memory = MemorySaverWithExpiry(expire_time=600, clean_interval=60)
|
|
|
+ # 构建图
|
|
|
+ self.graph = self._build_graph()
|
|
|
+ # SQL 语句中设置的变量
|
|
|
+ self.var_list: list[str] = []
|
|
|
+ # 最终的来源表集合
|
|
|
+ self.final_source_table_list: list[dict] = []
|
|
|
+ # 最终的中间表集合
|
|
|
+ self.final_mid_table_list: list[dict] = []
|
|
|
+ # 最终的where集合
|
|
|
+ self.final_where_list: list[dict] = []
|
|
|
+ # 最终的join集合
|
|
|
+ self.final_join_list: list[dict] = []
|
|
|
+ # 最终的目标表集合
|
|
|
+ self.final_target_list: list[dict] = []
|
|
|
+
|
|
|
+ # 最终目标表,需合并
|
|
|
+ self.final_target_table = {}
|
|
|
+
|
|
|
+ # 目标表 格式 库名.表名
|
|
|
+ self.target_table = ""
|
|
|
+
|
|
|
+ # 内部建表信息
|
|
|
+ self.ct_list = []
|
|
|
+
|
|
|
+
|
|
|
+ def _build_graph(self):
|
|
|
+ # 构造计算流程
|
|
|
+ graph_builder = StateGraph(AgentState)
|
|
|
+
|
|
|
+ # 添加节点
|
|
|
+ graph_builder.add_node("_sql_type", self._sql_type)
|
|
|
+ graph_builder.add_node("_extract_set", self._extract_set)
|
|
|
+ graph_builder.add_node("_extract_create_table", self._extract_create_table)
|
|
|
+ graph_builder.add_node("_invalid", self._invalid)
|
|
|
+ graph_builder.add_node("source", self.__extract_source)
|
|
|
+ graph_builder.add_node("target", self.__extract_target)
|
|
|
+ graph_builder.add_node("where", self.__extract_where)
|
|
|
+ graph_builder.add_node("join", self.__extract_join)
|
|
|
+ graph_builder.add_node("merge", self.__merge_result)
|
|
|
+
|
|
|
+ # 添加边
|
|
|
+ graph_builder.add_edge(START, "_sql_type")
|
|
|
+ graph_builder.add_conditional_edges("_sql_type", path=self.__parallel_route2, path_map={
|
|
|
+ "_invalid": "_invalid",
|
|
|
+ "_extract_set": "_extract_set",
|
|
|
+ "_extract_create_table": "_extract_create_table",
|
|
|
+ "where": "where",
|
|
|
+ "join": "join"
|
|
|
+ })
|
|
|
+ graph_builder.add_edge("_invalid", END)
|
|
|
+ graph_builder.add_edge("_extract_set", END)
|
|
|
+ graph_builder.add_edge("_extract_set", END)
|
|
|
+ graph_builder.add_edge("_extract_create_table", END)
|
|
|
+ graph_builder.add_edge("where", "source")
|
|
|
+ graph_builder.add_edge("join", "source")
|
|
|
+ graph_builder.add_edge("source", "target")
|
|
|
+ graph_builder.add_edge("target", "merge")
|
|
|
+ graph_builder.add_edge("merge", END)
|
|
|
+
|
|
|
+ return graph_builder.compile(checkpointer=self.memory)
|
|
|
+
|
|
|
+ def _sql_type(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 判断SQL的操作类型
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+ 你是数据库 {dialect} 专家,对 SQL语句: {sql} 中的操作进行分类, 直接返回分类ID(数字)
|
|
|
+
|
|
|
+ ### 操作分类标准如下:
|
|
|
+ 1:create, 如建表、建视图等
|
|
|
+ 2:insert, 向表中插入数据
|
|
|
+ 3:set、设置参数操作、变量赋值操作
|
|
|
+ 4:其它
|
|
|
+
|
|
|
+ 不要重复用户问题,不要中间分析过程,直接回答。
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ chain = pt | self.llm
|
|
|
+ response = chain.invoke({"dialect": state["dialect"], "sql": state["question"]})
|
|
|
+ return {"sql_type": response.content}
|
|
|
+
|
|
|
+ def __parallel_route(self, state: AgentState):
|
|
|
+ """并发节点"""
|
|
|
+ status = state["status"]
|
|
|
+ if status == "invalid" or status == "canceled" or state["sql_type"] == "3": # 无效问题 或 用户中止
|
|
|
+ return END
|
|
|
+
|
|
|
+ return "continue"
|
|
|
+
|
|
|
+ def __parallel_route2(self, state: AgentState):
|
|
|
+ """并发节点"""
|
|
|
+ if state["sql_type"] == "4": # 非create和非insert操作
|
|
|
+ return "_invalid"
|
|
|
+ if state["sql_type"] == "3": # 设置参数
|
|
|
+ return "_extract_set"
|
|
|
+ if state["sql_type"] == "1": # 建表
|
|
|
+ return "_extract_create_table"
|
|
|
+
|
|
|
+ return ["where", "join"]
|
|
|
+
|
|
|
+ def _invalid(self, state: AgentState):
|
|
|
+ """
|
|
|
+ sql_type == 4,不关注的语句,直接跳过
|
|
|
+ """
|
|
|
+ return {
|
|
|
+ "lineage": {
|
|
|
+ "source_tables": [],
|
|
|
+ "join_list": [],
|
|
|
+ "where_list": [],
|
|
|
+ "target_table": {}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ async def _extract_set(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 提取 SQL 中设置的参数变量
|
|
|
+ """
|
|
|
+ if state["sql_type"] != "3":
|
|
|
+ return state
|
|
|
+
|
|
|
+ template = """
|
|
|
+ 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 参数设置的全部内容,包括变量赋值、注释。
|
|
|
+ 不要重复用户问题,不需要中间思考过程,直接回答。
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ chain = pt | self.llm_coder
|
|
|
+ response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]})
|
|
|
+ self.var_list.append(response.content)
|
|
|
+ return {
|
|
|
+ "lineage": {
|
|
|
+ "source_tables": [],
|
|
|
+ "join_list": [],
|
|
|
+ "where_list": [],
|
|
|
+ "target_table": {}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ async def _extract_create_table(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 提取建表信息
|
|
|
+ """
|
|
|
+ if state["sql_type"] != "1":
|
|
|
+ return state
|
|
|
+
|
|
|
+ template = """
|
|
|
+ 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 建表信息。
|
|
|
+
|
|
|
+ ### 提取要求
|
|
|
+ 1、提取建表对应的数据库名(无则留空)
|
|
|
+ 2、提取建表对应的表名
|
|
|
+ 3、提取建表对应的字段名
|
|
|
+ 4、提取建表对应的来源表名,来源表名格式是 来源库库名.来源表名
|
|
|
+
|
|
|
+ ### 输出规范
|
|
|
+ {format_instructions}
|
|
|
+
|
|
|
+ 不要重复用户问题,不需要中间思考过程,直接回答。
|
|
|
+ """
|
|
|
+ parser = PydanticOutputParser(pydantic_object=CreateTable)
|
|
|
+ pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
|
|
|
+ chain = pt | self.llm_coder | parser
|
|
|
+ response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]})
|
|
|
+ #
|
|
|
+ print(f"extract-create-table: {response}")
|
|
|
+ self.ct_list.append(response.model_dump())
|
|
|
+ return {
|
|
|
+ "lineage": {
|
|
|
+ "source_tables": [],
|
|
|
+ "join_list": [],
|
|
|
+ "where_list": [],
|
|
|
+ "target_table": {}
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ async def _parse(self, semaphore, sql: str, session: str, dialect: str, file_name: str = "SQL_FILE.SQL"):
|
|
|
+ """
|
|
|
+ 解析SQL(片段)的血缘关系
|
|
|
+ """
|
|
|
+ async with semaphore: # 进入信号量区域(限制并发)
|
|
|
+ config = {"configurable": {"thread_id": session}}
|
|
|
+ # 调用任务流
|
|
|
+ response = await self.graph.ainvoke({
|
|
|
+ "question": sql,
|
|
|
+ "session": session,
|
|
|
+ "dialect": dialect or self.db_dialect,
|
|
|
+ "file_name": file_name,
|
|
|
+ "stop": False,
|
|
|
+ "error": "",
|
|
|
+ "sql_type": "0",
|
|
|
+ "status": "success"}, config)
|
|
|
+ return response
|
|
|
+
|
|
|
+ async def find_target_table(self, sql: str, dialect: str):
|
|
|
+ """
|
|
|
+ 找出目标表
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+ 你是SQL数据血缘分析专家,参照数据库方言{dialect},仔细分析SQL语句,找出目标表并返回。
|
|
|
+ ### SQL片段如下
|
|
|
+ {sql}
|
|
|
+
|
|
|
+ ### 目标表标准
|
|
|
+ 1、目标表一定是出现在 insert into 后的 表
|
|
|
+ 2、最后 insert into 后的表为目标表
|
|
|
+ 3、目标表只有一张
|
|
|
+
|
|
|
+ ### 核心要求
|
|
|
+ 1、目标表的返回格式是 数据库名.表名 (如果无数据库名,则留空)
|
|
|
+ 2、SQL语句中可能包括多个insert into语句,但要从全局分析,找出最终插入的目标表
|
|
|
+
|
|
|
+ 不要重复用户问题、不要中间分析过程,直接回答。
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ chain = pt | self.llm_coder
|
|
|
+ response = await chain.ainvoke({"sql": sql, "dialect": dialect})
|
|
|
+ return response.content.strip().upper()
|
|
|
+
|
|
|
+ def split_sql_statements(self, sql):
|
|
|
+ """
|
|
|
+ 将SQL语句按分号分段,同时正确处理字符串和注释中的分号
|
|
|
+ 参数:
|
|
|
+ sql: 包含一个或多个SQL语句的字符串
|
|
|
+ 返回:
|
|
|
+ 分割后的SQL语句列表
|
|
|
+ """
|
|
|
+ # 状态变量
|
|
|
+ in_single_quote = False
|
|
|
+ in_double_quote = False
|
|
|
+ in_line_comment = False
|
|
|
+ in_block_comment = False
|
|
|
+ escape_next = False
|
|
|
+
|
|
|
+ statements = []
|
|
|
+ current_start = 0
|
|
|
+ i = 0
|
|
|
+
|
|
|
+ while i < len(sql):
|
|
|
+ char = sql[i]
|
|
|
+
|
|
|
+ # 处理转义字符
|
|
|
+ if escape_next:
|
|
|
+ escape_next = False
|
|
|
+ i += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 处理字符串和注释中的情况(这些地方的分号不应该作为分隔符)
|
|
|
+ if in_line_comment:
|
|
|
+ if char == '\n':
|
|
|
+ in_line_comment = False
|
|
|
+ elif in_block_comment:
|
|
|
+ if char == '*' and i + 1 < len(sql) and sql[i + 1] == '/':
|
|
|
+ in_block_comment = False
|
|
|
+ i += 1 # 跳过下一个字符
|
|
|
+ elif in_single_quote:
|
|
|
+ if char == "'":
|
|
|
+ in_single_quote = False
|
|
|
+ elif char == '\\':
|
|
|
+ escape_next = True
|
|
|
+ elif in_double_quote:
|
|
|
+ if char == '"':
|
|
|
+ in_double_quote = False
|
|
|
+ elif char == '\\':
|
|
|
+ escape_next = True
|
|
|
+ else:
|
|
|
+ # 不在字符串或注释中,检查是否进入这些状态
|
|
|
+ if char == "'":
|
|
|
+ in_single_quote = True
|
|
|
+ elif char == '"':
|
|
|
+ in_double_quote = True
|
|
|
+ elif char == '-' and i + 1 < len(sql) and sql[i + 1] == '-':
|
|
|
+ in_line_comment = True
|
|
|
+ i += 1 # 跳过下一个字符
|
|
|
+ elif char == '/' and i + 1 < len(sql) and sql[i + 1] == '*':
|
|
|
+ in_block_comment = True
|
|
|
+ i += 1 # 跳过下一个字符
|
|
|
+ elif char == ';':
|
|
|
+ # 找到真正的分号分隔符
|
|
|
+ statement = sql[current_start:i].strip()
|
|
|
+ if statement:
|
|
|
+ statements.append(statement)
|
|
|
+ current_start = i + 1
|
|
|
+
|
|
|
+ i += 1
|
|
|
+
|
|
|
+ # 添加最后一个语句(如果没有以分号结尾)
|
|
|
+ if current_start < len(sql):
|
|
|
+ statement = sql[current_start:].strip()
|
|
|
+ if statement:
|
|
|
+ statements.append(statement)
|
|
|
+
|
|
|
+ return statements
|
|
|
+
|
|
|
+ async def find_target_table2(self, results: list[dict]) -> str:
|
|
|
+ """
|
|
|
+ 根据各段解析结果查找出目标表
|
|
|
+ 规则:一个段的目标表如果没有出现在其它各段的来源表中,即为目标表
|
|
|
+ """
|
|
|
+ target_table = ""
|
|
|
+ exists = False
|
|
|
+ for i, item in enumerate(results):
|
|
|
+ tt = item["lineage"]["target_table"]
|
|
|
+ # 取一个目标表
|
|
|
+ db = tt.get("database", "").strip().upper()
|
|
|
+ table = tt.get("table_name", "").strip().upper()
|
|
|
+ exists = False
|
|
|
+ if len(table) > 0:
|
|
|
+ # 从源表里面查找
|
|
|
+ for j, item2 in enumerate(results):
|
|
|
+ if i != j: # 不跟自身比
|
|
|
+ st_list = item2["lineage"]["source_tables"]
|
|
|
+ for st in st_list:
|
|
|
+ sdb = st.get("database", "").strip().upper()
|
|
|
+ stable = st.get("table_name", "").strip().upper()
|
|
|
+ if db == sdb and table == stable: # 目标表在源表中存在
|
|
|
+ exists = True
|
|
|
+ break
|
|
|
+ if exists:
|
|
|
+ break
|
|
|
+ if not exists: # 说明目标表不在其它各段的源表中存在
|
|
|
+ target_table = ".".join([db, table])
|
|
|
+ break
|
|
|
+
|
|
|
+ return target_table.strip().upper()
|
|
|
+
|
|
|
+ async def ask_question(self, sql_content: str, session: str, dialect: str, owner: str, sql_file: str = "SQL_FILE.SQL"):
|
|
|
+ """
|
|
|
+ 功能:根据用户问题,解析SQL中的数据血缘关系。先分段解析,再组织合并。
|
|
|
+ :param sql_content: SQL内容
|
|
|
+ :param session: 会话
|
|
|
+ :param dialect: 数据库方言
|
|
|
+ :param owner: 脚本平台属主
|
|
|
+ :param var_map: 变量关系
|
|
|
+ :sql_file: 脚本名称
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+ # 通过session保持记忆
|
|
|
+ t1 = datetime.now()
|
|
|
+ # 最终解析结果
|
|
|
+ final_result = {}
|
|
|
+ final_result["task_id"] = session
|
|
|
+ final_result["file_name"] = sql_file
|
|
|
+ final_result["owner"] = owner
|
|
|
+ final_result["status"] = "success"
|
|
|
+ final_result["error"] = ""
|
|
|
+ lineage = {}
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 根据全局SQL找出目标表
|
|
|
+ self.target_table = await self.find_target_table(sql_content, dialect)
|
|
|
+ print(f"大模型查找目标表:{self.target_table}, {len(self.target_table)}")
|
|
|
+
|
|
|
+ # 对SQL分段,并发执行解析
|
|
|
+ sql_list = self.split_sql_statements(sql=sql_content)
|
|
|
+ print("SQL分段数量:", len(sql_list))
|
|
|
+
|
|
|
+ # 信号量,控制并发量
|
|
|
+ semaphore = asyncio.Semaphore(self.concurrency)
|
|
|
+ task_list = [self._parse(semaphore=semaphore, sql=sql, session=session, dialect=dialect, file_name=sql_file) for sql in sql_list]
|
|
|
+ # 并发分段解析
|
|
|
+ results = await asyncio.gather(*task_list)
|
|
|
+
|
|
|
+ if len(self.target_table) == 0:
|
|
|
+ # 从各段解析结果中查找出目标表
|
|
|
+ self.target_table = await self.find_target_table2(results) or ""
|
|
|
+ print(f"从血缘关系中找出目标表:{self.target_table}")
|
|
|
+ if not self.target_table: # 未出现目标表, 直接返回
|
|
|
+ final_result["status"] = "error",
|
|
|
+ final_result["error"] = "未找到目标表"
|
|
|
+ return final_result
|
|
|
+
|
|
|
+ for response in results:
|
|
|
+ # 来源表、where、join、target
|
|
|
+ self._merge_source_tables(response["lineage"]["source_tables"])
|
|
|
+ self._merge_where(response["lineage"].get("where_list",[]))
|
|
|
+ self._merge_join(response["lineage"]["join_list"])
|
|
|
+ # 合并目标表
|
|
|
+ self._merge_target(response["lineage"]["target_table"])
|
|
|
+
|
|
|
+ # 增加中间表标识
|
|
|
+ self._add_mid_table_label(self.final_source_table_list, self.final_mid_table_list)
|
|
|
+ # 补充中间表依赖字段
|
|
|
+ self._add_mid_table_col(self.final_mid_table_list, self.final_target_list)
|
|
|
+
|
|
|
+ # 多目标表合并
|
|
|
+ self.final_target_table["src_table_size"] = len(self.final_source_table_list)
|
|
|
+ lineage["target_table"] = self.final_target_table
|
|
|
+ lineage["source_tables"] = self.final_source_table_list
|
|
|
+ lineage["mid_table_list"] = self.final_mid_table_list
|
|
|
+ lineage["join_list"] = self.final_join_list
|
|
|
+ lineage["where_list"] = self.final_where_list
|
|
|
+ final_result["lineage"] = lineage
|
|
|
+ except Exception as e:
|
|
|
+ final_result["status"] = "error"
|
|
|
+ final_result["error"] = str(e)
|
|
|
+ print(str(e))
|
|
|
+
|
|
|
+ t2 = datetime.now()
|
|
|
+ print("总共用时(sec):", (t2-t1).seconds)
|
|
|
+ parse_time = (t2 - t1).seconds
|
|
|
+ # 解析时间(秒)
|
|
|
+ final_result["parse_time"] = parse_time
|
|
|
+ final_result["parse_end_time"] = t2.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ return final_result
|
|
|
+
|
|
|
+ def _merge_source_tables(self, source_table_list: list[dict]):
|
|
|
+ """
|
|
|
+ 合并分段来源表
|
|
|
+ """
|
|
|
+ if len(source_table_list) > 0:
|
|
|
+ # 目标源表的keys
|
|
|
+ table_keys = [(t.get("database","").strip().upper(), t.get("table_name","").strip().upper()) for t in self.final_source_table_list]
|
|
|
+ for st in source_table_list:
|
|
|
+ # 将字典转换为元组
|
|
|
+ db = st.get("database","").strip().upper()
|
|
|
+ table = st.get("table_name","").strip().upper()
|
|
|
+ key = (db, table)
|
|
|
+ if key not in table_keys:
|
|
|
+ table_keys.append(key)
|
|
|
+ self.final_source_table_list.append(st)
|
|
|
+ else:
|
|
|
+ # 合并字段
|
|
|
+ col_list = st.get("col_list", [])
|
|
|
+ # 找出相同的元素, 库名,表名相同
|
|
|
+ e = next(filter(lambda item: item["database"].upper()==db and item["table_name"].upper()==table, self.final_source_table_list), None)
|
|
|
+ if e:
|
|
|
+ final_col_list = e.get("col_list",[])
|
|
|
+ col_keys = [(c.get("col_name","").upper()) for c in final_col_list]
|
|
|
+ for c in col_list:
|
|
|
+ ck = (c["col_name"].upper())
|
|
|
+ if ck not in col_keys and len(c) > 0:
|
|
|
+ col_keys.append(ck)
|
|
|
+ final_col_list.append(c)
|
|
|
+
|
|
|
+ @timeit
|
|
|
+ def _merge_where(self, where_list: list[dict]):
|
|
|
+ """
|
|
|
+ 最终合并所有的where条件
|
|
|
+ """
|
|
|
+ if where_list and len(where_list) > 0:
|
|
|
+ self.final_where_list += where_list
|
|
|
+
|
|
|
+
|
|
|
+ def _merge_join(self, join_list: list[dict]):
|
|
|
+ """
|
|
|
+ 最终合并所有的join条件
|
|
|
+ """
|
|
|
+ if len(join_list) > 0:
|
|
|
+ self.final_join_list += join_list
|
|
|
+
|
|
|
+
|
|
|
+ def _merge_target(self, mid_target_table: dict):
|
|
|
+ """
|
|
|
+ 合并中间目标表,只有 名称是 target_table_name 的表才是真目标表
|
|
|
+ 目标表可能出现多次(如多次插入),需要合并 字段信息(字段名称和字段来源)
|
|
|
+ 对于非目标表的中间目标表,放到中间表集合中
|
|
|
+ """
|
|
|
+ if not mid_target_table: # 中间目标表为空
|
|
|
+ return
|
|
|
+ else: # 将临时目标表与目标表做合并
|
|
|
+ db = mid_target_table.get("database", "").strip().upper()
|
|
|
+ table = mid_target_table.get("table_name", "").strip().upper()
|
|
|
+ col_size = mid_target_table.get("col_list",[])
|
|
|
+ if len(table) == 0 or len(col_size) == 0: # 表名为空 或 无字段 直接返回
|
|
|
+ return
|
|
|
+ if len(db) == 0: # 缺少数据库名, 检查表名中是否带库名
|
|
|
+ arr = table.split('.')
|
|
|
+ if len(arr) == 2: # 说明表名中当库名,解析有错,修正
|
|
|
+ mid_target_table["database"] = arr[0].strip().upper()
|
|
|
+ mid_target_table["table_name"] = arr[1].strip().upper()
|
|
|
+ db = arr[0].strip().upper()
|
|
|
+ table = arr[1].strip().upper()
|
|
|
+ elif len(arr) == 1: # 无库名,只有表名,使用真目标表库名替换
|
|
|
+ arr = self.target_table.split(".")
|
|
|
+ if len(arr) == 2:
|
|
|
+ db = mid_target_table["database"] = arr[0].strip().upper()
|
|
|
+
|
|
|
+ key = ".".join([db, table])
|
|
|
+
|
|
|
+ if key == self.target_table: # 找到最终目标表
|
|
|
+ print(f"合并目标表:{mid_target_table}")
|
|
|
+ if not self.final_target_table: # 目标表为空(首次)
|
|
|
+ self.final_target_table = mid_target_table
|
|
|
+ return
|
|
|
+
|
|
|
+ # 合并字段和字段来源信息
|
|
|
+ # 合并新字段
|
|
|
+ new_col_list = [] # 新字段集
|
|
|
+ col_list = mid_target_table.get("col_list", [])
|
|
|
+ for new_col in col_list:
|
|
|
+ # 检查字段是否存在、忽略大小写
|
|
|
+ if new_col.strip().upper() not in [col.strip().upper() for col in self.final_target_table.get("col_list", [])]: # 不存在
|
|
|
+ print(f"合并新字段:{new_col}")
|
|
|
+ self.final_target_table.get("col_list", []).append(new_col.strip().upper())
|
|
|
+ new_col_list.append(new_col.strip().upper())
|
|
|
+
|
|
|
+ # 合并字段来源信息
|
|
|
+ col_dep_list = mid_target_table.get("col_dep", [])
|
|
|
+ for col_dep in col_dep_list:
|
|
|
+ col_name = col_dep["col_name"].strip().upper()
|
|
|
+ if col_name in new_col_list: # 新字段, 直接添加
|
|
|
+ self.final_target_table.get("col_dep",[]).append(col_dep)
|
|
|
+ else: # 合并来源信息
|
|
|
+ # 找到同名字段,合并来源字段
|
|
|
+ for dep in self.final_target_table.get("col_dep", []):
|
|
|
+ cn = dep["col_name"].strip().upper()
|
|
|
+ if col_name == cn:
|
|
|
+ print(f"合并来源字段, {col_name}")
|
|
|
+ m_from_col = col_dep.get('from_cols', [])
|
|
|
+ f_from_col = dep.get("from_cols", [])
|
|
|
+ print(f"中间表来源字段:{m_from_col}")
|
|
|
+ print(f"目标表来源字段:{f_from_col}")
|
|
|
+ # 将临时中间表的来源字段合并至目标表的来源字段
|
|
|
+ self._merge_col_dep(m_from_col, f_from_col)
|
|
|
+ else:
|
|
|
+ # 加入中间表集合中
|
|
|
+ print(f"加入中间表:{mid_target_table}")
|
|
|
+ self.final_mid_table_list.append(mid_target_table)
|
|
|
+
|
|
|
+ print(f"final tt:{self.final_target_table}")
|
|
|
+
|
|
|
+ def _merge_col_dep(self, mid_from_cols: list[str], final_from_cols: list[str]):
|
|
|
+ """
|
|
|
+ 合并临时中间表来源字段 到目标表来源字段中
|
|
|
+ """
|
|
|
+ for from_col in mid_from_cols:
|
|
|
+ if len(from_col.split(".")) > 0: # 忽略常量,NULL值
|
|
|
+ if from_col.strip().upper() not in [fc.strip().upper() for fc in final_from_cols]:
|
|
|
+ print(f"合并字段:{from_col}")
|
|
|
+ final_from_cols.append(from_col.strip().upper())
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @timeit
|
|
|
+ def _add_mid_table_label(source_table_list: list[dict], mid_table_list: list[dict]):
|
|
|
+ """
|
|
|
+ 给最终的来源表增加中间表标识
|
|
|
+ """
|
|
|
+ if len(source_table_list) > 0 and len(mid_table_list) > 0:
|
|
|
+ mid_table_keys = [(t["database"].strip().upper(), t["table_name"].strip().upper()) for t in mid_table_list]
|
|
|
+ for st in source_table_list:
|
|
|
+ st["is_temp"] = False
|
|
|
+ # 将字典转换为元组
|
|
|
+ key = (st["database"].strip().upper(), st["table_name"].strip().upper())
|
|
|
+ if key in mid_table_keys: # 是中间表
|
|
|
+ st["is_temp"] = True
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @timeit
|
|
|
+ def _add_mid_table_col(mid_table_list: list[dict], target_table_list: list[dict]):
|
|
|
+ """
|
|
|
+ 给中间表补充来源字段,使用对应的目标表col_dep去替换
|
|
|
+ """
|
|
|
+ if len(mid_table_list) > 0 and len(target_table_list) > 0:
|
|
|
+ for mt in mid_table_list:
|
|
|
+ for tt in target_table_list:
|
|
|
+ tt["is_temp"] = False
|
|
|
+ # 目标表的数据库名与表名相同
|
|
|
+ tdn = tt["database"].strip().upper()
|
|
|
+ ttn = tt["table_name"].strip().upper()
|
|
|
+ if len(tdn) == 0 and len(ttn) == 0: # 全部为空
|
|
|
+ tt["is_temp"] = True
|
|
|
+ continue
|
|
|
+ if mt["database"].strip().upper()==tdn and mt["table_name"].strip().upper()==ttn:
|
|
|
+ mt["col_dep"] = tt["col_dep"]
|
|
|
+ tt["is_temp"] = True
|
|
|
+ if len(mt["col_list"]) == 0:
|
|
|
+ mt["col_list"] = [c.upper() for c in tt["col_list"]]
|
|
|
+ mt["col_size"] = len(mt["col_list"])
|
|
|
+
|
|
|
+ @timeit
|
|
|
+ def _merge_final_target_table(self, seg_target_table_list: list[dict], target_table_name: str):
|
|
|
+ """
|
|
|
+ 合并分段目标表内容
|
|
|
+ """
|
|
|
+ # 可能存在多个相同的真正目标表(多次插入目标表情况)
|
|
|
+ for tt in seg_target_table_list:
|
|
|
+ tt["is_target"] = False
|
|
|
+ database = tt["database"].strip()
|
|
|
+ table = tt["table_name"].strip()
|
|
|
+ if len(database) > 0:
|
|
|
+ dt = ".".join([database,table])
|
|
|
+ else:
|
|
|
+ dt = table
|
|
|
+ if dt.upper() == target_table_name.strip().upper():
|
|
|
+ tt["is_target"] = True
|
|
|
+
|
|
|
+ # 合并真正的目标表
|
|
|
+ i = 0
|
|
|
+ result = {}
|
|
|
+ final_col_list: list[str] = []
|
|
|
+ final_col_dep_list: list[dict] = []
|
|
|
+ # 合并目标表信息
|
|
|
+ for target_table in seg_target_table_list:
|
|
|
+ if target_table["is_target"]:
|
|
|
+ if i == 0:
|
|
|
+ result["database"] = target_table["database"]
|
|
|
+ result["table_name"] = target_table["table_name"]
|
|
|
+ i = 1
|
|
|
+ # 合并列信息
|
|
|
+ col_list = target_table["col_list"]
|
|
|
+ for c in col_list:
|
|
|
+ if c not in final_col_list and len(c.strip()) > 0:
|
|
|
+ final_col_list.append(c.upper())
|
|
|
+
|
|
|
+ # 合并列来源信息
|
|
|
+ col_dep_list = target_table["col_dep"]
|
|
|
+ for col_dep in col_dep_list:
|
|
|
+ cn = col_dep["col_name"].strip().upper()
|
|
|
+ from_cols = col_dep["from_cols"]
|
|
|
+ existed = False
|
|
|
+ if len(cn) > 0:
|
|
|
+ for fcd in final_col_dep_list:
|
|
|
+ if cn == fcd["col_name"].upper():
|
|
|
+ existed = True
|
|
|
+ # 合并来源字段
|
|
|
+ final_from_cols = fcd["from_cols"]
|
|
|
+ fcd["from_cols"] += [c.upper() for c in from_cols if c.upper() not in final_from_cols]
|
|
|
+ if not existed:
|
|
|
+ final_col_dep_list.append({
|
|
|
+ "col_name": cn.upper(),
|
|
|
+ "col_comment": col_dep["col_comment"],
|
|
|
+ "from_cols": [c.upper() for c in from_cols]
|
|
|
+ })
|
|
|
+
|
|
|
+ result["col_size"] = len(final_col_list)
|
|
|
+ result["col_list"] = final_col_list
|
|
|
+ result["col_dep"] = final_col_dep_list
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ async def __check(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 检查用户问题是否是有效的SQL语句/SQL片段
|
|
|
+ """
|
|
|
+ template = """
|
|
|
+ 你是 数据库 {dialect} SQL分析助手, 仔细分析SQL语句: {sql} ,判断是否语法正确,如果是 直接返回 1,否则说明错误地方,并返回 0。
|
|
|
+
|
|
|
+ ### 分析要求
|
|
|
+ 1. 判断子查询的正确性
|
|
|
+ 2. 判断整体语句的正确性
|
|
|
+ 3. 如果存在语法错误,给出说明
|
|
|
+
|
|
|
+ 不要重复用户问题、不要分析过程、直接回答。
|
|
|
+ """
|
|
|
+ pt = ChatPromptTemplate.from_template(template)
|
|
|
+ chain = pt | self.llm
|
|
|
+ stopped = False
|
|
|
+ response = ""
|
|
|
+ config = {"configurable": {"thread_id": state["session"]}}
|
|
|
+ async for chunk in chain.astream({"sql": state["question"], "dialect": state["dialect"]}, config=config):
|
|
|
+ current_state = self.memory.get(config)
|
|
|
+ if current_state and current_state["channel_values"]["stop"]:
|
|
|
+ stopped = True
|
|
|
+ break
|
|
|
+ response += chunk.content
|
|
|
+
|
|
|
+ if stopped:
|
|
|
+ return {"stop": True, "error": "用户中止", "status": "canceled", "lineage": {}}
|
|
|
+
|
|
|
+ if response == "0":
|
|
|
+ return {"status": "invalid", "error": "无效问题或SQL存在语法错误,请重新描述!", "lineage": {}}
|
|
|
+
|
|
|
+ return {"status": "success"}
|
|
|
+
|
|
|
+ async def __merge_result(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 合并所有节点结果,输出血缘关系
|
|
|
+ """
|
|
|
+ result = {}
|
|
|
+ result["file_name"] = state["file_name"]
|
|
|
+ source_tables = state["source_tables"]
|
|
|
+ target_table = state["target_table"]
|
|
|
+ # 设置目标表依赖的来源表数量和字段数量
|
|
|
+ target_table["src_table_size"] = len(source_tables)
|
|
|
+ target_table["col_size"] = len(target_table["col_dep"])
|
|
|
+
|
|
|
+ result["target_table"] = target_table
|
|
|
+ result["source_tables"] = source_tables
|
|
|
+ result["where_list"] = state.get("where_list",[])
|
|
|
+ result["join_list"] = state["join_list"]
|
|
|
+
|
|
|
+ return {"lineage": result}
|
|
|
+
|
|
|
+ async def __extract_source(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 根据SQL,提取其中的来源库表信息
|
|
|
+ """
|
|
|
+ dt1 = datetime.now()
|
|
|
+
|
|
|
+ # 默认DB
|
|
|
+ default_db = "TMP"
|
|
|
+ if len(self.target_table) > 0:
|
|
|
+ arr = self.target_table.split(".")
|
|
|
+ if len(arr) == 2: # 使用识别出的目标表db作为默认DB
|
|
|
+ default_db = arr[0]
|
|
|
+
|
|
|
+ template = """
|
|
|
+ 你是SQL数据血缘分析专家,仔细分析SQL片段和上下文信息, 从SQL片段中提取出 来源表信息。
|
|
|
+
|
|
|
+ ### SQL和上下文信息如下
|
|
|
+ - SQL片段:{sql}
|
|
|
+ - 数据库方言: {dialect}
|
|
|
+ - 变量设置信息:{var}
|
|
|
+ - 已有的建表信息: {create_table_info}
|
|
|
+
|
|
|
+ ### 核心要求:
|
|
|
+ 1. 提取 来源表数据库名称(无则使用 {default_db} 填充)
|
|
|
+ 2. 提取 来源表名,包括 from子句、子查询、嵌套子查询、union语句、with语句中出现的物理表名(注意:不要带库名前缀)
|
|
|
+ 3. 提取 来源表别名(无则留空)
|
|
|
+ 4. 提取来源表的所有字段信息,提取来自几种:
|
|
|
+ - 从select中提取,如果带有表别名,则转换成物理表名
|
|
|
+ - 从where过滤条件中提取
|
|
|
+ - 从关联条件(inner join,left join,right join,full join)中提取
|
|
|
+ - 从group中提取
|
|
|
+ - 从函数参数中提取,比如函数 COALESCE(T13.ECIF_CUST_ID,T2.RELAT_PTY_ID,'') AS ECIF_CUST_ID,提取出T13.ECIF_CUST_ID和T2.RELAT_PTY_ID
|
|
|
+ - 从表达式参数中提取
|
|
|
+ 5. 判断来源表是否是临时表
|
|
|
+ 6. 不要包含目标表(目标表指最终插入的表)
|
|
|
+
|
|
|
+ ### 输出规范
|
|
|
+ {format_instructions}
|
|
|
+
|
|
|
+ 不要重复用户问题,不要分析中间过程,直接给出答案。
|
|
|
+ """
|
|
|
+ parser = PydanticOutputParser(pydantic_object=SourceTableList)
|
|
|
+ pt = ChatPromptTemplate.from_template(template=template).partial(format_instructions=parser.get_format_instructions())
|
|
|
+ chain = pt | self.llm_coder | parser
|
|
|
+ answer = {}
|
|
|
+ config = {"configurable": {"thread_id": state["session"]}}
|
|
|
+
|
|
|
+ stopped = False
|
|
|
+ async for chunk in chain.astream({
|
|
|
+ "sql": state["question"],
|
|
|
+ "dialect": state["dialect"],
|
|
|
+ "var": self.var_list or [],
|
|
|
+ "default_db": default_db,
|
|
|
+ "create_table_info": self.ct_list
|
|
|
+ }):
|
|
|
+ current_state = self.memory.get(config)
|
|
|
+ if current_state and current_state["channel_values"]["stop"]:
|
|
|
+ stopped = True
|
|
|
+ break
|
|
|
+ # 设置目标表字段数量,来源表数量
|
|
|
+ # 合并相同源表
|
|
|
+ table_list: list[SourceTable] = []
|
|
|
+ for table in chunk.source_tables:
|
|
|
+ if table not in table_list:
|
|
|
+ table_list.append(table)
|
|
|
+ else:
|
|
|
+ idx = table_list.index(table)
|
|
|
+ tt = table_list[idx]
|
|
|
+ # 合并相同的列
|
|
|
+ for col in table.col_list:
|
|
|
+ if col not in tt.col_list:
|
|
|
+ tt.col_list.append(col)
|
|
|
+
|
|
|
+ answer["source_tables"] = [table.model_dump() for table in table_list]
|
|
|
+
|
|
|
+ dt2 = datetime.now()
|
|
|
+ print("提取源表用时(sec):", (dt2-dt1).seconds)
|
|
|
+ if stopped:
|
|
|
+ return {"status": "canceled", "error": "用户中止", "lineage": {}}
|
|
|
+
|
|
|
+ return answer
|
|
|
+
|
|
|
+
|
|
|
+ async def __extract_target(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 根据SQL,提取其中的目标表信息
|
|
|
+ """
|
|
|
+ if state["sql_type"] == "1":
|
|
|
+ return {"target_table": {"database":"", "table_name": "", "col_list": [], "col_size": 0, "col_dep": []}}
|
|
|
+
|
|
|
+ # 默认DB
|
|
|
+ default_db = "TMP"
|
|
|
+ if len(self.target_table) > 0:
|
|
|
+ arr = self.target_table.split(".")
|
|
|
+ if len(arr) == 2: # 使用识别出的目标表db作为默认DB
|
|
|
+ default_db = arr[0]
|
|
|
+
|
|
|
+ template = """
|
|
|
+ 你是SQL数据血缘分析专家,仔细分析SQL片段和以下上下文:
|
|
|
+ - SQL片段: {sql}
|
|
|
+ - 来源表信息: {source_tables}
|
|
|
+ - 数据库方言: {dialect}
|
|
|
+ - 变量设置信息:{var}
|
|
|
+ - 已有的建表信息: {create_table_info}
|
|
|
+
|
|
|
+ ### 核心要求:
|
|
|
+ 1. **目标表识别**:
|
|
|
+ - 提取 INSERT INTO 后的表作为目标表
|
|
|
+ - 格式:`数据库名.表名`(未指定库名则使用 {default_db} 填充)
|
|
|
+ 2. **字段依赖分析**:
|
|
|
+ - 目标表字段必须溯源到来源表的物理字段
|
|
|
+ - 字段获取方式分类:
|
|
|
+ - `直取`:直接引用源字段(如 `src.col`)
|
|
|
+ - `函数`:通过函数转换(如 `COALESCE(col1, col2)`)
|
|
|
+ - `表达式`:计算表达式(如 `col1 + col2`)
|
|
|
+ - `常量`:固定值(如 `'2023'`)
|
|
|
+ - `子查询`:嵌套SELECT结果
|
|
|
+ - 字段来源格式:`源库.源表.字段名`(不要使用表别名,如果源库未指定则使用 {default_db}替换)
|
|
|
+ - 字段来源信息只能包含 库名、表名和字段信息,如果经过函数或表达式处理,则从参数中提取出具体的字段信息
|
|
|
+ 3. **关键约束**:
|
|
|
+ - 每个目标字段必须有详细的来源说明
|
|
|
+ - 函数参数需完整展开(如 `COALESCE(T1.id, T2.id)` 需解析所有参数)
|
|
|
+ - 表名必须是 英文字符 构成(非汉字构成)
|
|
|
+
|
|
|
+ ### 输出规范
|
|
|
+ {format_instructions}
|
|
|
+
|
|
|
+ 请直接输出JSON格式结果,无需解释过程。
|
|
|
+ """
|
|
|
+ parser = PydanticOutputParser(pydantic_object=TargetTable)
|
|
|
+ pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
|
|
|
+ chain = pt | self.llm_coder | parser
|
|
|
+ answer = {}
|
|
|
+ config = {"configurable": {"thread_id": state["session"]}}
|
|
|
+ dt1 = datetime.now()
|
|
|
+ print("开始解析目标表...")
|
|
|
+
|
|
|
+ stopped = False
|
|
|
+ async for chunk in chain.astream({
|
|
|
+ "sql": state["question"],
|
|
|
+ "dialect": state["dialect"],
|
|
|
+ "source_tables": state["source_tables"],
|
|
|
+ "var": self.var_list or [],
|
|
|
+ "default_db": default_db,
|
|
|
+ "default_db": default_db,
|
|
|
+ "create_table_info": self.ct_list
|
|
|
+ }):
|
|
|
+ current_state = self.memory.get(config)
|
|
|
+ if current_state and current_state["channel_values"]["stop"]:
|
|
|
+ stopped = True
|
|
|
+ break
|
|
|
+ answer["target_table"] = chunk.model_dump()
|
|
|
+
|
|
|
+ if stopped:
|
|
|
+ return {"status": "canceled", "error": "用户中止", "lineage": {}}
|
|
|
+
|
|
|
+ dt2 = datetime.now()
|
|
|
+ print("提取目标表用时(sec):", (dt2 - dt1).seconds)
|
|
|
+ print(f"target:{answer}")
|
|
|
+ return answer
|
|
|
+
|
|
|
+
|
|
|
+ async def __extract_where(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 根据SQL,提取来源表 where 信息
|
|
|
+ """
|
|
|
+ # 默认DB
|
|
|
+ default_db = "TMP"
|
|
|
+ if len(self.target_table) > 0:
|
|
|
+ arr = self.target_table.split(".")
|
|
|
+ if len(arr) == 2: # 使用识别出的目标表db作为默认DB
|
|
|
+ default_db = arr[0]
|
|
|
+
|
|
|
+ dt1 = datetime.now()
|
|
|
+ template = """
|
|
|
+ 你是SQL数据血缘分析专家,仔细分析 SQL语句和上下文,提取 where 条件信息。
|
|
|
+ ### SQL语句和上下文如下
|
|
|
+ - SQL语句: {sql}
|
|
|
+ - 数据库方言: {dialect}
|
|
|
+ - 变量设置信息: {var}
|
|
|
+
|
|
|
+ ### 提取要求:
|
|
|
+ 1. 提取 所有 来源表 中的 where 条件信息,包含 where 中出现的 表原名(非别名)、 表所在数据库名(如果没有则使用 {default_db} 填充)、字段名、条件操作符和条件值
|
|
|
+ 2. 字段名不能包括函数,比如 length(INT_ORG_NO),提取的字段名是 INT_ORG_NO
|
|
|
+ 3. 如果SQL语句中存在变量,则根据变量设置信息转换成实际值
|
|
|
+
|
|
|
+ ### 输出规范
|
|
|
+ {format_instructions}
|
|
|
+
|
|
|
+ 不要重复用户问题,不要分析中间过程,直接给出答案。
|
|
|
+ """
|
|
|
+ parser = PydanticOutputParser(pydantic_object=WhereCondList)
|
|
|
+ pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
|
|
|
+
|
|
|
+ chain = pt | self.llm_coder | parser
|
|
|
+ answer = {}
|
|
|
+ config = {"configurable": {"thread_id": state["session"]}}
|
|
|
+ stopped = False
|
|
|
+
|
|
|
+ async for chunk in chain.astream({
|
|
|
+ "sql": state["question"],
|
|
|
+ "dialect": state["dialect"],
|
|
|
+ "var": self.var_list or [],
|
|
|
+ "default_db": default_db
|
|
|
+ }):
|
|
|
+ current_state = self.memory.get(config)
|
|
|
+ if current_state and current_state["channel_values"]["stop"]:
|
|
|
+ stopped = True
|
|
|
+ break
|
|
|
+ # 设置目标表字段数量,来源表数量
|
|
|
+ dt2 = datetime.now()
|
|
|
+ if stopped:
|
|
|
+ return {"status": "canceled", "error": "用户中止", "lineage": {}}
|
|
|
+ return answer
|
|
|
+
|
|
|
+ async def __extract_join(self, state: AgentState):
|
|
|
+ """
|
|
|
+ 根据SQL,提取 表关联 信息
|
|
|
+ """
|
|
|
+ # 默认DB
|
|
|
+ default_db = "TMP"
|
|
|
+ if len(self.target_table) > 0:
|
|
|
+ arr = self.target_table.split(".")
|
|
|
+ if len(arr) == 2: # 使用识别出的目标表db作为默认DB
|
|
|
+ default_db = arr[0]
|
|
|
+
|
|
|
+ dt1 = datetime.now()
|
|
|
+ template = """
|
|
|
+ 你是SQL数据血缘分析专家,负责提取完整的JOIN关系。仔细分析SQL片段:
|
|
|
+ - SQL片段: {sql}
|
|
|
+ - 数据库方言: {dialect}
|
|
|
+
|
|
|
+ ### 关键要求:
|
|
|
+ 1. **SQL语句中必须存在关联条件(包括inner、left join、right join、full join)**
|
|
|
+ 2. **识别JOIN结构**:
|
|
|
+ - 提取FROM子句中的主表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {default_db} 填充
|
|
|
+ - 提取每个JOIN子句中的从表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {j_default_db} 填充
|
|
|
+ - 明确标注JOIN类型(INNER/LEFT/RIGHT/FULL)
|
|
|
+ 3. **关联字段提取**:
|
|
|
+ - 必须提取每对关联字段的完整信息:
|
|
|
+ * 主表端:表别名.字段名
|
|
|
+ * 从表端:表别名.字段名
|
|
|
+ - 明确标注关联操作符(=, >, <等)
|
|
|
+ - 多条件JOIN拆分为多个字段对
|
|
|
+ 4. **特殊场景处理**:
|
|
|
+ - 子查询作为表时,表名填写子查询别名
|
|
|
+ - 隐式JOIN(WHERE条件)需转换为显式JOIN结构
|
|
|
+ - 多表JOIN时保持原始顺序
|
|
|
+ - 如果关联字段是 常量,则关联字段取空
|
|
|
+ - 如果关联字段经函数处,则提取参数中的 字段,如果参数是常量,则关联字段取空
|
|
|
+
|
|
|
+ ### 输出规范
|
|
|
+ {format_instructions}
|
|
|
+
|
|
|
+ 不要重复用户问题,不要中间思考过程,直接回答。
|
|
|
+ """
|
|
|
+ parser = PydanticOutputParser(pydantic_object=JoinRelationList)
|
|
|
+ pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
|
|
|
+
|
|
|
+ chain = pt | self.llm_coder | parser
|
|
|
+ answer = {}
|
|
|
+ config = {"configurable": {"thread_id": state["session"]}}
|
|
|
+
|
|
|
+ stopped = False
|
|
|
+
|
|
|
+ async for chunk in chain.astream({
|
|
|
+ "sql": state["question"],
|
|
|
+ "dialect": state["dialect"],
|
|
|
+ "default_db": default_db,
|
|
|
+ "j_default_db": default_db
|
|
|
+ }):
|
|
|
+ current_state = self.memory.get(config)
|
|
|
+ if current_state and current_state["channel_values"]["stop"]:
|
|
|
+ stopped = True
|
|
|
+ break
|
|
|
+ # 设置目标表字段数量,来源表数量
|
|
|
+ answer["join_list"] = [join.model_dump() for join in chunk.join_list]
|
|
|
+
|
|
|
+ dt2 = datetime.now()
|
|
|
+ if stopped:
|
|
|
+ return {"status": "canceled", "error": "用户中止", "lineage": {}}
|
|
|
+
|
|
|
+ return answer
|
|
|
+
|
|
|
+ def trigger_stop(self, session: str):
|
|
|
+ """外部触发中止"""
|
|
|
+ config = {"configurable": {"thread_id": session}}
|
|
|
+ current_state = self.memory.get(config)
|
|
|
+ if current_state:
|
|
|
+ current_state["channel_values"]["stop"] = True
|
|
|
+
|
|
|
+
|
|
|
+async def main(sql_content: str, sql_file: str, dialect: str, owner: str):
|
|
|
+ agent = SqlLineageAgent()
|
|
|
+ result = await agent.ask_question(sql_content=sql_content, session="s-1", dialect=dialect, sql_file=sql_file, owner=owner)
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+def read_var_file(var_file: str):
|
|
|
+ """
|
|
|
+ 读取变量文件
|
|
|
+ :param var_file: 变量文件(格式excel)
|
|
|
+ """
|
|
|
+ import pandas as pd
|
|
|
+
|
|
|
+ # 读取变量excel文件
|
|
|
+ result = []
|
|
|
+ df = pd.read_excel(var_file)
|
|
|
+ for row in df.itertuples():
|
|
|
+ print(f"变量 {row.Index + 1}: {row.变量}, {row.含义}")
|
|
|
+ item = {}
|
|
|
+ arr = row.变量.split("=")
|
|
|
+ item["var_name"] = arr[0]
|
|
|
+ item["var_value"] = arr[1]
|
|
|
+ item["var_comment"] = row.含义
|
|
|
+ result.append(item)
|
|
|
+ return result
|
|
|
+
|
|
|
+def replace_vars(sql_file: str, var_list: list[dict]):
|
|
|
+ """
|
|
|
+ 替换 sql_file 中的变量
|
|
|
+ :param sql_file: SQL文件
|
|
|
+ :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}]
|
|
|
+ """
|
|
|
+ new_content = ""
|
|
|
+ encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312']
|
|
|
+ for encoding in encodings:
|
|
|
+ try:
|
|
|
+ with open(sql_file, 'r', encoding=encoding) as f:
|
|
|
+ # 读取文件内容
|
|
|
+ content = f.read()
|
|
|
+ # 替换变量
|
|
|
+ for var in var_list:
|
|
|
+ name = var["var_name"]
|
|
|
+ value = var["var_value"]
|
|
|
+ new_content = content.replace(name, value)
|
|
|
+ content = new_content
|
|
|
+ break
|
|
|
+ except UnicodeDecodeError:
|
|
|
+ print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ print("无法找到合适的编码")
|
|
|
+ raise Exception(f"无法读取文件 {sql_file}, 未找到合适的编码.")
|
|
|
+
|
|
|
+ return new_content
|
|
|
+
|
|
|
+
|
|
|
+def get_ddl_list(ddl_dir: str, var_list: list[dict]):
|
|
|
+ """
|
|
|
+ 获取 DDL目录中定义的所有建表DDL,并替换其中的变量
|
|
|
+ :param ddl_dir: DDL目录,目录结构 ddl/db/xx.ddl
|
|
|
+ :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}]
|
|
|
+ """
|
|
|
+ file_list = []
|
|
|
+ # 遍历目录,获取所有.ddl文件
|
|
|
+ for root, dirs, files in os.walk(ddl_dir):
|
|
|
+ for file in files:
|
|
|
+ if file.endswith(".ddl"):
|
|
|
+ file_path = os.path.join(root, file)
|
|
|
+ file_list.append(file_path)
|
|
|
+
|
|
|
+ print(f"DDL文件数量:{len(file_list)}")
|
|
|
+ # 读取ddl文件,替换其中的变量
|
|
|
+ result = []
|
|
|
+ for file in file_list:
|
|
|
+ fp = Path(file)
|
|
|
+ # 文件的名称规则:db__table,解析出数据库和表名
|
|
|
+ name = fp.name.replace(".ddl", "")
|
|
|
+ arr = name.split("__")
|
|
|
+ db = arr[0]
|
|
|
+ table = arr[1]
|
|
|
+ # key
|
|
|
+ key = f"{db.upper()}.{table.upper()}"
|
|
|
+ ddl_content = ""
|
|
|
+ encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312']
|
|
|
+ for encoding in encodings:
|
|
|
+ try:
|
|
|
+ with open(file, 'r', encoding=encoding) as f:
|
|
|
+ # 读取文件内容
|
|
|
+ content = f.read()
|
|
|
+ # 替换变量
|
|
|
+ for var in var_list:
|
|
|
+ name = var["var_name"]
|
|
|
+ value = var["var_value"]
|
|
|
+ new_content = content.replace(name, value)
|
|
|
+ content = new_content
|
|
|
+ result.append({key: content})
|
|
|
+ break
|
|
|
+ except UnicodeDecodeError:
|
|
|
+ print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ print("无法找到合适的编码")
|
|
|
+ raise Exception(f"无法读取文件 {file}, 未找到合适的编码.")
|
|
|
+
|
|
|
+ print(f"result:{result[0]}")
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import sys
|
|
|
+ import os
|
|
|
+ from pathlib import Path
|
|
|
+ import json
|
|
|
+
|
|
|
+ # 参数检查
|
|
|
+ if len(sys.argv) <= 4:
|
|
|
+ print(f"usage python sql_lineage_agent_xmgj.py var_file sql_file dialect owner")
|
|
|
+ exit(1)
|
|
|
+ # 变量文件
|
|
|
+ var_file = sys.argv[1]
|
|
|
+ # sql脚本文件或目录
|
|
|
+ sql_file = sys.argv[2]
|
|
|
+ # 数据库方言
|
|
|
+ dialect = sys.argv[3]
|
|
|
+ # 属主
|
|
|
+ owner = sys.argv[4]
|
|
|
+
|
|
|
+ print(f"1、解析SQL文件/目录:{sql_file}")
|
|
|
+ print(f"2、SQL文件属主平台:{owner}")
|
|
|
+ print(f"3、变量文件:{var_file}")
|
|
|
+ print(f"4、SQL数据库方言: {dialect}")
|
|
|
+
|
|
|
+ # 检查变量文件是否存在
|
|
|
+ var_fp = Path(var_file)
|
|
|
+ if not var_fp.exists():
|
|
|
+ raise FileNotFoundError(f"变量文件 {var_file} 不存在,请指定正确路径.")
|
|
|
+
|
|
|
+ # 检查SQL脚本文件是否存在
|
|
|
+ sql_fp = Path(sql_file)
|
|
|
+ if not sql_fp.exists():
|
|
|
+ raise FileNotFoundError(f"SQL脚本文件 {sql_file} 不存在,请指定正确路径.")
|
|
|
+
|
|
|
+ # 读取变量文件,获取变量值
|
|
|
+ var_list = read_var_file(var_file)
|
|
|
+
|
|
|
+ # 读取SQL文件, 替换变量
|
|
|
+ sql_content = replace_vars(sql_file=sql_file, var_list=var_list)
|
|
|
+
|
|
|
+ # 解析SQL血缘
|
|
|
+ result = asyncio.run(main(sql_content=sql_content, sql_file=sql_fp.name, dialect=dialect, owner=owner))
|
|
|
+ # 将结果写入同级目录下,文件后缀为.json
|
|
|
+ target_file = Path(sql_fp.parent.absolute() / (sql_fp.name + ".json"))
|
|
|
+ print(f"写目标文件:{target_file}")
|
|
|
+ # 写文件
|
|
|
+ with open(target_file, 'w', encoding="utf-8") as t:
|
|
|
+ t.write(json.dumps(result, ensure_ascii=False, indent=2))
|
|
|
+
|