A Comprehensive Guide to RAG-to-SQL on Google Cloud

Ritika Last Updated : 18 Oct, 2024
11 min read

Text-to-SQL technologies frequently struggle to capture the complete context and meaning of a user’s request, resulting in queries that do not exactly match the intended. While developers work hard to enhance these systems, it is worth questioning if there is a better method.

Enter RAG-to-SQL—a new approach that combines natural language understanding with powerful data retrieval to generate accurate SQL queries. By blending the best of natural language processing and information retrieval, RAG-to-SQL offers a more reliable way to turn everyday language into meaningful insights from your database.

In this article, we’ll explore how RAG-to-SQL can transform the way we interact with databases, especially using Google Cloud services such as BigQuery and Vertex AI. 

Learning Objectives

  • Identify the limitations of Text-to-SQL systems in accurately capturing user intent.
  • Understand the advantages of RAG-to-SQL as a new paradigm for generating more reliable SQL queries.
  • Implement the RAG-to-SQL approach using Google Cloud services like BigQuery and Vertex AI.
  • Learn how to integrate and utilize multiple Google Cloud tools for RAG-to-SQL implementation.

This article was published as a part of the Data Science Blogathon.

Limitations of Traditional Text-to-SQL Approaches

The main idea behind text to sql models of LLM was to enable people who do not know about SQL to interact with database and gain information using natural language instead. The existing text 2 sql framework relies mainly in LLM knowledge to be able to convert natural language query to sql query. This can lead to wrong or invalid formulation of SQL queries. This is where the new approach RAG to SQL comes to our rescue which is explained in next section.

What is RAG-to-SQL ?

In order to overcome the shortcomings of Text to SQL we can use the innovative approach of RAG to SQL. The integration of domain information about the database is the major issue that each text-to-SQL software faces. The RAG2SQL architecture addresses this difficulty by adding contextual data (metadata, DDL, queries, and more). This data is then “trained” and made available for usage.
Furthermore, the “retriever” evaluates and forwards the most relevant context to respond to the User Query. The end result is greatly improved precision.

Setting Up RAG-to-SQL with Google Cloud: A Step-by-Step Guide

Follow a detailed guide to implement RAG-to-SQL using Google Cloud services such as BigQuery and Vertex AI.

Pre-requisites for Code 

In order to follow and run this code you will need to setup your GCP (google cloud account with Payment information). Initially they provide free 300$ trial for 90 days so no charges will be incurred. Detail for account setup  : Link

Code Flowchart

Below is the code flowchart which describes at a higher level the various blocks of code. We can refer it to follow along as we proceed.

RAG to SQL

The code implementation can be divided into 3 major blocks :

  • SQL Query Chain: This chain is responsible for generating appropriate sql query based on user question and relevant schema of table fetched from Vector DB.
  • Interpret Chain: This chain takes the SQL query from the previous chain, runs it in BigQuery, and then uses the results to generate a response with an appropriate prompt.
  • Agent Chain: This is the final chain which encapsulates the above two chains. Whenever a user question comes in it will decide whether to call sql query tool or answer the question directly. It routes user queries to various tools based on the task required to answer the question.

 Step 1: Installing the Required Libraries

In colab notebook we have to install the below libraries required for this implementation.

! pip install langchain==0.0.340 --quiet
! pip install chromadb==0.4.13 --quiet
! pip install google-cloud-bigquery[pandas] --quiet
! pip install google-cloud-aiplatform --quiet#import csv

Step 2: Configuring Your Google Cloud Project and Credentials

Now we have to declare some variables to initialise our GCP project and Big Query Datasets . Using this variables we can access the tables in Big Query withing GCP in our notebook.

You can view this details in your GCP cloud console. In BigQuery you can create a dataset and within dataset you can add or upload a table for details see Create Dataset and Create Table.

VERTEX_PROJECT = "Your GCP Project ID" # @param{type: "string"}
VERTEX_REGION = "us-central1" # @param{type: "string"}

BIGQUERY_DATASET = "Big Query Dataset Name" # @param{type: "string"}
BIGQUERY_PROJECT = "Vertex Project ID" # @param{type: "string"}

Now authenticate and login to your GCP Vertex AI from your notebook using below code in colab.

from google.colab import auth
auth.authenticate_user()

import vertexai
vertexai.init(project=VERTEX_PROJECT, location=VERTEX_REGION)

Step 3: Building a Vector Database for Table Schema Storage

Now we have to create a vector db which will contain schema of various tables present in our dataset and we will create a retriever on top of this vector db so that we can incorporate RAG in our workflow.

Connecting to Big Query using BQ client in python and fetching schema of tables.

from google.cloud import bigquery
import json

#Fetching Schemas of Tables

bq_client = bigquery.Client(project=VERTEX_PROJECT)
bq_tables = bq_client.list_tables(dataset=f"{BIGQUERY_PROJECT}.{BIGQUERY_DATASET}")
schemas = []
for bq_table in bq_tables:
   t = bq_client.get_table(f"{BIGQUERY_PROJECT}.{BIGQUERY_DATASET}.{bq_table.table_id}")
   schema_fields = [f.to_api_repr() for f in t.schema]
   schema = f"The schema for table {bq_table.table_id} is the following: \n```{json.dumps(schema_fields, indent=1)}```"
   schemas.append(schema)

print(f"Found {len(schemas)} tables in dataset {BIGQUERY_PROJECT}:{BIGQUERY_DATASET}")#import csv

Storing the schemas in Vector Db such as Chroma DB. We need to create a folder called “data”

from langchain.embeddings import VertexAIEmbeddings
from langchain.vectorstores import Chroma

embeddings = VertexAIEmbeddings()
try: # Avoid duplicated documents
  vector_store.delete_collection()
except:
  print("No need to clean the vector store")
vector_store = Chroma.from_texts(schemas, embedding=embeddings,persist_directory='./data')
n_docs = len(vector_store.get()['ids'])
retriever = vector_store.as_retriever(search_kwargs={'k': 2})
print(f"The vector store has {n_docs} documents")

Step 4: Instantiating LLM Models for SQL Query, Interpretation, and Agent Chains

We will instantiate the 3 LLM models for the 3 different chains.

First model is Query Model which is responsible for generating SQL query based on user question and table schema retrieved from vector db similar to user question. For this we are using “codechat-bison”   model . This model specializes in generating code in different coding languages and hence, is appropriate for our use case.

Other 2 models are default LLM models in ChatVertexAI which is “gemini-1.5-flash-001” this is the latest gemini model optimized for chat and quick response.

from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI

query_model = ChatVertexAI(model_name="codechat-bison", max_output_tokens=1000)
interpret_data_model = ChatVertexAI(max_output_tokens=1000)
agent_model = ChatVertexAI(max_output_tokens=1024)

Step 5: Constructing the SQL Query Chain

Below is the SQL prompt used to generate the SQL query for the input user question.

SQL_PROMPT = """You are a SQL and BigQuery expert.

Your job is to create a query for BigQuery in SQL.

The following paragraph contains the schema of the table used for a query. It is encoded in JSON format.

{context}

Create a BigQuery SQL query for the following user input, using the above table.
And Use only columns mentioned in schema for the SQL query

The user and the agent have done this conversation so far:
{chat_history}

Follow these restrictions strictly:
- Only return the SQL code.
- Do not add backticks or any markup. Only write the query as output. NOTHING ELSE.
- In FROM, always use the full table path, using `{project}` as project and `{dataset}` as dataset.
- Always transform country names to full uppercase. For instance, if the country is Japan, you should use JAPAN in the query.

User input: {question}

SQL query:
"""

Now we will define a function which will retrieve relevant documents i.e schemas for the user question input.

from langchain.schema.vectorstore import VectorStoreRetriever
def get_documents(retriever: VectorStoreRetriever, question: str) -> str:
  # Return only the first document
  output = ""
  for d in retriever.get_relevant_documents(question):
    output += d.page_content
    output += "\n"
    return output

Then we define the LLM chain using Langchain expression language syntax. Note we define prompt with 5 placeholder variables and later we define a partial prompt by filling in the 2 placeholder variables project and dataset.  The rest of the variables will get populated with incoming request dictionary consisting of input, chat history and the context  variable is populated form the function we defined above get_documents.

from operator import itemgetter
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser

prompt_template = PromptTemplate(
    input_variables=["context", "chat_history", "question", "project", "dataset"],
    template=SQL_PROMPT)

partial_prompt = prompt_template.partial(project=BIGQUERY_PROJECT,
                                         dataset=BIGQUERY_DATASET)

# Input will be like {"input": "SOME_QUESTION", "chat_history": "HISTORY"}
docs = {"context": lambda x: get_documents(retriever, x['input'])}
question = {"question": itemgetter("input")}
chat_history = {"chat_history": itemgetter("chat_history")}
query_chain = docs | question | chat_history | partial_prompt | query_model
query = query_chain | StrOutputParser()

Let us test our chain using CallBack Handler of Langchain which will show each steps of chain execution in detail.

from langchain.callbacks.tracers import ConsoleCallbackHandler
# Example
x = {"input": "Highest duration of trip where start station was from Atlantic Ave & Fort Greene Pl ", "chat_history": ""}
print(query.invoke(x, config={'callbacks': [ConsoleCallbackHandler()]}))
Output of Chain Execution
Output of Chain Execution
Final SQL query output
Final SQL query output

Step 6: Refining the SQL Chain Output for Interpretation

We need to refine the above sql chain output so that it will include other variables too in its outp which will be then passed on to second chain – interpret chain.

from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.schema.runnable import RunnableLambda

#Refine the Chain output to include other variables in output in dictionary format
def _dict_to_json(x: dict) -> str:
  return "```\n" + json.dumps(x) + "\n```"

query_response_schema = [
    ResponseSchema(name="query", description="SQL query to solve the user question."),
    ResponseSchema(name="question", description="Question asked by the user."),
    ResponseSchema(name="context", description="Documents retrieved from the vector store.")
  ]
query_output_parser = StructuredOutputParser.from_response_schemas(query_response_schema)
query_output_json = docs | question | {"query": query} | RunnableLambda(_dict_to_json) | StrOutputParser()
query_output = query_output_json | query_output_parser

Lets try to execute this chain.

# Example
x = {"input": "Give me top 2 start stations where trip duration was highest?", "chat_history": ""}
output = query_output.invoke(x)  # Output is now a dictionary, input for the next chain
Output of Refined SQL Chain
Output of Refined SQL Chain

Above we can see the output of the refined chain is an sql query.

Step 7: Building the Interpret Chain for Query Results

Now we have to build the next chain which will take output of SQL query chain defined above. This chain will take the sql query from previous chain and run it in Big Query and its results are then used to generate a response using appropriate prompt.

INTERPRET_PROMPT = """You are a BigQuery expert. You are also expert in extracting data from CSV.

The following paragraph describes the schema of the table used for a query. It is encoded in JSON format.

{context}

A user asked this question:
{question}

To find the answer, the following SQL query was run in BigQuery:
```
{query}
```

The result of that query was the following table in CSV format:
```
{result}
```

Based on those results, provide a brief answer to the user question.

Follow these restrictions strictly:
- Do not add any explanation about how the answer is obtained, just write the answer.
- Extract any value related to the answer only from the result of the query. Do not use any other data source.
- Just write the answer, omit the question from your answer, this is a chat, just provide the answer.
- If you cannot find the answer in the result, do not make up any data, just say "I cannot find the answer"
"""
from google.cloud import bigquery
def get_bq_csv(bq_client: bigquery.Client, query: str) -> str:
  cleaned_query = clean_query(query)
  df = bq_client.query(cleaned_query, location="US").to_dataframe()
  return df.to_csv(index=False)


def clean_query(query: str):
  query = query.replace("```sql","")
  cleaned_query = query.replace("```","")

  return cleaned_query

We will define two function one is clean_query – this will clean the sql query of apostrophes and other unnecessary characters  and other is get_bq_csv –  this will run the cleaned sql query in Big Query and get the output table in csv format.

# Get the output of the previous chain


query = {"query": itemgetter("query")}
context = {"context": itemgetter("context")}
question = {"question": itemgetter("question")}
#cleaned_query = {"result": lambda x: clean_query(x["query"])}
query_result = {"result": lambda x: get_bq_csv(bq_client, x["query"])}

prompt = PromptTemplate(
    input_variables=["question", "query", "result", "context"],
    template=INTERPRET_PROMPT)

run_bq_chain = context | question | query | query_result | prompt
run_bq_result = run_bq_chain | interpret_data_model | StrOutputParser()

Let’s execute the chain and test it.

# Example
x = {"input": "Give me top 2 start stations where trip duration was highest?", "chat_history": ""}
final_response = run_bq_result.invoke(query_output.invoke(x))
print(final_response)
 output of interpret chain
output of interpret chain

Step 8: Implementing the Agent Chain for Dynamic Query Routing

Now we will build the final chain which is the Agent Chain . When a user asks a question, it decides whether to utilise the SQL query tool or to answer it directly. Basically, it sends user queries to various tools according on the work that must be completed in order to answer the user’s inquiry.

We define an agent_memory, agent prompt, tool funtion.

from langchain.memory import ConversationBufferWindowMemory

agent_memory = ConversationBufferWindowMemory(
    memory_key="chat_history",
    k=10,
    return_messages=True)
AGENT_PROMPT = """You are a very powerful assistant that can answer questions using BigQuery.

You can invoke the tool user_question_tool to answer questions using BigQuery.

Always use the tools to try to answer the questions. Use the chat history for context. Never try to use any other external information.

Assume that the user may write with misspellings, fix the spelling of the user before passing the question to any tool.

Don't mention what tool you have used in your answer.
"""
from langchain.tools import tool
from langchain.callbacks.tracers import ConsoleCallbackHandler

@tool
def user_question_tool(question) -> str:
  """Useful to answer natural language questions from users using BigQuery."""
  config={'callbacks': [ConsoleCallbackHandler()]}
  config = {}
  memory = agent_memory.buffer_as_str.strip()
  question = {"input": question, "chat_history": memory}
  query = query_output.invoke(question, config=config)
  print("\n\n******************\n\n")
  print(query['query'])
  print("\n\n******************\n\n")
  result = run_bq_result.invoke(query, config=config)
  return result.strip()

We now bring together all the main components of agent and initialize the agent.

from langchain.agents import AgentType, initialize_agent, AgentExecutor

agent_kwgards = {"system_message": AGENT_PROMPT}
agent_tools = [user_question_tool]

agent_memory.clear()

agent = initialize_agent(
    tools=agent_tools,
    llm=agent_model,
    agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
    memory=agent_memory,
    agent_kwgards=agent_kwgards,
    max_iterations=5,
    early_stopping_method='generate',
    verbose=True)

Lets run the agent now.

q = "Give me top 2 start stations where trip duration was highest?"
agent.invoke(q)
Output
output

Follow-up question to agent.

q = "What is the capacity for  each of these station name?"
agent.invoke(q)
Follow-up question to agent.
Final output for follow-up question

Observations

The agent was accurately able to process the complex question and also generate correct answers for follow -up question based on chat history and then it utilised another table to get capacity information of citi bikes.

Conclusion

The RAG-to-SQL approach represents a significant advancement in addressing the limitations of traditional Text-to-SQL models by incorporating contextual data and leveraging retrieval techniques. This methodology enhances query accuracy by retrieving relevant schema information from vector databases, allowing for more precise SQL generation. Implementing RAG-to-SQL within Google Cloud services like BigQuery and Vertex AI demonstrates its scalability and effectiveness in real-world applications. By automating the decision-making process in query handling, RAG-to-SQL opens new possibilities for non-technical users to interact seamlessly with databases while maintaining high precision.

Key Takeaways

  • Overcomes Text-to-SQL Limitations addresses the common pitfalls of traditional Text-to-SQL models by integrating metadata.
  • The agent-based system efficiently decides how to process user queries, improving usability.
  • RAG-to-SQL allows non-technical users to generate complex SQL queries with natural language inputs.
  • The approach is successfully implemented using services like BigQuery and Vertex AI.

Frequently Asked Questions

Q1. Is GCP Vertex AI access free?

A. No, but you can get a trial period of 90 days with 300$ credits if you register first time and you only need to provide a card details for getting access. No charges are deducted from card and even if you use any services which is consuming beyond 300$ credits then Google will ask you to enable payment account so that you can use the service. So there is no automatic deduction of amount.

Q2.  What is the key benefit of using Rag to SQL?

A. This allows us to automate the table schema which is to be fed to the LLM if we are using multiple tables we don’t need to feed all table schemas at once . Based on user query the relevant table schema can be fetched from the RAG. Thus, increasing efficiency over conventional Text to SQL systems.

Q3.  How can agents be useful for this use case?

A. If we are building a holistic chatbot it might require lot of other tools apart from SQL query tool . So we can leverage the agent and provide it with multiple tools such as  web search , database sql query tool, other rag tools or function calling api tools. This will enable to handle different types of user queries based on the task that needs to be accomplished to respond to the user query.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

I am a professional working as data scientist after finishing my MBA in Business Analytics and Finance. A keen learner who loves to explore and understand and simplify stuff! I am currently learning about advanced ML and NLP techniques and reading up on various topics related to it including research papers .

Responses From Readers

Clear

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details