| 123456789101112131415161718192021 |
- from langchain.agents.tools import BaseTool
- from pydantic import BaseModel, Field
- from typing import Type
- class GenQuerySqlToolInput(BaseModel):
- query: str = Field(description="查询的文字描述")
- table_schema: str = Field(description="表schma信息说明")
- dialect: str = Field(default="hive", description="数据库方言")
- class GenQuerySqlTool(BaseTool):
- name = "gen_query_sql"
- description = (
- "根据用户描述和表schema信息,生成查询语句"
- )
- args_schema: Type[BaseModel] = GenQuerySqlToolInput
- def _run(self, query: str, table_schema: str, dialect: str):
|