Agents
LangChain offers a number of tools and functions that allow you to create SQL Agents which can provide a more flexible way of interacting with SQL databases. The main advantages of using SQL Agents are:
- It can answer questions based on the databases schema as well as on the databases content (like describing a specific table).
- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.
- It can query the database as many times as needed to answer the user question.
To initialize the agent we'll use the createOpenAIToolsAgent
function.
This agent uses the SqlToolkit
which contains tools to:
- Create and execute queries
- Check query syntax
- Retrieve table descriptions
- β¦ and more
Setupβ
First, install the required packages and set your environment variables. This example will use OpenAI as the LLM.
npm install langchain @langchain/community @langchain/openai typeorm sqlite3
export OPENAI_API_KEY="your api key"
# Uncomment the below to use LangSmith. Not required.
# export LANGCHAIN_API_KEY="your api key"
# export LANGCHAIN_TRACING_V2=true
The below example will use a SQLite connection with Chinook database. Follow these installation steps to create Chinook.db
in the same directory as this notebook:
- Save this file as
Chinook_Sqlite.sql
- Run sqlite3
Chinook.db
- Run
.read Chinook_Sqlite.sql
- Test
SELECT * FROM Artist LIMIT 10;
Now, Chinhook.db
is in our directory and we can interface with it using the Typeorm-driven SqlDatabase
class:
import { SqlDatabase } from "langchain/sql_db";
import { DataSource } from "typeorm";
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
console.log(db.allTables.map((t) => t.tableName));
/**
[
'Album', 'Artist',
'Customer', 'Employee',
'Genre', 'Invoice',
'InvoiceLine', 'MediaType',
'Playlist', 'PlaylistTrack',
'Track'
]
*/
API Reference:
- SqlDatabase from
langchain/sql_db
Initializing the Agentβ
We'll use an OpenAI chat model and an "openai-tools" agent, which will use OpenAI's function-calling API to drive the agent's tool selection and invocations.
As we can see, the agent will first choose which tables are relevant and then add the schema for those tables and a few sample rows to the prompt.
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI } from "@langchain/openai";
import { createOpenAIToolsAgent, AgentExecutor } from "langchain/agents";
import { SqlToolkit } from "langchain/agents/toolkits/sql";
import { AIMessage } from "langchain/schema";
import { SqlDatabase } from "langchain/sql_db";
import { DataSource } from "typeorm";
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
const llm = new ChatOpenAI({ model: "gpt-3.5-turbo", temperature: 0 });
const sqlToolKit = new SqlToolkit(db, llm);
const tools = sqlToolKit.getTools();
const SQL_PREFIX = `You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results using the LIMIT clause.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools.
Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.`;
const SQL_SUFFIX = `Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}`;
const prompt = ChatPromptTemplate.fromMessages([
["system", SQL_PREFIX],
HumanMessagePromptTemplate.fromTemplate("{input}"),
new AIMessage(SQL_SUFFIX.replace("{agent_scratchpad}", "")),
new MessagesPlaceholder("agent_scratchpad"),
]);
const newPrompt = await prompt.partial({
dialect: sqlToolKit.dialect,
top_k: "10",
});
const runnableAgent = await createOpenAIToolsAgent({
llm,
tools,
prompt: newPrompt,
});
const agentExecutor = new AgentExecutor({
agent: runnableAgent,
tools,
});
console.log(
await agentExecutor.invoke({
input:
"List the total sales per country. Which country's customers spent the most?",
})
);
/**
{
input: "List the total sales per country. Which country's customers spent the most?",
output: 'The total sales per country are as follows:\n' +
'\n' +
'1. USA: $523.06\n' +
'2. Canada: $303.96\n' +
'3. France: $195.10\n' +
'4. Brazil: $190.10\n' +
'5. Germany: $156.48\n' +
'6. United Kingdom: $112.86\n' +
'7. Czech Republic: $90.24\n' +
'8. Portugal: $77.24\n' +
'9. India: $75.26\n' +
'10. Chile: $46.62\n' +
'\n' +
"To find out which country's customers spent the most, we can see that the customers from the USA spent the most with a total sales of $523.06."
}
*/
console.log(
await agentExecutor.invoke({
input: "Describe the playlisttrack table",
})
);
/**
{
input: 'Describe the playlisttrack table',
output: 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. Both columns are of type INTEGER and are not nullable (NOT NULL).\n' +
'\n' +
'Here are three sample rows from the `PlaylistTrack` table:\n' +
'\n' +
'| PlaylistId | TrackId |\n' +
'|------------|---------|\n' +
'| 1 | 3402 |\n' +
'| 1 | 3389 |\n' +
'| 1 | 3390 |\n' +
'\n' +
'Please let me know if there is anything else I can help you with.'
}
*/
API Reference:
- ChatPromptTemplate from
@langchain/core/prompts
- HumanMessagePromptTemplate from
@langchain/core/prompts
- MessagesPlaceholder from
@langchain/core/prompts
- ChatOpenAI from
@langchain/openai
- createOpenAIToolsAgent from
langchain/agents
- AgentExecutor from
langchain/agents
- SqlToolkit from
langchain/agents/toolkits/sql
- AIMessage from
langchain/schema
- SqlDatabase from
langchain/sql_db
Using a dynamic few-shot promptβ
To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input. This will help the model make better queries by inserting relevant queries in the prompt that the model can use as reference.
First we need some user input SQL query examples:
export const examples = [
{ input: "List all artists.", query: "SELECT * FROM Artist;" },
{
input: "Find all albums for the artist 'AC/DC'.",
query:
"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
},
{
input: "List all tracks in the 'Rock' genre.",
query:
"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
},
{
input: "Find the total duration of all tracks.",
query: "SELECT SUM(Milliseconds) FROM Track;",
},
{
input: "List all customers from Canada.",
query: "SELECT * FROM Customer WHERE Country = 'Canada';",
},
{
input: "How many tracks are there in the album with ID 5?",
query: "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
},
{
input: "Find the total number of invoices.",
query: "SELECT COUNT(*) FROM Invoice;",
},
{
input: "List all tracks that are longer than 5 minutes.",
query: "SELECT * FROM Track WHERE Milliseconds > 300000;",
},
{
input: "Who are the top 5 customers by total purchase?",
query:
"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
},
{
input: "Which albums are from the year 2000?",
query: "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
},
{
input: "How many employees are there",
query: 'SELECT COUNT(*) FROM "Employee"',
},
];
API Reference:
Now we can create an example selector. This will take the actual user input and select some number of examples to add to our few-shot prompt. We'll use a SemanticSimilarityExampleSelector, which will perform a semantic search using the embeddings and vector store we configure to find the examples most similar to our input:
import { HNSWLib } from "@langchain/community/vectorstores/hnswlib";
import { SemanticSimilarityExampleSelector } from "@langchain/core/example_selectors";
import {
FewShotPromptTemplate,
PromptTemplate,
ChatPromptTemplate,
SystemMessagePromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { SqlToolkit } from "langchain/agents/toolkits/sql";
import { SqlDatabase } from "langchain/sql_db";
import { DataSource } from "typeorm";
import { AgentExecutor, createOpenAIToolsAgent } from "langchain/agents";
import { examples } from "./examples.js";
const exampleSelector = await SemanticSimilarityExampleSelector.fromExamples(
examples,
new OpenAIEmbeddings(),
HNSWLib,
{
k: 5,
inputKeys: ["input"],
}
);
// Now we can create our FewShotPromptTemplate, which takes our example selector, an example prompt for formatting each example, and a string prefix and suffix to put before and after our formatted examples:
const SYSTEM_PREFIX = `You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
Here are some examples of user inputs and their corresponding SQL queries:`;
const fewShotPrompt = new FewShotPromptTemplate({
exampleSelector,
examplePrompt: PromptTemplate.fromTemplate(
"User input: {input}\nSQL query: {query}"
),
inputVariables: ["input", "dialect", "top_k"],
prefix: SYSTEM_PREFIX,
suffix: "",
});
// Since our underlying agent is an [OpenAI tools agent](https://js.langchain.com/docs/modules/agents/agent_types/openai_tools_agent), which uses
// OpenAI function calling, our full prompt should be a chat prompt with a human message template and an agentScratchpad MessagesPlaceholder.
// The few-shot prompt will be used for our system message:
const fullPrompt = ChatPromptTemplate.fromMessages([
new SystemMessagePromptTemplate(fewShotPrompt),
["human", "{input}"],
new MessagesPlaceholder("agent_scratchpad"),
]);
// And now we can create our agent with our custom prompt:
const llm = new ChatOpenAI({ model: "gpt-4", temperature: 0 });
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
const sqlToolKit = new SqlToolkit(db, llm);
const tools = sqlToolKit.getTools();
const newPrompt = await fullPrompt.partial({
dialect: sqlToolKit.dialect,
top_k: "10",
});
const runnableAgent = await createOpenAIToolsAgent({
llm,
tools,
prompt: newPrompt,
});
const agentExecutor = new AgentExecutor({
agent: runnableAgent,
tools,
});
console.log(
await agentExecutor.invoke({ input: "How many artists are there?" })
);
/**
{
input: 'How many artists are there?',
output: 'There are 275 artists.'
}
*/
API Reference:
- HNSWLib from
@langchain/community/vectorstores/hnswlib
- SemanticSimilarityExampleSelector from
@langchain/core/example_selectors
- FewShotPromptTemplate from
@langchain/core/prompts
- PromptTemplate from
@langchain/core/prompts
- ChatPromptTemplate from
@langchain/core/prompts
- SystemMessagePromptTemplate from
@langchain/core/prompts
- MessagesPlaceholder from
@langchain/core/prompts
- ChatOpenAI from
@langchain/openai
- OpenAIEmbeddings from
@langchain/openai
- SqlToolkit from
langchain/agents/toolkits/sql
- SqlDatabase from
langchain/sql_db
- AgentExecutor from
langchain/agents
- createOpenAIToolsAgent from
langchain/agents
You can see a LangSmith trace of this example here
Dealing with high-cardinality columnsβ
In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly.
We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.
First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai";
import { AgentExecutor, createOpenAIToolsAgent } from "langchain/agents";
import { SqlToolkit } from "langchain/agents/toolkits/sql";
import { SqlDatabase } from "langchain/sql_db";
import { Tool } from "langchain/tools";
import { createRetrieverTool } from "langchain/tools/retriever";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { DataSource } from "typeorm";
const datasource = new DataSource({
type: "sqlite",
database: "../../../../Chinook.db",
});
const db = await SqlDatabase.fromDataSourceParams({
appDataSource: datasource,
});
async function queryAsList(query: string): Promise<string[]> {
const res: Array<{ [key: string]: string }> = JSON.parse(await db.run(query))
.flat()
.filter((el: any) => el != null);
const justValues: Array<string> = res.map((item) =>
Object.values(item)[0]
.replace(/\b\d+\b/g, "")
.trim()
);
return justValues;
}
const artists = await queryAsList("SELECT Name FROM Artist");
const albums = await queryAsList("SELECT Title FROM Album");
console.log(albums.slice(0, 5));
/**
[
'For Those About To Rock We Salute You',
'Balls to the Wall',
'Restless and Wild',
'Let There Be Rock',
'Big Ones'
]
*/
// Now we can proceed with creating the custom retriever tool and the final agent:
const vectorDb = await MemoryVectorStore.fromTexts(
artists,
{},
new OpenAIEmbeddings()
);
const retriever = vectorDb.asRetriever(15);
const description = `Use to look up values to filter on.
Input is an approximate spelling of the proper noun, output is valid proper nouns.
Use the noun most similar to the search.`;
const retrieverTool = createRetrieverTool(retriever, {
description,
name: "search_proper_nouns",
}) as unknown as Tool;
const system = `You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool!
You have access to the following tables: {table_names}
If the question does not seem related to the database, just return "I don't know" as the answer.`;
const prompt = ChatPromptTemplate.fromMessages([
["system", system],
["human", "{input}"],
new MessagesPlaceholder("agent_scratchpad"),
]);
const llm = new ChatOpenAI({ model: "gpt-4", temperature: 0 });
const sqlToolKit = new SqlToolkit(db, llm);
const newPrompt = await prompt.partial({
dialect: sqlToolKit.dialect,
top_k: "10",
table_names: db.allTables.map((t) => t.tableName).join(", "),
});
const tools = [...sqlToolKit.getTools(), retrieverTool];
const runnableAgent = await createOpenAIToolsAgent({
llm,
tools,
prompt: newPrompt,
});
const agentExecutor = new AgentExecutor({
agent: runnableAgent,
tools,
});
console.log(
await agentExecutor.invoke({
input: "How many albums does alis in chain have?",
})
);
/**
{
input: 'How many albums does alis in chain have?',
output: 'Alice In Chains has 1 album.'
}
*/
API Reference:
- ChatPromptTemplate from
@langchain/core/prompts
- MessagesPlaceholder from
@langchain/core/prompts
- ChatOpenAI from
@langchain/openai
- OpenAIEmbeddings from
@langchain/openai
- AgentExecutor from
langchain/agents
- createOpenAIToolsAgent from
langchain/agents
- SqlToolkit from
langchain/agents/toolkits/sql
- SqlDatabase from
langchain/sql_db
- Tool from
langchain/tools
- createRetrieverTool from
langchain/tools/retriever
- MemoryVectorStore from
langchain/vectorstores/memory
You can see a LangSmith trace of this example here
Next stepsβ
To learn more about the built-in generic agent types as well as how to build custom agents, head to the Agents Modules.
The built-in AgentExecutor
runs a simple Agent action -> Tool call -> Agent action⦠loop. To build more complex agent runtimes, head to the LangGraph section.