In this article we will learn to enhance RAG performance with CRAG. The word RAG has been floating around for a while and for all the good reasons. Large language models made it possible to build solutions for problems that were difficult before. Question and Answering over large amounts of data was one such problem. But now, it is possible. Thanks to LLMs, AI frameworks, and other tools such as vector databases.
Instead of only matching keywords and metadata to find similar texts, we can use cosine similarity between texts to retrieve relevant matches. And use the matched text chunks to generate a coherent answer from an LLM. This method is called RAG(Retrieval Augmented Generation). But is vector retrieval always sufficient? Can we rely on RAG when the retrieved object does not have answers to the question? This is where CRAG, or Corrective Retrieval Augmented Generation, comes into the picture.
This article was published as a part of the Data Science Blogathon.
RAG has been great for questioning and answering over-text documents. It is a straightforward process. We extract the contents from documents, pre-process them, find embeddings, and store them in a vector database. We then compute the similarity score between the queries and text documents to find the most semantically similar text chunks. These chunks are then fed to an LLM to generate a human-readable answer.
This is simple yet effective for most use cases. However, it is not always effective. Finding relevant documents using just cosine similarity may not always be ideal. Throwing in top k text chunks to generate an answer may not be a good idea where the cost of false information is high.
To mitigate this, the primary knowledge sources can be supplemented with external sources like the web. It has been observed that web access can enhance the LLM capability for QA. Much of the success of Bard(Gemini Pro) and Perplexity AI is due to web integration with LLMs.
Observe the performance gap between Gemini Pro with web and vanilla Gemini Pro in the LMSys chatbot leaderboard.
The Corrective RAG is based on the same principle. It introduces the internet as a third source of knowledge, supplementing primary knowledge bases. So, let’s understand how it works.
The word corrective in CRAG stands for a corrective module in the existing RAG pipeline. This corrective module is responsible for correcting the wrong retrieval results. The idea was proposed in the paper Corrective Retrieval Augmented Generation. The paper describes how to build a CRAG system with all the benchmarks. So, let’s see the fundamental architecture of CRAG.
As you can observe, there are three new additions to a conventional RAG architecture: an evaluator, knowledge refinement, and knowledge searching.
The evaluator is a language model responsible for classifying a retrieved text as correct, incorrect, or ambiguous. The authors have used a fine-tuned T5 large model as the evaluator, but any LLM can be used. The LLM is queried with the question and a retrieved text chunk to validate if the chunk is relevant or not. The texts are then classified as correct, incorrect, or ambiguous. The accuracy of the evaluator plays a crucial role here.
Once the chunks are classified as correct, they undergo further pruning for a refined source of knowledge. The text chunks are decomposed into small knowledge strips(1-2 sentences), and an evaluator is used again to filter out irrelevant strips. The final strips are rejoined again and sent to the LLM for answer generation.
This is applied when a chunk is classified as either ambiguous or incorrect. When a chunk is found to be irrelevant, we discard it and use a web search API to find relevant outcomes from the internet. So, instead of using the incorrect chunks, we use the sources from the internet for final answer generation.
However, in case of ambiguity, we apply both the knowledge refinement and search. The irrelevant strips are weeded out, and new information from the internet is added. Final concatenated chunks are sent to the LLM for answer generation.
This approach of using an evaluator, knowledge refinement, and search can significantly improve the RAG performance of QA systems.
Now that we understand the concepts behind CRAG let’s implement them with LangGraph.
LangGraph is an extension of the LangChain ecosystem. LangGraph allows us to build AI apps, including agents and RAG, as a graph. It treats the workflows as a cyclic Graph structure, where each node represents a function or a Langchain Runnable object, and edges are connections between nodes. It also provides a stateful solution where a global state object can be shared among nodes.
LangGraph’s main features include:
LangGraph leverages this to facilitate a cyclic LLM call execution with state persistence, which is crucial for agentic behavior. The architecture derives inspiration from Pregel and Apache Beam.
We will use the LangGraph to build our Corrective RAG pipeline.
Let’s understand the structure of our pipeline. We will build a CRAG pipeline, but for brevity, instead of using three evaluator classes, we will only use two. A chunk is either relevant or irrelevant. As the evaluator, we will use Mixtral 8x7b from Together AI. You can use a re-ranker like Cohere re-rank as the evaluator. The Cohere re-ranker outputs relevant documents and their relevancy score in decreasing order. This can be used to classify documents with some thresholds for each category.
We will use the Tavily search API for web searching for irrelevant chunks. Get APIs of both Together and Tavily before moving ahead. Also, the same Mixtral model will be used as the final LLM for answer generation. You can use other LLMs like Gemini, GPTs, Mistral medium, etc.
This is our workflow.
Create a Python virtual environment and install the following libraries.
! pip install --quiet langchain_community langchain-openai langchainhub chromadb \
langchain langgraph tavily-python sentence-transformers
Now, set up API keys for Together and Tavily as environment variables.
import os
os.environ["TOGETHER_API_KEY"] = "Your Key"
os.environ["TAVILY_API_KEY"] = "Your Key"
Import the libraries.
import json
import operator
from typing import Annotated, Sequence, TypedDict
from langchain import hub
from langchain_core.output_parsers import JsonOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
In this step, we will use one of my blog posts as the document and use LangChain’s tool for loading texts from the web page. We will use LangChain’s recursive text splitter to split documents and index them in a Chroma database. We use the BAAI/bge-base-en-v1.5 from the sentence transformers library as the embedding model. You can use any other model you wish.
# Load
url = "https://www.analyticsvidhya.com/blog/2023/10/introduction-to-hnsw-hierarchical-/
navigable-small-world/"
loader = WebBaseLoader(url)
docs = loader.load()
# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=100
)
all_splits = text_splitter.split_documents(docs)
# Embed and index
embedding = SentenceTransformerEmbeddings(model_name="BAAI/bge-base-en-v1.5")
# Index
vectorstore = Chroma.from_documents(
documents=all_splits,
collection_name="rag-chroma",
embedding=embedding,
)
retriever = vectorstore.as_retriever()
Define the LLM you will use. As discussed before, we will use a fine-tuned version of Mixtral from Nous Labs with TogetherAI.
TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY")
llm = ChatOpenAI(base_url="https://api.together.xyz/v1",
api_key=TOGETHER_API_KEY,
model = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
As Together API supports OpenAI SDK, all that changed was the base URL, API_KEY, and model name.
As mentioned earlier, LangGraph implements a graph structure for building applications on top of it. Also, we know it lets us use a state object for sharing data between nodes. So, let’s define the state class.
from typing import Annotated, Dict, TypedDict
from langchain_core.messages import BaseMessage
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
keys: A dictionary where each key is a string.
"""
keys: Dict[str, any]
The GraphState is a TypedDict class with a single attribute “key”, it is a dictionary that will store all the downstream data that we will need after each node.
We will now create the first node of our graph structure. As we know, the nodes in LangGraph are any functions or tools. The first node of our pipeline will be the retriever, responsible for retrieving documents from vector data.
def retrieve(state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = retriever.get_relevant_documents(question)
return {"keys": {"documents": documents, "question": question}}
The next node we will work on is for grading. We will use the LLM defined earlier to grade each chunk as “yes” or “no.” If a chunk is irrelevant, we will set a state key “search” as True.
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with relevant documents
"""
print("---CHECK RELEVANCE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
prompt = PromptTemplate(
template="""You are a grader assessing the relevance of a retrieved
document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keywords related to the user question,
grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out
erroneous retrievals. \n
Give a binary score of 'yes' or 'no' score to indicate whether the document
is relevant to the question. \n
Provide the binary score as a JSON with a single key 'score' and no preamble
or explanation.
""",
input_variables=["question", "context"],
)
chain = prompt | llm | JsonOutputParser()
# Score
filtered_docs = []
search = "No" # Default does not opt for web search to supplement retrieval
for d in documents:
score = chain.invoke(
{
"question": question,
"context": d.page_content,
}
)
grade = score["score"]
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
search = "Yes" # Perform web search
continue
return {
"keys": {
"documents": filtered_docs,
"question": question,
"run_web_search": search,
}
}
In the above code, the chain was defined using Langchain Query Language, which means the prompt was passed to the LLM, and subsequently, the LLM outcome was passed to a JSON output parser.
The queries need to be re-written before sending it to the search API. This is done to increase the chances of better web search results.
def transform_query(state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
print("---TRANSFORM QUERY---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
# Create a prompt template with format instructions and the query
prompt = PromptTemplate(
template="""You are generating questions that is well optimized for retrieval. \n
Look at the input and try to reason about the underlying sematic intent / meaning. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Provide an improved question without any premable, only respond with the
updated question: """,
input_variables=["question"],
)
# Prompt
chain = prompt | llm | StrOutputParser()
better_question = chain.invoke({"question": question})
return {
"keys": {"documents": documents, "question": better_question,}
}
In this node, we will define a function that uses the Tavily API to fetch the top K results from a web search. The search results are concatenated and appended to the documents list before being sent to the generation node.
def web_search(state):
"""
Web search based on the re-phrased question using Tavily API.
Args:
state (dict): The current graph state
Returns:
state (dict): Web results appended to documents.
"""
print("---WEB SEARCH---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
tool = TavilySearchResults()
docs = tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
print(web_results)
documents.append(web_results)
return {"keys": {"documents": documents, "question": question}}
In this node, the documents are sent to the LLM along with the query, and the output is added to the state dictionary.
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains generation
"""
print("---GENERATE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
# Prompt
prompt = hub.pull("rlm/rag-prompt")
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
rag_chain = prompt | llm | StrOutputParser()
# Run
generation = rag_chain.invoke({"context": documents, "question": question})
return {
"keys": {"documents": documents, "question": question, "generation": generation}
}
We have defined all the nodes that we need. Now, we can define the workflow and add nodes to it.
import pprint
from langgraph.graph import END, StateGraph
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query
workflow.add_node("web_search", web_search) # web search
We are done with nodes now; we need to define the edges. The edges signal the direction of workflows. In LangGraph, there are two types of edges.
In our case, we need a conditional edge between the grading node and the generation node. If the documents are relevant, we run the generation node else, the transform query node.
def decide_to_generate(state):
"""
Determines whether to generate an answer or re-generate a question for web search.
Args:
state (dict): The current state of the agent, including all keys.
Returns:
str: Next node to call
"""
print("---DECIDE TO GENERATE---")
state_dict = state["keys"]
question = state_dict["question"]
filtered_documents = state_dict["documents"]
search = state_dict["run_web_search"]
if search == "Yes":
# All documents have been filtered check_relevance
# We will re-generate a new query
print("---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---")
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
Now connect the respective nodes and set the entry point. This is the node from where the workflow starts.
# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)
Finally, compile the workflow and run it by passing a query.
# Compile
app = workflow.compile()
# Run
inputs = {
"keys": {
"question": "Who is the author of the HNSW paper?",
}
}
for output in app.stream(inputs):
for key, value in output.items():
# Node
pprint.pprint(f"Node '{key}':")
pprint.pprint("\n---\n")
# Final generation
pprint.pprint(value["keys"]["generation"])
The article does not directly mention the Author of the HNSW paper. Hence, the retriever could not retrieve any relevant text chunks from the vector store. But this is a trivial question, and the RAG would have failed to address it. However, with CRAG, this was not a problem as we could search the web in case of irrelevant documents.
The implementation of CRAG presents a pivotal enhancement to RAG, effectively addressing its inherent gaps by incorporating the internet as a third knowledge source. This article thoroughly explores CRAG and its implementation, offering valuable insights into how this augmentation fortifies the conventional RAG pipeline. Through this examination, we highlight key takeaways for optimizing knowledge augmentation, demonstrating how CRAG significantly boosts RAG performance with its internet integration.
A. LangGraph is an open-source library for building stateful cyclic multi-actor agent systems. It is built on top of the LangChain eco-system.
A. RAG stands for Retrieval Augmented Generation. In RAG, the documents are split and stored in a vector database. These documents are then matched with embeddings of user queries, and top-k retrieved chunks are sent to an LLM for answer generation.
A. Corrective RAG uses an evaluator LLM to distill relevant documents from all the retrieved documents and, if needed, uses external knowledge sources to supplant answer generation.
A. LangGraph is preferred for building cyclic multi-actor agents, while LangChain is better at creating chains or directed acyclic systems.
A. A RAG pipeline retrieves documents from external data stores, processes them to store them in a knowledge base, and provides tools to query them.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.