Natural Language Queries on Pandas with LangChain and DuckDB
- April 11, 2024
A lot of my reporting these days seems to revolve around pandas. I like how DataFrame
s can be quickly turned into charts using matplotlib. I often find myself pulling data into a dataframe, filtering and using the .plot
options to visualise the data.
Since Mark Needham posed the question “Can an LLM write better pandas than me?” (TL;DR: it doesn’t at the moment), I have been wondering whether text-to-SQL (text-to-cypher’s more popular cousin) could yield better results. Naturally, there’s only one tool for the job: DuckDB.
I’m happy to say that early signs are promising.
Converting DataFrames to DuckDB
I was pleasantly surprised to find that DuckDB already supports SQL on Pandas, which makes the job almost trivial.
If you want to follow along, you’ll need to install the duckdb
pandas
and langchain
dependencies:
%pip install duckdb pandas langchain
Next, something to query. Naturally, my mind went instantly to the League Two standings.
import pandas as pd
data = """Position Team Played Won Drawn Lost Goals For Goals Against Goal Difference Points
1 Stockport County 42 24 11 7 84 42 42 83
2 Wrexham 43 23 10 10 78 51 27 79
3 Mansfield Town 42 21 13 8 81 43 38 76
4 Milton Keynes Dons 43 22 8 13 73 57 16 74
5 Crewe Alexandra 43 19 13 11 68 58 10 70
6 Barrow 41 18 13 10 57 45 12 67
7 Crawley Town 42 20 5 17 66 61 5 65
8 AFC Wimbledon 43 16 14 13 55 44 11 62
9 Walsall 42 17 11 14 63 61 2 62
10 Doncaster Rovers 42 18 7 17 59 63 -4 61
11 Harrogate Town 43 17 10 16 53 60 -7 61
12 Gillingham 43 17 9 17 40 53 -13 60
13 Bradford City 42 15 12 15 50 54 -4 57
14 Morecambe 43 17 9 17 63 74 -11 57
15 Notts County 42 16 7 19 83 79 4 55
16 Newport County 43 16 7 20 60 69 -9 55
17 Accrington Stanley 42 15 9 18 56 60 -4 54
18 Tranmere Rovers 43 15 6 22 61 63 -2 51
19 Swindon Town 42 13 11 18 70 74 -4 50
20 Salford City 43 12 11 20 62 78 -16 47
21 Grimsby Town 42 9 16 17 52 70 -18 43
22 Sutton United 43 9 12 22 51 76 -25 39
23 Colchester United 41 9 11 21 52 72 -20 38
24 Forest Green Rovers 43 9 9 25 41 71 -30 36
"""
# Split the data into rows and columns by newlines and tabs
rows =[ n.strip().split("\t") for n in data.strip().split("\n") ]
# Remove spaces from the headers
headers = [ n.replace(" ", "") for n in rows[0]]
# Get the data
data = rows[1:]
# Create the dataframe
league_table = pd.DataFrame(data, columns=headers)
# Convert numeric columns
league_table['Position'] = league_table['Position'].astype(int)
league_table['Played'] = league_table['Played'].astype(int)
league_table['Won'] = league_table['Won'].astype(int)
league_table['Drawn'] = league_table['Drawn'].astype(int)
league_table['Lost'] = league_table['Lost'].astype(int)
league_table['GoalsFor'] = league_table['GoalsFor'].astype(int)
league_table['GoalsAgainst'] = league_table['GoalsAgainst'].astype(int)
league_table['GoalDifference'] = league_table['GoalDifference'].astype(int)
league_table['Points'] = league_table['Points'].astype(int)
Position | Team | Played | Won | Drawn | Lost | GoalsFor | GoalsAgainst | GoalDifference | Points |
---|---|---|---|---|---|---|---|---|---|
1 | Stockport County | 42 | 24 | 11 | 7 | 84 | 42 | 42 | 83 |
2 | Wrexham | 43 | 23 | 10 | 10 | 78 | 51 | 27 | 79 |
3 | Mansfield Town | 42 | 21 | 13 | 8 | 81 | 43 | 38 | 76 |
4 | Milton Keynes Dons | 43 | 22 | 8 | 13 | 73 | 57 | 16 | 74 |
5 | Crewe Alexandra | 43 | 19 | 13 | 11 | 68 | 58 | 10 | 70 |
6 | Barrow | 41 | 18 | 13 | 10 | 57 | 45 | 12 | 67 |
7 | Crawley Town | 42 | 20 | 5 | 17 | 66 | 61 | 5 | 65 |
8 | AFC Wimbledon | 43 | 16 | 14 | 13 | 55 | 44 | 11 | 62 |
9 | Walsall | 42 | 17 | 11 | 14 | 63 | 61 | 2 | 62 |
10 | Doncaster Rovers | 42 | 18 | 7 | 17 | 59 | 63 | -4 | 61 |
11 | Harrogate Town | 43 | 17 | 10 | 16 | 53 | 60 | -7 | 61 |
12 | Gillingham | 43 | 17 | 9 | 17 | 40 | 53 | -13 | 60 |
13 | Bradord City | 42 | 15 | 12 | 15 | 50 | 54 | -4 | 57 |
14 | Morecambe | 43 | 17 | 9 | 17 | 63 | 74 | -11 | 57 |
15 | Notts County | 42 | 16 | 7 | 19 | 83 | 79 | 4 | 55 |
16 | Newport County | 43 | 16 | 7 | 20 | 60 | 69 | -9 | 55 |
17 | Accrington Stanley | 42 | 15 | 9 | 18 | 56 | 60 | -4 | 54 |
18 | Tranmere Rovers | 43 | 15 | 6 | 22 | 61 | 63 | -2 | 51 |
19 | Swindon Town | 42 | 13 | 11 | 18 | 70 | 74 | -4 | 50 |
20 | Salford City | 43 | 12 | 11 | 20 | 62 | 78 | -16 | 47 |
21 | Grimsby Town | 42 | 9 | 16 | 17 | 52 | 70 | -18 | 43 |
22 | Sutton United | 43 | 9 | 12 | 22 | 51 | 76 | -25 | 39 |
23 | Colchester United | 41 | 9 | 11 | 21 | 52 | 72 | -20 | 38 |
24 | Forest Green Rovers | 43 | 9 | 9 | 25 | 41 | 71 | -30 | 36 |
Querying with DuckDB
DuckDB will automatically detect a dataframe based on variable name, so you can query it directly using duckdb.sql()
.
import duckdb
duckdb.sql("SELECT * FROM league_table")
┌──────────┬─────────────────────┬────────┬───────┬───────┬───────┬──────────┬──────────────┬────────────────┬────────┐ │ Position │ Team │ Played │ Won │ Drawn │ Lost │ GoalsFor │ GoalsAgainst │ GoalDifference │ Points │ │ int64 │ varchar │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ ├──────────┼─────────────────────┼────────┼───────┼───────┼───────┼──────────┼──────────────┼────────────────┼────────┤ │ 1 │ Stockport County │ 42 │ 24 │ 11 │ 7 │ 84 │ 42 │ 42 │ 83 │ │ 2 │ Wrexham │ 43 │ 23 │ 10 │ 10 │ 78 │ 51 │ 27 │ 79 │ │ 3 │ Mansfield Town │ 42 │ 21 │ 13 │ 8 │ 81 │ 43 │ 38 │ 76 │ │ 4 │ Milton Keynes Dons │ 43 │ 22 │ 8 │ 13 │ 73 │ 57 │ 16 │ 74 │ │ 5 │ Crewe Alexandra │ 43 │ 19 │ 13 │ 11 │ 68 │ 58 │ 10 │ 70 │ │ 6 │ Barrow │ 41 │ 18 │ 13 │ 10 │ 57 │ 45 │ 12 │ 67 │ │ 7 │ Crawley Town │ 42 │ 20 │ 5 │ 17 │ 66 │ 61 │ 5 │ 65 │ │ 8 │ AFC Wimbledon │ 43 │ 16 │ 14 │ 13 │ 55 │ 44 │ 11 │ 62 │ │ 9 │ Walsall │ 42 │ 17 │ 11 │ 14 │ 63 │ 61 │ 2 │ 62 │ │ 10 │ Doncaster Rovers │ 42 │ 18 │ 7 │ 17 │ 59 │ 63 │ -4 │ 61 │ │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ │ · │ · │ · │ · │ · │ · │ · │ · │ · │ · │ │ 15 │ Notts County │ 42 │ 16 │ 7 │ 19 │ 83 │ 79 │ 4 │ 55 │ │ 16 │ Newport County │ 43 │ 16 │ 7 │ 20 │ 60 │ 69 │ -9 │ 55 │ │ 17 │ Accrington Stanley │ 42 │ 15 │ 9 │ 18 │ 56 │ 60 │ -4 │ 54 │ │ 18 │ Tranmere Rovers │ 43 │ 15 │ 6 │ 22 │ 61 │ 63 │ -2 │ 51 │ │ 19 │ Swindon Town │ 42 │ 13 │ 11 │ 18 │ 70 │ 74 │ -4 │ 50 │ │ 20 │ Salford City │ 43 │ 12 │ 11 │ 20 │ 62 │ 78 │ -16 │ 47 │ │ 21 │ Grimsby Town │ 42 │ 9 │ 16 │ 17 │ 52 │ 70 │ -18 │ 43 │ │ 22 │ Sutton United │ 43 │ 9 │ 12 │ 22 │ 51 │ 76 │ -25 │ 39 │ │ 23 │ Colchester United │ 41 │ 9 │ 11 │ 21 │ 52 │ 72 │ -20 │ 38 │ │ 24 │ Forest Green Rovers │ 43 │ 9 │ 9 │ 25 │ 41 │ 71 │ -30 │ 36 │ ├──────────┴─────────────────────┴────────┴───────┴───────┴───────┴──────────┴──────────────┴────────────────┴────────┤ │ 24 rows (20 shown) 10 columns │ └─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Natural Language Queries in LangChain
An LLM will need to know the structure of the table in order to generate a statement. DuckDB supports the DESC
keyword for generating a table schema.
schema = duckdb.sql("DESC league_table")
┌────────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐ │ column_name │ column_type │ null │ key │ default │ extra │ │ varchar │ varchar │ varchar │ varchar │ varchar │ varchar │ ├────────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤ │ Position │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ Team │ VARCHAR │ YES │ NULL │ NULL │ NULL │ │ Played │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ Won │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ Drawn │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ Lost │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ GoalsFor │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ GoalsAgainst │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ GoalDifference │ BIGINT │ YES │ NULL │ NULL │ NULL │ │ Points │ BIGINT │ YES │ NULL │ NULL │ NULL │ ├────────────────┴─────────────┴─────────┴─────────┴─────────┴─────────┤ │ 10 rows 6 columns │ └──────────────────────────────────────────────────────────────────────┘
Now, similar to the Cypher generation process covered in Building a Neo4j-backed Chatbot with TypeScript course on GraphAcademy, that schema should be passed along with a question to a prompt instructing the LLM to write an SQL statement.
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
import os
llm = ChatOpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY")
)
sql_prompt = PromptTemplate.from_template("""
Given the following table schema for the table `league_table`:
{schema}
Write an SQL statement to answer the following question:
{question}
Provide all rows form the table - eg select * from.
Provide a limit where applicable.
Respond with only the SQL statement.
""")
sql_chain = sql_prompt | llm | StrOutputParser()
Let’s give it a test:
sql = sql_chain.invoke({
"schema": schema,
"question": "Who are the top scorers?"
})
`'SELECT * FROM league_table ORDER BY GoalsFor DESC;'`
This looks reasonable. The LLM has correctly identified that the GoalsFor
column should be used to get the top scoring team.
The results of this query should be passed back to the LLM, along with instructions to generate an answer to the question. Classic Retrieval Augmented Generation (RAG)
answer_chain = PromptTemplate.from_template("""
Use the following data to provide a definitive answer to the user's question:
{context}
The question is: {question}
""") | llm | StrOutputParser()
These chains can then be combined to create an overall chain that takes the question, generates and executes the SQL, then passes that information to the answer generation chain.
from langchain_core.runnables import RunnablePassthrough
qa_chain = (
# Get the SQL schema
RunnablePassthrough.assign(
schema=lambda _: duckdb.sql("DESC league_table"),
)
# Generate the SQL and execute it
| RunnablePassthrough.assign(
context= sql_chain | duckdb.sql,
)
# Pass {context} and {question} to the answer_chain
| answer_chain
# Format the output as a string
| StrOutputParser()
)
So, how does this chain perform on a simple question?
Simple Query - Who has conceded the most?
Which team has conceded the most goals?
qa_chain.invoke({"question": "Which team has conceded the most goals?"})
‘Based on the data provided, Notts County has conceded the most goals with a total of 79 goals against.’
Looking at the original table, that looks correct.
How about a more complicated statement that will require some more complex SQL?
More complex - Goals scored vs Average
What about if we ask which team has scored the most goals and how does that compare to the league average?
qa_chain.invoke({"question": "Which team scored the most goals and how many have they scored compared to the league average?"})
‘Stockport County has scored the most goals with a total of 84, which is 22.42 goals more than the league average.‘
What about Swindon
How many games have Swindon Town won this year?
qa_chain.invoke({"question": "How many games have Swindon Town won?"})
‘Swindon Town has won 13 games.’
Can we trick the LLM into answering a question incorrectly?
qa_chain.invoke({"question": "How many more games has Swindon Town won than they have lost?"})
‘Based on the data provided, Swindon Town has won 5 fewer games than they have lost.’
No, we can’t. Nice!
Points-per-game Ratios
How about some more complex reasoning? Can the LLM handle that?
qa_chain.invoke({"question": "What are the points per game values for each team and how does that compare to their league position?"})
Based on the provided data, the points per game values for each team are as follows:
- Stockport County: 1.976
- Wrexham: 1.837
- Mansfield Town: 1.810
- Milton Keynes Dons: 1.721
- Crewe Alexandra: 1.628
- Barrow: 1.634
- Crawley Town: 1.548
- AFC Wimbledon: 1.442
- Walsall: 1.476
- Doncaster Rovers: 1.452
- Notts County: 1.310
- Newport County: 1.279
- Accrington Stanley: 1.286
- Tranmere Rovers: 1.186
- Swindon Town: 1.190
- Salford City: 1.093
- Grimsby Town: 1.024
- Sutton United: 0.907
- Colchester United: 0.927
- Forest Green Rovers: 0.837
From the data, it can be observed that Stockport County has the highest points per game value and is in the first position. On the other hand, Forest Green Rovers has the lowest points per game value and is in the last position. Generally, there is a correlation between the points per game value and the league position, with teams having higher points per game values tending to be in higher positions in the league standings.
I’m pretty happy with that.
Conclusion
I was surprised by two things: firstly that this seemed so easy; and secondly that no-one had provided an example already.
Overall the LLM performed well, in this case gpt-4
through OpenAI. I found that the SQL generation could be flaky at times, mainly due to the flavour of SQL that DuckDB expects. The LLM struggled in general when the columns contained spaces, despite amending the prompt to use single quotes for column names rather than backticks.
It will be interesting to see how this performs at scale, and on more complex datasets.