from pydantic import BaseModel, Field from langchain.agents.tools import BaseTool from typing import Type from llmops.agents.datadev.rag.schema_handler import SchemaHandler class GetDatabaseSchemaToolInput(BaseModel): """ GetTableSchemaTool 输入参数结构 """ query: str = Field(description="检索文本") class GetDatabaseSchemaTool(BaseTool): # 工具名称 name: str = "get_database_schema" # 工具描述 description: str = ( "根据用户问题,从知识库中检索数据库schema定义信息", "输入是查找的文本字符串", "返回的形式是json数组" ) # 工具参数结构 args_schema: Type[BaseModel] = GetDatabaseSchemaToolInput # schema查询器 schema_handler: Type[SchemaHandler] = SchemaHandler() def _run(self, query: str): result = "" try: result = self.query(query) print("schema=====:", result) except Exception as e: print(f"调用工具出现异常{str(e)}") result = "" return result def _arun(self, query_list: list[str]): return self._run(query_list) def query(self, query: str): result = self.schema_handler.query_mulsimilar(query_list=query.split(","), top_k=10, similarity_threshold=0.45) return result if __name__ == '__main__': tool = GetTableSchemaTool() result = tool.run("查询机构编号是100的机构信息") print(result)