get_database_schema_tool.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from pydantic import BaseModel, Field
  2. from langchain.agents.tools import BaseTool
  3. from typing import Type
  4. from llmops.agents.datadev.rag.schema_handler import SchemaHandler
  5. class GetDatabaseSchemaToolInput(BaseModel):
  6. """
  7. GetTableSchemaTool 输入参数结构
  8. """
  9. query: str = Field(description="检索文本")
  10. class GetDatabaseSchemaTool(BaseTool):
  11. # 工具名称
  12. name: str = "get_database_schema"
  13. # 工具描述
  14. description: str = (
  15. "根据用户问题,从知识库中检索数据库schema定义信息",
  16. "输入是查找的文本字符串",
  17. "返回的形式是json数组"
  18. )
  19. # 工具参数结构
  20. args_schema: Type[BaseModel] = GetDatabaseSchemaToolInput
  21. # schema查询器
  22. schema_handler: Type[SchemaHandler] = SchemaHandler()
  23. def _run(self, query: str):
  24. result = ""
  25. try:
  26. result = self.query(query)
  27. print("schema=====:", result)
  28. except Exception as e:
  29. print(f"调用工具出现异常{str(e)}")
  30. result = ""
  31. return result
  32. def _arun(self, query_list: list[str]):
  33. return self._run(query_list)
  34. def query(self, query: str):
  35. result = self.schema_handler.query_mulsimilar(query_list=query.split(","), top_k=10, similarity_threshold=0.45)
  36. return result
  37. if __name__ == '__main__':
  38. tool = GetTableSchemaTool()
  39. result = tool.run("查询机构编号是100的机构信息")
  40. print(result)