gen_query_sql_tool.py 625 B

123456789101112131415161718192021
  1. from langchain.agents.tools import BaseTool
  2. from pydantic import BaseModel, Field
  3. from typing import Type
  4. class GenQuerySqlToolInput(BaseModel):
  5. query: str = Field(description="查询的文字描述")
  6. table_schema: str = Field(description="表schma信息说明")
  7. dialect: str = Field(default="hive", description="数据库方言")
  8. class GenQuerySqlTool(BaseTool):
  9. name = "gen_query_sql"
  10. description = (
  11. "根据用户描述和表schema信息,生成查询语句"
  12. )
  13. args_schema: Type[BaseModel] = GenQuerySqlToolInput
  14. def _run(self, query: str, table_schema: str, dialect: str):