sql_lineage_agent_xmgj.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298
  1. from langchain_core.prompts import ChatPromptTemplate
  2. from langgraph.graph import START, StateGraph, END
  3. from llmops.agents.datadev.llm import get_llm, get_llm_coder
  4. from typing import List
  5. from typing_extensions import TypedDict
  6. from pydantic import BaseModel, Field
  7. from llmops.agents.datadev.memory.memory_saver_with_expiry2 import MemorySaverWithExpiry
  8. from langchain_core.output_parsers import PydanticOutputParser
  9. from config import linage_agent_config
  10. from llmops.agents.datadev.tools.timeit import timeit
  11. import asyncio
  12. from datetime import datetime
  13. class Column(BaseModel):
  14. """
  15. 字段信息
  16. """
  17. col_name: str = Field(description="字段名称")
  18. def __eq__(self, other):
  19. if isinstance(other, Column):
  20. return self.col_name == other.col_name
  21. return False
  22. class ColumnDep(BaseModel):
  23. """
  24. 目标表字段信息,包含依赖关系
  25. """
  26. col_name: str = Field(description="目标表字段名称")
  27. col_comment: str = Field(description="目标表字段说明")
  28. from_cols: List[str] = Field(description="字段来源信息,格式是 源库名.源表名.源字段名 或 中间表名.中间表字段名")
  29. dep_type: List[str] = Field(description="字段获取方式 1:直取 2:函数 3:表达式")
  30. desp: List[str] = Field(description="详细解释选取来源字段的原因")
  31. class SourceTable(BaseModel):
  32. """
  33. 单个来源信息表信息
  34. """
  35. database: str = Field(description="来源数据库名称")
  36. table_name: str = Field(description="来源表名称")
  37. table_name_alias: str = Field(default="", description="来源表别名")
  38. col_list: List[Column] = Field(description="来源表的字段集合")
  39. is_temp: bool = Field(description="是否是临时表")
  40. def __eq__(self, other):
  41. if isinstance(other, SourceTable):
  42. return self.database == other.database and self.table_name == other.table_name
  43. return False
  44. class SourceTableList(BaseModel):
  45. """
  46. 所有来源信息表信息
  47. """
  48. source_tables: List[SourceTable] = Field(description="所有来源表信息")
  49. class WhereCondItem(BaseModel):
  50. """
  51. 单个表过滤条件信息
  52. """
  53. database: str = Field(description="数据库名")
  54. table_name: str = Field(description="表原名")
  55. col_name: str = Field(description="条件字段名称")
  56. operator: str = Field(default="=", description="条件操作符,如=,>,<,in")
  57. value: str = Field(description="条件值")
  58. class WhereCondList(BaseModel):
  59. """
  60. 表过滤条件集合
  61. """
  62. where_list: List[WhereCondItem] = Field(description="所有的where条件")
  63. class JoinColumnPair(BaseModel):
  64. """关联字段对"""
  65. main_table_alias: str = Field(description="主表别名")
  66. main_column: str = Field(description="主表关联字段")
  67. join_table_alias: str = Field(description="从表别名")
  68. join_column: str = Field(description="从表关联字段")
  69. operator: str = Field(default="=", description="关联操作符,如=,>,<")
  70. class JoinRelation(BaseModel):
  71. """完整的JOIN关系"""
  72. main_database: str = Field(description="主表数据库名")
  73. main_table: str = Field(description="主表物理表名")
  74. main_alias: str = Field(description="主表别名")
  75. join_type: str = Field(description="JOIN类型:INNER/LEFT/RIGHT/FULL")
  76. join_database: str = Field(description="从表数据库名")
  77. join_table: str = Field(description="从表物理表名")
  78. join_alias: str = Field(description="从表别名")
  79. column_pairs: List[JoinColumnPair] = Field(description="关联字段对列表")
  80. class JoinRelationList(BaseModel):
  81. """
  82. join关系集合
  83. """
  84. join_list: List[JoinRelation] = Field(description="所有的join关系")
  85. class TargetTable(BaseModel):
  86. """
  87. 目标表信息
  88. """
  89. database: str = Field(default="", description="目标表的数据库名称")
  90. table_name: str = Field(description="目标表名称")
  91. col_size: int = Field(description="目标表字段数量")
  92. col_list: List[str] = Field(description="目标表字段集合")
  93. col_dep: List[ColumnDep] = Field(description="目标表字段依赖信息")
  94. src_table_size: int = Field(description="依赖来源表数量")
  95. class CreateTable(BaseModel):
  96. """
  97. 建表信息
  98. """
  99. database: str = Field(description="数据库名")
  100. table: str = Field(description="表名")
  101. col_list: list[str] = Field(description="字段名集合")
  102. source_table_list: list[str] = Field(description="来源表集合")
  103. class AgentState(TypedDict):
  104. """
  105. Agent状态
  106. """
  107. question: str
  108. session: str
  109. dialect: str
  110. sql: str # 带行号的SQL语句
  111. sql_type: str # sql操作类型
  112. lineage: dict
  113. status: str
  114. error: str
  115. stop: bool
  116. file_name: str # 脚本文件
  117. target_table: dict # 目标表信息
  118. source_tables: list[dict] # 来源表信息
  119. where_list: list[dict] # where条件信息
  120. join_list: list[dict] # 表关联信息
  121. class SqlLineageAgent:
  122. """
  123. 解析用户提供的SQL语句/SQL脚本中的数据血缘关系
  124. """
  125. def __init__(self):
  126. # 读取配置文件
  127. c = linage_agent_config
  128. # 数据库方言
  129. self.db_dialect = c["dialect"]
  130. # 并发度
  131. self.concurrency = c["concurrency"]
  132. # 大模型实例
  133. self.llm = get_llm()
  134. self.llm_coder = get_llm_coder()
  135. self.memory = MemorySaverWithExpiry(expire_time=600, clean_interval=60)
  136. # 构建图
  137. self.graph = self._build_graph()
  138. # SQL 语句中设置的变量
  139. self.var_list: list[str] = []
  140. # 最终的来源表集合
  141. self.final_source_table_list: list[dict] = []
  142. # 最终的中间表集合
  143. self.final_mid_table_list: list[dict] = []
  144. # 最终的where集合
  145. self.final_where_list: list[dict] = []
  146. # 最终的join集合
  147. self.final_join_list: list[dict] = []
  148. # 最终的目标表集合
  149. self.final_target_list: list[dict] = []
  150. # 最终目标表,需合并
  151. self.final_target_table = {}
  152. # 目标表 格式 库名.表名
  153. self.target_table = ""
  154. # 内部建表信息
  155. self.ct_list = []
  156. def _build_graph(self):
  157. # 构造计算流程
  158. graph_builder = StateGraph(AgentState)
  159. # 添加节点
  160. graph_builder.add_node("_sql_type", self._sql_type)
  161. graph_builder.add_node("_extract_set", self._extract_set)
  162. graph_builder.add_node("_extract_create_table", self._extract_create_table)
  163. graph_builder.add_node("_invalid", self._invalid)
  164. graph_builder.add_node("source", self.__extract_source)
  165. graph_builder.add_node("target", self.__extract_target)
  166. graph_builder.add_node("where", self.__extract_where)
  167. graph_builder.add_node("join", self.__extract_join)
  168. graph_builder.add_node("merge", self.__merge_result)
  169. # 添加边
  170. graph_builder.add_edge(START, "_sql_type")
  171. graph_builder.add_conditional_edges("_sql_type", path=self.__parallel_route2, path_map={
  172. "_invalid": "_invalid",
  173. "_extract_set": "_extract_set",
  174. "_extract_create_table": "_extract_create_table",
  175. "where": "where",
  176. "join": "join"
  177. })
  178. graph_builder.add_edge("_invalid", END)
  179. graph_builder.add_edge("_extract_set", END)
  180. graph_builder.add_edge("_extract_set", END)
  181. graph_builder.add_edge("_extract_create_table", END)
  182. graph_builder.add_edge("where", "source")
  183. graph_builder.add_edge("join", "source")
  184. graph_builder.add_edge("source", "target")
  185. graph_builder.add_edge("target", "merge")
  186. graph_builder.add_edge("merge", END)
  187. return graph_builder.compile(checkpointer=self.memory)
  188. def _sql_type(self, state: AgentState):
  189. """
  190. 判断SQL的操作类型
  191. """
  192. template = """
  193. 你是数据库 {dialect} 专家,对 SQL语句: {sql} 中的操作进行分类, 直接返回分类ID(数字)
  194. ### 操作分类标准如下:
  195. 1:create, 如建表、建视图等
  196. 2:insert, 向表中插入数据
  197. 3:set、设置参数操作、变量赋值操作
  198. 4:其它
  199. 不要重复用户问题,不要中间分析过程,直接回答。
  200. """
  201. pt = ChatPromptTemplate.from_template(template)
  202. chain = pt | self.llm
  203. response = chain.invoke({"dialect": state["dialect"], "sql": state["question"]})
  204. return {"sql_type": response.content}
  205. def __parallel_route(self, state: AgentState):
  206. """并发节点"""
  207. status = state["status"]
  208. if status == "invalid" or status == "canceled" or state["sql_type"] == "3": # 无效问题 或 用户中止
  209. return END
  210. return "continue"
  211. def __parallel_route2(self, state: AgentState):
  212. """并发节点"""
  213. if state["sql_type"] == "4": # 非create和非insert操作
  214. return "_invalid"
  215. if state["sql_type"] == "3": # 设置参数
  216. return "_extract_set"
  217. if state["sql_type"] == "1": # 建表
  218. return "_extract_create_table"
  219. return ["where", "join"]
  220. def _invalid(self, state: AgentState):
  221. """
  222. sql_type == 4,不关注的语句,直接跳过
  223. """
  224. return {
  225. "lineage": {
  226. "source_tables": [],
  227. "join_list": [],
  228. "where_list": [],
  229. "target_table": {}
  230. }
  231. }
  232. async def _extract_set(self, state: AgentState):
  233. """
  234. 提取 SQL 中设置的参数变量
  235. """
  236. if state["sql_type"] != "3":
  237. return state
  238. template = """
  239. 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 参数设置的全部内容,包括变量赋值、注释。
  240. 不要重复用户问题,不需要中间思考过程,直接回答。
  241. """
  242. pt = ChatPromptTemplate.from_template(template)
  243. chain = pt | self.llm_coder
  244. response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]})
  245. self.var_list.append(response.content)
  246. return {
  247. "lineage": {
  248. "source_tables": [],
  249. "join_list": [],
  250. "where_list": [],
  251. "target_table": {}
  252. }
  253. }
  254. async def _extract_create_table(self, state: AgentState):
  255. """
  256. 提取建表信息
  257. """
  258. if state["sql_type"] != "1":
  259. return state
  260. template = """
  261. 你是数据库 {dialect} SQL分析专家,从SQL语句: {sql} 中提取 建表信息。
  262. ### 提取要求
  263. 1、提取建表对应的数据库名(无则留空)
  264. 2、提取建表对应的表名
  265. 3、提取建表对应的字段名
  266. 4、提取建表对应的来源表名,来源表名格式是 来源库库名.来源表名
  267. ### 输出规范
  268. {format_instructions}
  269. 不要重复用户问题,不需要中间思考过程,直接回答。
  270. """
  271. parser = PydanticOutputParser(pydantic_object=CreateTable)
  272. pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
  273. chain = pt | self.llm_coder | parser
  274. response = await chain.ainvoke({"dialect": state["dialect"], "sql": state["question"]})
  275. #
  276. print(f"extract-create-table: {response}")
  277. self.ct_list.append(response.model_dump())
  278. return {
  279. "lineage": {
  280. "source_tables": [],
  281. "join_list": [],
  282. "where_list": [],
  283. "target_table": {}
  284. }
  285. }
  286. async def _parse(self, semaphore, sql: str, session: str, dialect: str, file_name: str = "SQL_FILE.SQL"):
  287. """
  288. 解析SQL(片段)的血缘关系
  289. """
  290. async with semaphore: # 进入信号量区域(限制并发)
  291. config = {"configurable": {"thread_id": session}}
  292. # 调用任务流
  293. response = await self.graph.ainvoke({
  294. "question": sql,
  295. "session": session,
  296. "dialect": dialect or self.db_dialect,
  297. "file_name": file_name,
  298. "stop": False,
  299. "error": "",
  300. "sql_type": "0",
  301. "status": "success"}, config)
  302. return response
  303. async def find_target_table(self, sql: str, dialect: str):
  304. """
  305. 找出目标表
  306. """
  307. template = """
  308. 你是SQL数据血缘分析专家,参照数据库方言{dialect},仔细分析SQL语句,找出目标表并返回。
  309. ### SQL片段如下
  310. {sql}
  311. ### 目标表标准
  312. 1、目标表一定是出现在 insert into 后的 表
  313. 2、最后 insert into 后的表为目标表
  314. 3、目标表只有一张
  315. ### 核心要求
  316. 1、目标表的返回格式是 数据库名.表名 (如果无数据库名,则留空)
  317. 2、SQL语句中可能包括多个insert into语句,但要从全局分析,找出最终插入的目标表
  318. 不要重复用户问题、不要中间分析过程,直接回答。
  319. """
  320. pt = ChatPromptTemplate.from_template(template)
  321. chain = pt | self.llm_coder
  322. response = await chain.ainvoke({"sql": sql, "dialect": dialect})
  323. return response.content.strip().upper()
  324. def split_sql_statements(self, sql):
  325. """
  326. 将SQL语句按分号分段,同时正确处理字符串和注释中的分号
  327. 参数:
  328. sql: 包含一个或多个SQL语句的字符串
  329. 返回:
  330. 分割后的SQL语句列表
  331. """
  332. # 状态变量
  333. in_single_quote = False
  334. in_double_quote = False
  335. in_line_comment = False
  336. in_block_comment = False
  337. escape_next = False
  338. statements = []
  339. current_start = 0
  340. i = 0
  341. while i < len(sql):
  342. char = sql[i]
  343. # 处理转义字符
  344. if escape_next:
  345. escape_next = False
  346. i += 1
  347. continue
  348. # 处理字符串和注释中的情况(这些地方的分号不应该作为分隔符)
  349. if in_line_comment:
  350. if char == '\n':
  351. in_line_comment = False
  352. elif in_block_comment:
  353. if char == '*' and i + 1 < len(sql) and sql[i + 1] == '/':
  354. in_block_comment = False
  355. i += 1 # 跳过下一个字符
  356. elif in_single_quote:
  357. if char == "'":
  358. in_single_quote = False
  359. elif char == '\\':
  360. escape_next = True
  361. elif in_double_quote:
  362. if char == '"':
  363. in_double_quote = False
  364. elif char == '\\':
  365. escape_next = True
  366. else:
  367. # 不在字符串或注释中,检查是否进入这些状态
  368. if char == "'":
  369. in_single_quote = True
  370. elif char == '"':
  371. in_double_quote = True
  372. elif char == '-' and i + 1 < len(sql) and sql[i + 1] == '-':
  373. in_line_comment = True
  374. i += 1 # 跳过下一个字符
  375. elif char == '/' and i + 1 < len(sql) and sql[i + 1] == '*':
  376. in_block_comment = True
  377. i += 1 # 跳过下一个字符
  378. elif char == ';':
  379. # 找到真正的分号分隔符
  380. statement = sql[current_start:i].strip()
  381. if statement:
  382. statements.append(statement)
  383. current_start = i + 1
  384. i += 1
  385. # 添加最后一个语句(如果没有以分号结尾)
  386. if current_start < len(sql):
  387. statement = sql[current_start:].strip()
  388. if statement:
  389. statements.append(statement)
  390. return statements
  391. async def find_target_table2(self, results: list[dict]) -> str:
  392. """
  393. 根据各段解析结果查找出目标表
  394. 规则:一个段的目标表如果没有出现在其它各段的来源表中,即为目标表
  395. """
  396. target_table = ""
  397. exists = False
  398. for i, item in enumerate(results):
  399. tt = item["lineage"]["target_table"]
  400. # 取一个目标表
  401. db = tt.get("database", "").strip().upper()
  402. table = tt.get("table_name", "").strip().upper()
  403. exists = False
  404. if len(table) > 0:
  405. # 从源表里面查找
  406. for j, item2 in enumerate(results):
  407. if i != j: # 不跟自身比
  408. st_list = item2["lineage"]["source_tables"]
  409. for st in st_list:
  410. sdb = st.get("database", "").strip().upper()
  411. stable = st.get("table_name", "").strip().upper()
  412. if db == sdb and table == stable: # 目标表在源表中存在
  413. exists = True
  414. break
  415. if exists:
  416. break
  417. if not exists: # 说明目标表不在其它各段的源表中存在
  418. target_table = ".".join([db, table])
  419. break
  420. return target_table.strip().upper()
  421. async def ask_question(self, sql_content: str, session: str, dialect: str, owner: str, sql_file: str = "SQL_FILE.SQL"):
  422. """
  423. 功能:根据用户问题,解析SQL中的数据血缘关系。先分段解析,再组织合并。
  424. :param sql_content: SQL内容
  425. :param session: 会话
  426. :param dialect: 数据库方言
  427. :param owner: 脚本平台属主
  428. :param var_map: 变量关系
  429. :sql_file: 脚本名称
  430. """
  431. # 通过session保持记忆
  432. t1 = datetime.now()
  433. # 最终解析结果
  434. final_result = {}
  435. final_result["task_id"] = session
  436. final_result["file_name"] = sql_file
  437. final_result["owner"] = owner
  438. final_result["status"] = "success"
  439. final_result["error"] = ""
  440. lineage = {}
  441. try:
  442. # 根据全局SQL找出目标表
  443. self.target_table = await self.find_target_table(sql_content, dialect)
  444. print(f"大模型查找目标表:{self.target_table}, {len(self.target_table)}")
  445. # 对SQL分段,并发执行解析
  446. sql_list = self.split_sql_statements(sql=sql_content)
  447. print("SQL分段数量:", len(sql_list))
  448. # 信号量,控制并发量
  449. semaphore = asyncio.Semaphore(self.concurrency)
  450. task_list = [self._parse(semaphore=semaphore, sql=sql, session=session, dialect=dialect, file_name=sql_file) for sql in sql_list]
  451. # 并发分段解析
  452. results = await asyncio.gather(*task_list)
  453. if len(self.target_table) == 0:
  454. # 从各段解析结果中查找出目标表
  455. self.target_table = await self.find_target_table2(results) or ""
  456. print(f"从血缘关系中找出目标表:{self.target_table}")
  457. if not self.target_table: # 未出现目标表, 直接返回
  458. final_result["status"] = "error",
  459. final_result["error"] = "未找到目标表"
  460. return final_result
  461. for response in results:
  462. # 来源表、where、join、target
  463. self._merge_source_tables(response["lineage"]["source_tables"])
  464. self._merge_where(response["lineage"].get("where_list",[]))
  465. self._merge_join(response["lineage"]["join_list"])
  466. # 合并目标表
  467. self._merge_target(response["lineage"]["target_table"])
  468. # 增加中间表标识
  469. self._add_mid_table_label(self.final_source_table_list, self.final_mid_table_list)
  470. # 补充中间表依赖字段
  471. self._add_mid_table_col(self.final_mid_table_list, self.final_target_list)
  472. # 多目标表合并
  473. self.final_target_table["src_table_size"] = len(self.final_source_table_list)
  474. lineage["target_table"] = self.final_target_table
  475. lineage["source_tables"] = self.final_source_table_list
  476. lineage["mid_table_list"] = self.final_mid_table_list
  477. lineage["join_list"] = self.final_join_list
  478. lineage["where_list"] = self.final_where_list
  479. final_result["lineage"] = lineage
  480. except Exception as e:
  481. final_result["status"] = "error"
  482. final_result["error"] = str(e)
  483. print(str(e))
  484. t2 = datetime.now()
  485. print("总共用时(sec):", (t2-t1).seconds)
  486. parse_time = (t2 - t1).seconds
  487. # 解析时间(秒)
  488. final_result["parse_time"] = parse_time
  489. final_result["parse_end_time"] = t2.strftime("%Y-%m-%d %H:%M:%S")
  490. return final_result
  491. def _merge_source_tables(self, source_table_list: list[dict]):
  492. """
  493. 合并分段来源表
  494. """
  495. if len(source_table_list) > 0:
  496. # 目标源表的keys
  497. table_keys = [(t.get("database","").strip().upper(), t.get("table_name","").strip().upper()) for t in self.final_source_table_list]
  498. for st in source_table_list:
  499. # 将字典转换为元组
  500. db = st.get("database","").strip().upper()
  501. table = st.get("table_name","").strip().upper()
  502. key = (db, table)
  503. if key not in table_keys:
  504. table_keys.append(key)
  505. self.final_source_table_list.append(st)
  506. else:
  507. # 合并字段
  508. col_list = st.get("col_list", [])
  509. # 找出相同的元素, 库名,表名相同
  510. e = next(filter(lambda item: item["database"].upper()==db and item["table_name"].upper()==table, self.final_source_table_list), None)
  511. if e:
  512. final_col_list = e.get("col_list",[])
  513. col_keys = [(c.get("col_name","").upper()) for c in final_col_list]
  514. for c in col_list:
  515. ck = (c["col_name"].upper())
  516. if ck not in col_keys and len(c) > 0:
  517. col_keys.append(ck)
  518. final_col_list.append(c)
  519. @timeit
  520. def _merge_where(self, where_list: list[dict]):
  521. """
  522. 最终合并所有的where条件
  523. """
  524. if where_list and len(where_list) > 0:
  525. self.final_where_list += where_list
  526. def _merge_join(self, join_list: list[dict]):
  527. """
  528. 最终合并所有的join条件
  529. """
  530. if len(join_list) > 0:
  531. self.final_join_list += join_list
  532. def _merge_target(self, mid_target_table: dict):
  533. """
  534. 合并中间目标表,只有 名称是 target_table_name 的表才是真目标表
  535. 目标表可能出现多次(如多次插入),需要合并 字段信息(字段名称和字段来源)
  536. 对于非目标表的中间目标表,放到中间表集合中
  537. """
  538. if not mid_target_table: # 中间目标表为空
  539. return
  540. else: # 将临时目标表与目标表做合并
  541. db = mid_target_table.get("database", "").strip().upper()
  542. table = mid_target_table.get("table_name", "").strip().upper()
  543. col_size = mid_target_table.get("col_list",[])
  544. if len(table) == 0 or len(col_size) == 0: # 表名为空 或 无字段 直接返回
  545. return
  546. if len(db) == 0: # 缺少数据库名, 检查表名中是否带库名
  547. arr = table.split('.')
  548. if len(arr) == 2: # 说明表名中当库名,解析有错,修正
  549. mid_target_table["database"] = arr[0].strip().upper()
  550. mid_target_table["table_name"] = arr[1].strip().upper()
  551. db = arr[0].strip().upper()
  552. table = arr[1].strip().upper()
  553. elif len(arr) == 1: # 无库名,只有表名,使用真目标表库名替换
  554. arr = self.target_table.split(".")
  555. if len(arr) == 2:
  556. db = mid_target_table["database"] = arr[0].strip().upper()
  557. key = ".".join([db, table])
  558. if key == self.target_table: # 找到最终目标表
  559. print(f"合并目标表:{mid_target_table}")
  560. if not self.final_target_table: # 目标表为空(首次)
  561. self.final_target_table = mid_target_table
  562. return
  563. # 合并字段和字段来源信息
  564. # 合并新字段
  565. new_col_list = [] # 新字段集
  566. col_list = mid_target_table.get("col_list", [])
  567. for new_col in col_list:
  568. # 检查字段是否存在、忽略大小写
  569. if new_col.strip().upper() not in [col.strip().upper() for col in self.final_target_table.get("col_list", [])]: # 不存在
  570. print(f"合并新字段:{new_col}")
  571. self.final_target_table.get("col_list", []).append(new_col.strip().upper())
  572. new_col_list.append(new_col.strip().upper())
  573. # 合并字段来源信息
  574. col_dep_list = mid_target_table.get("col_dep", [])
  575. for col_dep in col_dep_list:
  576. col_name = col_dep["col_name"].strip().upper()
  577. if col_name in new_col_list: # 新字段, 直接添加
  578. self.final_target_table.get("col_dep",[]).append(col_dep)
  579. else: # 合并来源信息
  580. # 找到同名字段,合并来源字段
  581. for dep in self.final_target_table.get("col_dep", []):
  582. cn = dep["col_name"].strip().upper()
  583. if col_name == cn:
  584. print(f"合并来源字段, {col_name}")
  585. m_from_col = col_dep.get('from_cols', [])
  586. f_from_col = dep.get("from_cols", [])
  587. print(f"中间表来源字段:{m_from_col}")
  588. print(f"目标表来源字段:{f_from_col}")
  589. # 将临时中间表的来源字段合并至目标表的来源字段
  590. self._merge_col_dep(m_from_col, f_from_col)
  591. else:
  592. # 加入中间表集合中
  593. print(f"加入中间表:{mid_target_table}")
  594. self.final_mid_table_list.append(mid_target_table)
  595. print(f"final tt:{self.final_target_table}")
  596. def _merge_col_dep(self, mid_from_cols: list[str], final_from_cols: list[str]):
  597. """
  598. 合并临时中间表来源字段 到目标表来源字段中
  599. """
  600. for from_col in mid_from_cols:
  601. if len(from_col.split(".")) > 0: # 忽略常量,NULL值
  602. if from_col.strip().upper() not in [fc.strip().upper() for fc in final_from_cols]:
  603. print(f"合并字段:{from_col}")
  604. final_from_cols.append(from_col.strip().upper())
  605. @staticmethod
  606. @timeit
  607. def _add_mid_table_label(source_table_list: list[dict], mid_table_list: list[dict]):
  608. """
  609. 给最终的来源表增加中间表标识
  610. """
  611. if len(source_table_list) > 0 and len(mid_table_list) > 0:
  612. mid_table_keys = [(t["database"].strip().upper(), t["table_name"].strip().upper()) for t in mid_table_list]
  613. for st in source_table_list:
  614. st["is_temp"] = False
  615. # 将字典转换为元组
  616. key = (st["database"].strip().upper(), st["table_name"].strip().upper())
  617. if key in mid_table_keys: # 是中间表
  618. st["is_temp"] = True
  619. @staticmethod
  620. @timeit
  621. def _add_mid_table_col(mid_table_list: list[dict], target_table_list: list[dict]):
  622. """
  623. 给中间表补充来源字段,使用对应的目标表col_dep去替换
  624. """
  625. if len(mid_table_list) > 0 and len(target_table_list) > 0:
  626. for mt in mid_table_list:
  627. for tt in target_table_list:
  628. tt["is_temp"] = False
  629. # 目标表的数据库名与表名相同
  630. tdn = tt["database"].strip().upper()
  631. ttn = tt["table_name"].strip().upper()
  632. if len(tdn) == 0 and len(ttn) == 0: # 全部为空
  633. tt["is_temp"] = True
  634. continue
  635. if mt["database"].strip().upper()==tdn and mt["table_name"].strip().upper()==ttn:
  636. mt["col_dep"] = tt["col_dep"]
  637. tt["is_temp"] = True
  638. if len(mt["col_list"]) == 0:
  639. mt["col_list"] = [c.upper() for c in tt["col_list"]]
  640. mt["col_size"] = len(mt["col_list"])
  641. @timeit
  642. def _merge_final_target_table(self, seg_target_table_list: list[dict], target_table_name: str):
  643. """
  644. 合并分段目标表内容
  645. """
  646. # 可能存在多个相同的真正目标表(多次插入目标表情况)
  647. for tt in seg_target_table_list:
  648. tt["is_target"] = False
  649. database = tt["database"].strip()
  650. table = tt["table_name"].strip()
  651. if len(database) > 0:
  652. dt = ".".join([database,table])
  653. else:
  654. dt = table
  655. if dt.upper() == target_table_name.strip().upper():
  656. tt["is_target"] = True
  657. # 合并真正的目标表
  658. i = 0
  659. result = {}
  660. final_col_list: list[str] = []
  661. final_col_dep_list: list[dict] = []
  662. # 合并目标表信息
  663. for target_table in seg_target_table_list:
  664. if target_table["is_target"]:
  665. if i == 0:
  666. result["database"] = target_table["database"]
  667. result["table_name"] = target_table["table_name"]
  668. i = 1
  669. # 合并列信息
  670. col_list = target_table["col_list"]
  671. for c in col_list:
  672. if c not in final_col_list and len(c.strip()) > 0:
  673. final_col_list.append(c.upper())
  674. # 合并列来源信息
  675. col_dep_list = target_table["col_dep"]
  676. for col_dep in col_dep_list:
  677. cn = col_dep["col_name"].strip().upper()
  678. from_cols = col_dep["from_cols"]
  679. existed = False
  680. if len(cn) > 0:
  681. for fcd in final_col_dep_list:
  682. if cn == fcd["col_name"].upper():
  683. existed = True
  684. # 合并来源字段
  685. final_from_cols = fcd["from_cols"]
  686. fcd["from_cols"] += [c.upper() for c in from_cols if c.upper() not in final_from_cols]
  687. if not existed:
  688. final_col_dep_list.append({
  689. "col_name": cn.upper(),
  690. "col_comment": col_dep["col_comment"],
  691. "from_cols": [c.upper() for c in from_cols]
  692. })
  693. result["col_size"] = len(final_col_list)
  694. result["col_list"] = final_col_list
  695. result["col_dep"] = final_col_dep_list
  696. return result
  697. async def __check(self, state: AgentState):
  698. """
  699. 检查用户问题是否是有效的SQL语句/SQL片段
  700. """
  701. template = """
  702. 你是 数据库 {dialect} SQL分析助手, 仔细分析SQL语句: {sql} ,判断是否语法正确,如果是 直接返回 1,否则说明错误地方,并返回 0。
  703. ### 分析要求
  704. 1. 判断子查询的正确性
  705. 2. 判断整体语句的正确性
  706. 3. 如果存在语法错误,给出说明
  707. 不要重复用户问题、不要分析过程、直接回答。
  708. """
  709. pt = ChatPromptTemplate.from_template(template)
  710. chain = pt | self.llm
  711. stopped = False
  712. response = ""
  713. config = {"configurable": {"thread_id": state["session"]}}
  714. async for chunk in chain.astream({"sql": state["question"], "dialect": state["dialect"]}, config=config):
  715. current_state = self.memory.get(config)
  716. if current_state and current_state["channel_values"]["stop"]:
  717. stopped = True
  718. break
  719. response += chunk.content
  720. if stopped:
  721. return {"stop": True, "error": "用户中止", "status": "canceled", "lineage": {}}
  722. if response == "0":
  723. return {"status": "invalid", "error": "无效问题或SQL存在语法错误,请重新描述!", "lineage": {}}
  724. return {"status": "success"}
  725. async def __merge_result(self, state: AgentState):
  726. """
  727. 合并所有节点结果,输出血缘关系
  728. """
  729. result = {}
  730. result["file_name"] = state["file_name"]
  731. source_tables = state["source_tables"]
  732. target_table = state["target_table"]
  733. # 设置目标表依赖的来源表数量和字段数量
  734. target_table["src_table_size"] = len(source_tables)
  735. target_table["col_size"] = len(target_table["col_dep"])
  736. result["target_table"] = target_table
  737. result["source_tables"] = source_tables
  738. result["where_list"] = state.get("where_list",[])
  739. result["join_list"] = state["join_list"]
  740. return {"lineage": result}
  741. async def __extract_source(self, state: AgentState):
  742. """
  743. 根据SQL,提取其中的来源库表信息
  744. """
  745. dt1 = datetime.now()
  746. # 默认DB
  747. default_db = "TMP"
  748. if len(self.target_table) > 0:
  749. arr = self.target_table.split(".")
  750. if len(arr) == 2: # 使用识别出的目标表db作为默认DB
  751. default_db = arr[0]
  752. template = """
  753. 你是SQL数据血缘分析专家,仔细分析SQL片段和上下文信息, 从SQL片段中提取出 来源表信息。
  754. ### SQL和上下文信息如下
  755. - SQL片段:{sql}
  756. - 数据库方言: {dialect}
  757. - 变量设置信息:{var}
  758. - 已有的建表信息: {create_table_info}
  759. ### 核心要求:
  760. 1. 提取 来源表数据库名称(无则使用 {default_db} 填充)
  761. 2. 提取 来源表名,包括 from子句、子查询、嵌套子查询、union语句、with语句中出现的物理表名(注意:不要带库名前缀)
  762. 3. 提取 来源表别名(无则留空)
  763. 4. 提取来源表的所有字段信息,提取来自几种:
  764. - 从select中提取,如果带有表别名,则转换成物理表名
  765. - 从where过滤条件中提取
  766. - 从关联条件(inner join,left join,right join,full join)中提取
  767. - 从group中提取
  768. - 从函数参数中提取,比如函数 COALESCE(T13.ECIF_CUST_ID,T2.RELAT_PTY_ID,'') AS ECIF_CUST_ID,提取出T13.ECIF_CUST_ID和T2.RELAT_PTY_ID
  769. - 从表达式参数中提取
  770. 5. 判断来源表是否是临时表
  771. 6. 不要包含目标表(目标表指最终插入的表)
  772. ### 输出规范
  773. {format_instructions}
  774. 不要重复用户问题,不要分析中间过程,直接给出答案。
  775. """
  776. parser = PydanticOutputParser(pydantic_object=SourceTableList)
  777. pt = ChatPromptTemplate.from_template(template=template).partial(format_instructions=parser.get_format_instructions())
  778. chain = pt | self.llm_coder | parser
  779. answer = {}
  780. config = {"configurable": {"thread_id": state["session"]}}
  781. stopped = False
  782. async for chunk in chain.astream({
  783. "sql": state["question"],
  784. "dialect": state["dialect"],
  785. "var": self.var_list or [],
  786. "default_db": default_db,
  787. "create_table_info": self.ct_list
  788. }):
  789. current_state = self.memory.get(config)
  790. if current_state and current_state["channel_values"]["stop"]:
  791. stopped = True
  792. break
  793. # 设置目标表字段数量,来源表数量
  794. # 合并相同源表
  795. table_list: list[SourceTable] = []
  796. for table in chunk.source_tables:
  797. if table not in table_list:
  798. table_list.append(table)
  799. else:
  800. idx = table_list.index(table)
  801. tt = table_list[idx]
  802. # 合并相同的列
  803. for col in table.col_list:
  804. if col not in tt.col_list:
  805. tt.col_list.append(col)
  806. answer["source_tables"] = [table.model_dump() for table in table_list]
  807. dt2 = datetime.now()
  808. print("提取源表用时(sec):", (dt2-dt1).seconds)
  809. if stopped:
  810. return {"status": "canceled", "error": "用户中止", "lineage": {}}
  811. return answer
  812. async def __extract_target(self, state: AgentState):
  813. """
  814. 根据SQL,提取其中的目标表信息
  815. """
  816. if state["sql_type"] == "1":
  817. return {"target_table": {"database":"", "table_name": "", "col_list": [], "col_size": 0, "col_dep": []}}
  818. # 默认DB
  819. default_db = "TMP"
  820. if len(self.target_table) > 0:
  821. arr = self.target_table.split(".")
  822. if len(arr) == 2: # 使用识别出的目标表db作为默认DB
  823. default_db = arr[0]
  824. template = """
  825. 你是SQL数据血缘分析专家,仔细分析SQL片段和以下上下文:
  826. - SQL片段: {sql}
  827. - 来源表信息: {source_tables}
  828. - 数据库方言: {dialect}
  829. - 变量设置信息:{var}
  830. - 已有的建表信息: {create_table_info}
  831. ### 核心要求:
  832. 1. **目标表识别**:
  833. - 提取 INSERT INTO 后的表作为目标表
  834. - 格式:`数据库名.表名`(未指定库名则使用 {default_db} 填充)
  835. 2. **字段依赖分析**:
  836. - 目标表字段必须溯源到来源表的物理字段
  837. - 字段获取方式分类:
  838. - `直取`:直接引用源字段(如 `src.col`)
  839. - `函数`:通过函数转换(如 `COALESCE(col1, col2)`)
  840. - `表达式`:计算表达式(如 `col1 + col2`)
  841. - `常量`:固定值(如 `'2023'`)
  842. - `子查询`:嵌套SELECT结果
  843. - 字段来源格式:`源库.源表.字段名`(不要使用表别名,如果源库未指定则使用 {default_db}替换)
  844. - 字段来源信息只能包含 库名、表名和字段信息,如果经过函数或表达式处理,则从参数中提取出具体的字段信息
  845. 3. **关键约束**:
  846. - 每个目标字段必须有详细的来源说明
  847. - 函数参数需完整展开(如 `COALESCE(T1.id, T2.id)` 需解析所有参数)
  848. - 表名必须是 英文字符 构成(非汉字构成)
  849. ### 输出规范
  850. {format_instructions}
  851. 请直接输出JSON格式结果,无需解释过程。
  852. """
  853. parser = PydanticOutputParser(pydantic_object=TargetTable)
  854. pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
  855. chain = pt | self.llm_coder | parser
  856. answer = {}
  857. config = {"configurable": {"thread_id": state["session"]}}
  858. dt1 = datetime.now()
  859. print("开始解析目标表...")
  860. stopped = False
  861. async for chunk in chain.astream({
  862. "sql": state["question"],
  863. "dialect": state["dialect"],
  864. "source_tables": state["source_tables"],
  865. "var": self.var_list or [],
  866. "default_db": default_db,
  867. "default_db": default_db,
  868. "create_table_info": self.ct_list
  869. }):
  870. current_state = self.memory.get(config)
  871. if current_state and current_state["channel_values"]["stop"]:
  872. stopped = True
  873. break
  874. answer["target_table"] = chunk.model_dump()
  875. if stopped:
  876. return {"status": "canceled", "error": "用户中止", "lineage": {}}
  877. dt2 = datetime.now()
  878. print("提取目标表用时(sec):", (dt2 - dt1).seconds)
  879. print(f"target:{answer}")
  880. return answer
  881. async def __extract_where(self, state: AgentState):
  882. """
  883. 根据SQL,提取来源表 where 信息
  884. """
  885. # 默认DB
  886. default_db = "TMP"
  887. if len(self.target_table) > 0:
  888. arr = self.target_table.split(".")
  889. if len(arr) == 2: # 使用识别出的目标表db作为默认DB
  890. default_db = arr[0]
  891. dt1 = datetime.now()
  892. template = """
  893. 你是SQL数据血缘分析专家,仔细分析 SQL语句和上下文,提取 where 条件信息。
  894. ### SQL语句和上下文如下
  895. - SQL语句: {sql}
  896. - 数据库方言: {dialect}
  897. - 变量设置信息: {var}
  898. ### 提取要求:
  899. 1. 提取 所有 来源表 中的 where 条件信息,包含 where 中出现的 表原名(非别名)、 表所在数据库名(如果没有则使用 {default_db} 填充)、字段名、条件操作符和条件值
  900. 2. 字段名不能包括函数,比如 length(INT_ORG_NO),提取的字段名是 INT_ORG_NO
  901. 3. 如果SQL语句中存在变量,则根据变量设置信息转换成实际值
  902. ### 输出规范
  903. {format_instructions}
  904. 不要重复用户问题,不要分析中间过程,直接给出答案。
  905. """
  906. parser = PydanticOutputParser(pydantic_object=WhereCondList)
  907. pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
  908. chain = pt | self.llm_coder | parser
  909. answer = {}
  910. config = {"configurable": {"thread_id": state["session"]}}
  911. stopped = False
  912. async for chunk in chain.astream({
  913. "sql": state["question"],
  914. "dialect": state["dialect"],
  915. "var": self.var_list or [],
  916. "default_db": default_db
  917. }):
  918. current_state = self.memory.get(config)
  919. if current_state and current_state["channel_values"]["stop"]:
  920. stopped = True
  921. break
  922. # 设置目标表字段数量,来源表数量
  923. dt2 = datetime.now()
  924. if stopped:
  925. return {"status": "canceled", "error": "用户中止", "lineage": {}}
  926. return answer
  927. async def __extract_join(self, state: AgentState):
  928. """
  929. 根据SQL,提取 表关联 信息
  930. """
  931. # 默认DB
  932. default_db = "TMP"
  933. if len(self.target_table) > 0:
  934. arr = self.target_table.split(".")
  935. if len(arr) == 2: # 使用识别出的目标表db作为默认DB
  936. default_db = arr[0]
  937. dt1 = datetime.now()
  938. template = """
  939. 你是SQL数据血缘分析专家,负责提取完整的JOIN关系。仔细分析SQL片段:
  940. - SQL片段: {sql}
  941. - 数据库方言: {dialect}
  942. ### 关键要求:
  943. 1. **SQL语句中必须存在关联条件(包括inner、left join、right join、full join)**
  944. 2. **识别JOIN结构**:
  945. - 提取FROM子句中的主表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {default_db} 填充
  946. - 提取每个JOIN子句中的从表信息(包括 数据库名和物理表名),如果无数据库名,则使用 {j_default_db} 填充
  947. - 明确标注JOIN类型(INNER/LEFT/RIGHT/FULL)
  948. 3. **关联字段提取**:
  949. - 必须提取每对关联字段的完整信息:
  950. * 主表端:表别名.字段名
  951. * 从表端:表别名.字段名
  952. - 明确标注关联操作符(=, >, <等)
  953. - 多条件JOIN拆分为多个字段对
  954. 4. **特殊场景处理**:
  955. - 子查询作为表时,表名填写子查询别名
  956. - 隐式JOIN(WHERE条件)需转换为显式JOIN结构
  957. - 多表JOIN时保持原始顺序
  958. - 如果关联字段是 常量,则关联字段取空
  959. - 如果关联字段经函数处,则提取参数中的 字段,如果参数是常量,则关联字段取空
  960. ### 输出规范
  961. {format_instructions}
  962. 不要重复用户问题,不要中间思考过程,直接回答。
  963. """
  964. parser = PydanticOutputParser(pydantic_object=JoinRelationList)
  965. pt = ChatPromptTemplate.from_template(template).partial(format_instructions=parser.get_format_instructions())
  966. chain = pt | self.llm_coder | parser
  967. answer = {}
  968. config = {"configurable": {"thread_id": state["session"]}}
  969. stopped = False
  970. async for chunk in chain.astream({
  971. "sql": state["question"],
  972. "dialect": state["dialect"],
  973. "default_db": default_db,
  974. "j_default_db": default_db
  975. }):
  976. current_state = self.memory.get(config)
  977. if current_state and current_state["channel_values"]["stop"]:
  978. stopped = True
  979. break
  980. # 设置目标表字段数量,来源表数量
  981. answer["join_list"] = [join.model_dump() for join in chunk.join_list]
  982. dt2 = datetime.now()
  983. if stopped:
  984. return {"status": "canceled", "error": "用户中止", "lineage": {}}
  985. return answer
  986. def trigger_stop(self, session: str):
  987. """外部触发中止"""
  988. config = {"configurable": {"thread_id": session}}
  989. current_state = self.memory.get(config)
  990. if current_state:
  991. current_state["channel_values"]["stop"] = True
  992. async def main(sql_content: str, sql_file: str, dialect: str, owner: str):
  993. agent = SqlLineageAgent()
  994. result = await agent.ask_question(sql_content=sql_content, session="s-1", dialect=dialect, sql_file=sql_file, owner=owner)
  995. return result
  996. def read_var_file(var_file: str):
  997. """
  998. 读取变量文件
  999. :param var_file: 变量文件(格式excel)
  1000. """
  1001. import pandas as pd
  1002. # 读取变量excel文件
  1003. result = []
  1004. df = pd.read_excel(var_file)
  1005. for row in df.itertuples():
  1006. print(f"变量 {row.Index + 1}: {row.变量}, {row.含义}")
  1007. item = {}
  1008. arr = row.变量.split("=")
  1009. item["var_name"] = arr[0]
  1010. item["var_value"] = arr[1]
  1011. item["var_comment"] = row.含义
  1012. result.append(item)
  1013. return result
  1014. def replace_vars(sql_file: str, var_list: list[dict]):
  1015. """
  1016. 替换 sql_file 中的变量
  1017. :param sql_file: SQL文件
  1018. :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}]
  1019. """
  1020. new_content = ""
  1021. encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312']
  1022. for encoding in encodings:
  1023. try:
  1024. with open(sql_file, 'r', encoding=encoding) as f:
  1025. # 读取文件内容
  1026. content = f.read()
  1027. # 替换变量
  1028. for var in var_list:
  1029. name = var["var_name"]
  1030. value = var["var_value"]
  1031. new_content = content.replace(name, value)
  1032. content = new_content
  1033. break
  1034. except UnicodeDecodeError:
  1035. print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持")
  1036. continue
  1037. else:
  1038. print("无法找到合适的编码")
  1039. raise Exception(f"无法读取文件 {sql_file}, 未找到合适的编码.")
  1040. return new_content
  1041. def get_ddl_list(ddl_dir: str, var_list: list[dict]):
  1042. """
  1043. 获取 DDL目录中定义的所有建表DDL,并替换其中的变量
  1044. :param ddl_dir: DDL目录,目录结构 ddl/db/xx.ddl
  1045. :param var_list: 变量集合 [{'var_name':xx,'var_value':xxx,'var_comment':xx}]
  1046. """
  1047. file_list = []
  1048. # 遍历目录,获取所有.ddl文件
  1049. for root, dirs, files in os.walk(ddl_dir):
  1050. for file in files:
  1051. if file.endswith(".ddl"):
  1052. file_path = os.path.join(root, file)
  1053. file_list.append(file_path)
  1054. print(f"DDL文件数量:{len(file_list)}")
  1055. # 读取ddl文件,替换其中的变量
  1056. result = []
  1057. for file in file_list:
  1058. fp = Path(file)
  1059. # 文件的名称规则:db__table,解析出数据库和表名
  1060. name = fp.name.replace(".ddl", "")
  1061. arr = name.split("__")
  1062. db = arr[0]
  1063. table = arr[1]
  1064. # key
  1065. key = f"{db.upper()}.{table.upper()}"
  1066. ddl_content = ""
  1067. encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252', 'gbk', 'gb2312']
  1068. for encoding in encodings:
  1069. try:
  1070. with open(file, 'r', encoding=encoding) as f:
  1071. # 读取文件内容
  1072. content = f.read()
  1073. # 替换变量
  1074. for var in var_list:
  1075. name = var["var_name"]
  1076. value = var["var_value"]
  1077. new_content = content.replace(name, value)
  1078. content = new_content
  1079. result.append({key: content})
  1080. break
  1081. except UnicodeDecodeError:
  1082. print(f"文件{os.path.basename(f.name)} 编码 {encoding} 不支持")
  1083. continue
  1084. else:
  1085. print("无法找到合适的编码")
  1086. raise Exception(f"无法读取文件 {file}, 未找到合适的编码.")
  1087. print(f"result:{result[0]}")
  1088. return result
  1089. if __name__ == "__main__":
  1090. import sys
  1091. import os
  1092. from pathlib import Path
  1093. import json
  1094. # 参数检查
  1095. if len(sys.argv) <= 4:
  1096. print(f"usage python sql_lineage_agent_xmgj.py var_file sql_file dialect owner")
  1097. exit(1)
  1098. # 变量文件
  1099. var_file = sys.argv[1]
  1100. # sql脚本文件或目录
  1101. sql_file = sys.argv[2]
  1102. # 数据库方言
  1103. dialect = sys.argv[3]
  1104. # 属主
  1105. owner = sys.argv[4]
  1106. print(f"1、解析SQL文件/目录:{sql_file}")
  1107. print(f"2、SQL文件属主平台:{owner}")
  1108. print(f"3、变量文件:{var_file}")
  1109. print(f"4、SQL数据库方言: {dialect}")
  1110. # 检查变量文件是否存在
  1111. var_fp = Path(var_file)
  1112. if not var_fp.exists():
  1113. raise FileNotFoundError(f"变量文件 {var_file} 不存在,请指定正确路径.")
  1114. # 检查SQL脚本文件是否存在
  1115. sql_fp = Path(sql_file)
  1116. if not sql_fp.exists():
  1117. raise FileNotFoundError(f"SQL脚本文件 {sql_file} 不存在,请指定正确路径.")
  1118. # 读取变量文件,获取变量值
  1119. var_list = read_var_file(var_file)
  1120. # 读取SQL文件, 替换变量
  1121. sql_content = replace_vars(sql_file=sql_file, var_list=var_list)
  1122. # 解析SQL血缘
  1123. result = asyncio.run(main(sql_content=sql_content, sql_file=sql_fp.name, dialect=dialect, owner=owner))
  1124. # 将结果写入同级目录下,文件后缀为.json
  1125. target_file = Path(sql_fp.parent.absolute() / (sql_fp.name + ".json"))
  1126. print(f"写目标文件:{target_file}")
  1127. # 写文件
  1128. with open(target_file, 'w', encoding="utf-8") as t:
  1129. t.write(json.dumps(result, ensure_ascii=False, indent=2))