| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- from pydantic import BaseModel, Field
- from langchain.agents.tools import BaseTool
- from typing import Type
- from llmops.agents.datadev.rag.schema_handler import SchemaHandler
- class GetDatabaseSchemaToolInput(BaseModel):
- """
- GetTableSchemaTool 输入参数结构
- """
- query: str = Field(description="检索文本")
- class GetDatabaseSchemaTool(BaseTool):
- # 工具名称
- name: str = "get_database_schema"
- # 工具描述
- description: str = (
- "根据用户问题,从知识库中检索数据库schema定义信息",
- "输入是查找的文本字符串",
- "返回的形式是json数组"
- )
- # 工具参数结构
- args_schema: Type[BaseModel] = GetDatabaseSchemaToolInput
- # schema查询器
- schema_handler: Type[SchemaHandler] = SchemaHandler()
- def _run(self, query: str):
- result = ""
- try:
- result = self.query(query)
- print("schema=====:", result)
- except Exception as e:
- print(f"调用工具出现异常{str(e)}")
- result = ""
- return result
- def _arun(self, query_list: list[str]):
- return self._run(query_list)
- def query(self, query: str):
- result = self.schema_handler.query_mulsimilar(query_list=query.split(","), top_k=10, similarity_threshold=0.45)
- return result
- if __name__ == '__main__':
- tool = GetTableSchemaTool()
- result = tool.run("查询机构编号是100的机构信息")
- print(result)
|