Text-to-SQL technologies frequently struggle to capture the complete context and meaning of a user’s request, resulting in queries that do not exactly match the intended. While developers work hard to enhance these systems, it is worth questioning if there is a better method.
Enter RAG-to-SQL—a new approach that combines natural language understanding with powerful data retrieval to generate accurate SQL queries. By blending the best of natural language processing and information retrieval, RAG-to-SQL offers a more reliable way to turn everyday language into meaningful insights from your database.
In this article, we’ll explore how RAG-to-SQL can transform the way we interact with databases, especially using Google Cloud services such as BigQuery and Vertex AI.
This article was published as a part of the Data Science Blogathon.
The main idea behind text to sql models of LLM was to enable people who do not know about SQL to interact with database and gain information using natural language instead. The existing text 2 sql framework relies mainly in LLM knowledge to be able to convert natural language query to sql query. This can lead to wrong or invalid formulation of SQL queries. This is where the new approach RAG to SQL comes to our rescue which is explained in next section.
In order to overcome the shortcomings of Text to SQL we can use the innovative approach of RAG to SQL. The integration of domain information about the database is the major issue that each text-to-SQL software faces. The RAG2SQL architecture addresses this difficulty by adding contextual data (metadata, DDL, queries, and more). This data is then “trained” and made available for usage.
Furthermore, the “retriever” evaluates and forwards the most relevant context to respond to the User Query. The end result is greatly improved precision.
Follow a detailed guide to implement RAG-to-SQL using Google Cloud services such as BigQuery and Vertex AI.
In order to follow and run this code you will need to setup your GCP (google cloud account with Payment information). Initially they provide free 300$ trial for 90 days so no charges will be incurred. Detail for account setup : Link
Below is the code flowchart which describes at a higher level the various blocks of code. We can refer it to follow along as we proceed.
The code implementation can be divided into 3 major blocks :
In colab notebook we have to install the below libraries required for this implementation.
! pip install langchain==0.0.340 --quiet
! pip install chromadb==0.4.13 --quiet
! pip install google-cloud-bigquery[pandas] --quiet
! pip install google-cloud-aiplatform --quiet#import csv
Now we have to declare some variables to initialise our GCP project and Big Query Datasets . Using this variables we can access the tables in Big Query withing GCP in our notebook.
You can view this details in your GCP cloud console. In BigQuery you can create a dataset and within dataset you can add or upload a table for details see Create Dataset and Create Table.
VERTEX_PROJECT = "Your GCP Project ID" # @param{type: "string"}
VERTEX_REGION = "us-central1" # @param{type: "string"}
BIGQUERY_DATASET = "Big Query Dataset Name" # @param{type: "string"}
BIGQUERY_PROJECT = "Vertex Project ID" # @param{type: "string"}
Now authenticate and login to your GCP Vertex AI from your notebook using below code in colab.
from google.colab import auth
auth.authenticate_user()
import vertexai
vertexai.init(project=VERTEX_PROJECT, location=VERTEX_REGION)
Now we have to create a vector db which will contain schema of various tables present in our dataset and we will create a retriever on top of this vector db so that we can incorporate RAG in our workflow.
Connecting to Big Query using BQ client in python and fetching schema of tables.
from google.cloud import bigquery
import json
#Fetching Schemas of Tables
bq_client = bigquery.Client(project=VERTEX_PROJECT)
bq_tables = bq_client.list_tables(dataset=f"{BIGQUERY_PROJECT}.{BIGQUERY_DATASET}")
schemas = []
for bq_table in bq_tables:
t = bq_client.get_table(f"{BIGQUERY_PROJECT}.{BIGQUERY_DATASET}.{bq_table.table_id}")
schema_fields = [f.to_api_repr() for f in t.schema]
schema = f"The schema for table {bq_table.table_id} is the following: \n```{json.dumps(schema_fields, indent=1)}```"
schemas.append(schema)
print(f"Found {len(schemas)} tables in dataset {BIGQUERY_PROJECT}:{BIGQUERY_DATASET}")#import csv
Storing the schemas in Vector Db such as Chroma DB. We need to create a folder called “data”
from langchain.embeddings import VertexAIEmbeddings
from langchain.vectorstores import Chroma
embeddings = VertexAIEmbeddings()
try: # Avoid duplicated documents
vector_store.delete_collection()
except:
print("No need to clean the vector store")
vector_store = Chroma.from_texts(schemas, embedding=embeddings,persist_directory='./data')
n_docs = len(vector_store.get()['ids'])
retriever = vector_store.as_retriever(search_kwargs={'k': 2})
print(f"The vector store has {n_docs} documents")
We will instantiate the 3 LLM models for the 3 different chains.
First model is Query Model which is responsible for generating SQL query based on user question and table schema retrieved from vector db similar to user question. For this we are using “codechat-bison” model . This model specializes in generating code in different coding languages and hence, is appropriate for our use case.
Other 2 models are default LLM models in ChatVertexAI which is “gemini-1.5-flash-001” this is the latest gemini model optimized for chat and quick response.
from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI
query_model = ChatVertexAI(model_name="codechat-bison", max_output_tokens=1000)
interpret_data_model = ChatVertexAI(max_output_tokens=1000)
agent_model = ChatVertexAI(max_output_tokens=1024)
Below is the SQL prompt used to generate the SQL query for the input user question.
SQL_PROMPT = """You are a SQL and BigQuery expert.
Your job is to create a query for BigQuery in SQL.
The following paragraph contains the schema of the table used for a query. It is encoded in JSON format.
{context}
Create a BigQuery SQL query for the following user input, using the above table.
And Use only columns mentioned in schema for the SQL query
The user and the agent have done this conversation so far:
{chat_history}
Follow these restrictions strictly:
- Only return the SQL code.
- Do not add backticks or any markup. Only write the query as output. NOTHING ELSE.
- In FROM, always use the full table path, using `{project}` as project and `{dataset}` as dataset.
- Always transform country names to full uppercase. For instance, if the country is Japan, you should use JAPAN in the query.
User input: {question}
SQL query:
"""
Now we will define a function which will retrieve relevant documents i.e schemas for the user question input.
from langchain.schema.vectorstore import VectorStoreRetriever
def get_documents(retriever: VectorStoreRetriever, question: str) -> str:
# Return only the first document
output = ""
for d in retriever.get_relevant_documents(question):
output += d.page_content
output += "\n"
return output
Then we define the LLM chain using Langchain expression language syntax. Note we define prompt with 5 placeholder variables and later we define a partial prompt by filling in the 2 placeholder variables project and dataset. The rest of the variables will get populated with incoming request dictionary consisting of input, chat history and the context variable is populated form the function we defined above get_documents.
from operator import itemgetter
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser
prompt_template = PromptTemplate(
input_variables=["context", "chat_history", "question", "project", "dataset"],
template=SQL_PROMPT)
partial_prompt = prompt_template.partial(project=BIGQUERY_PROJECT,
dataset=BIGQUERY_DATASET)
# Input will be like {"input": "SOME_QUESTION", "chat_history": "HISTORY"}
docs = {"context": lambda x: get_documents(retriever, x['input'])}
question = {"question": itemgetter("input")}
chat_history = {"chat_history": itemgetter("chat_history")}
query_chain = docs | question | chat_history | partial_prompt | query_model
query = query_chain | StrOutputParser()
Let us test our chain using CallBack Handler of Langchain which will show each steps of chain execution in detail.
from langchain.callbacks.tracers import ConsoleCallbackHandler
# Example
x = {"input": "Highest duration of trip where start station was from Atlantic Ave & Fort Greene Pl ", "chat_history": ""}
print(query.invoke(x, config={'callbacks': [ConsoleCallbackHandler()]}))
We need to refine the above sql chain output so that it will include other variables too in its outp which will be then passed on to second chain – interpret chain.
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.schema.runnable import RunnableLambda
#Refine the Chain output to include other variables in output in dictionary format
def _dict_to_json(x: dict) -> str:
return "```\n" + json.dumps(x) + "\n```"
query_response_schema = [
ResponseSchema(name="query", description="SQL query to solve the user question."),
ResponseSchema(name="question", description="Question asked by the user."),
ResponseSchema(name="context", description="Documents retrieved from the vector store.")
]
query_output_parser = StructuredOutputParser.from_response_schemas(query_response_schema)
query_output_json = docs | question | {"query": query} | RunnableLambda(_dict_to_json) | StrOutputParser()
query_output = query_output_json | query_output_parser
Lets try to execute this chain.
# Example
x = {"input": "Give me top 2 start stations where trip duration was highest?", "chat_history": ""}
output = query_output.invoke(x) # Output is now a dictionary, input for the next chain
Above we can see the output of the refined chain is an sql query.
Now we have to build the next chain which will take output of SQL query chain defined above. This chain will take the sql query from previous chain and run it in Big Query and its results are then used to generate a response using appropriate prompt.
INTERPRET_PROMPT = """You are a BigQuery expert. You are also expert in extracting data from CSV.
The following paragraph describes the schema of the table used for a query. It is encoded in JSON format.
{context}
A user asked this question:
{question}
To find the answer, the following SQL query was run in BigQuery:
```
{query}
```
The result of that query was the following table in CSV format:
```
{result}
```
Based on those results, provide a brief answer to the user question.
Follow these restrictions strictly:
- Do not add any explanation about how the answer is obtained, just write the answer.
- Extract any value related to the answer only from the result of the query. Do not use any other data source.
- Just write the answer, omit the question from your answer, this is a chat, just provide the answer.
- If you cannot find the answer in the result, do not make up any data, just say "I cannot find the answer"
"""
from google.cloud import bigquery
def get_bq_csv(bq_client: bigquery.Client, query: str) -> str:
cleaned_query = clean_query(query)
df = bq_client.query(cleaned_query, location="US").to_dataframe()
return df.to_csv(index=False)
def clean_query(query: str):
query = query.replace("```sql","")
cleaned_query = query.replace("```","")
return cleaned_query
We will define two function one is clean_query – this will clean the sql query of apostrophes and other unnecessary characters and other is get_bq_csv – this will run the cleaned sql query in Big Query and get the output table in csv format.
# Get the output of the previous chain
query = {"query": itemgetter("query")}
context = {"context": itemgetter("context")}
question = {"question": itemgetter("question")}
#cleaned_query = {"result": lambda x: clean_query(x["query"])}
query_result = {"result": lambda x: get_bq_csv(bq_client, x["query"])}
prompt = PromptTemplate(
input_variables=["question", "query", "result", "context"],
template=INTERPRET_PROMPT)
run_bq_chain = context | question | query | query_result | prompt
run_bq_result = run_bq_chain | interpret_data_model | StrOutputParser()
Let’s execute the chain and test it.
# Example
x = {"input": "Give me top 2 start stations where trip duration was highest?", "chat_history": ""}
final_response = run_bq_result.invoke(query_output.invoke(x))
print(final_response)
Now we will build the final chain which is the Agent Chain . When a user asks a question, it decides whether to utilise the SQL query tool or to answer it directly. Basically, it sends user queries to various tools according on the work that must be completed in order to answer the user’s inquiry.
We define an agent_memory, agent prompt, tool funtion.
from langchain.memory import ConversationBufferWindowMemory
agent_memory = ConversationBufferWindowMemory(
memory_key="chat_history",
k=10,
return_messages=True)
AGENT_PROMPT = """You are a very powerful assistant that can answer questions using BigQuery.
You can invoke the tool user_question_tool to answer questions using BigQuery.
Always use the tools to try to answer the questions. Use the chat history for context. Never try to use any other external information.
Assume that the user may write with misspellings, fix the spelling of the user before passing the question to any tool.
Don't mention what tool you have used in your answer.
"""
from langchain.tools import tool
from langchain.callbacks.tracers import ConsoleCallbackHandler
@tool
def user_question_tool(question) -> str:
"""Useful to answer natural language questions from users using BigQuery."""
config={'callbacks': [ConsoleCallbackHandler()]}
config = {}
memory = agent_memory.buffer_as_str.strip()
question = {"input": question, "chat_history": memory}
query = query_output.invoke(question, config=config)
print("\n\n******************\n\n")
print(query['query'])
print("\n\n******************\n\n")
result = run_bq_result.invoke(query, config=config)
return result.strip()
We now bring together all the main components of agent and initialize the agent.
from langchain.agents import AgentType, initialize_agent, AgentExecutor
agent_kwgards = {"system_message": AGENT_PROMPT}
agent_tools = [user_question_tool]
agent_memory.clear()
agent = initialize_agent(
tools=agent_tools,
llm=agent_model,
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
memory=agent_memory,
agent_kwgards=agent_kwgards,
max_iterations=5,
early_stopping_method='generate',
verbose=True)
Lets run the agent now.
q = "Give me top 2 start stations where trip duration was highest?"
agent.invoke(q)
Follow-up question to agent.
q = "What is the capacity for each of these station name?"
agent.invoke(q)
The agent was accurately able to process the complex question and also generate correct answers for follow -up question based on chat history and then it utilised another table to get capacity information of citi bikes.
The RAG-to-SQL approach represents a significant advancement in addressing the limitations of traditional Text-to-SQL models by incorporating contextual data and leveraging retrieval techniques. This methodology enhances query accuracy by retrieving relevant schema information from vector databases, allowing for more precise SQL generation. Implementing RAG-to-SQL within Google Cloud services like BigQuery and Vertex AI demonstrates its scalability and effectiveness in real-world applications. By automating the decision-making process in query handling, RAG-to-SQL opens new possibilities for non-technical users to interact seamlessly with databases while maintaining high precision.
A. No, but you can get a trial period of 90 days with 300$ credits if you register first time and you only need to provide a card details for getting access. No charges are deducted from card and even if you use any services which is consuming beyond 300$ credits then Google will ask you to enable payment account so that you can use the service. So there is no automatic deduction of amount.
A. This allows us to automate the table schema which is to be fed to the LLM if we are using multiple tables we don’t need to feed all table schemas at once . Based on user query the relevant table schema can be fetched from the RAG. Thus, increasing efficiency over conventional Text to SQL systems.
A. If we are building a holistic chatbot it might require lot of other tools apart from SQL query tool . So we can leverage the agent and provide it with multiple tools such as web search , database sql query tool, other rag tools or function calling api tools. This will enable to handle different types of user queries based on the task that needs to be accomplished to respond to the user query.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.