The advent of LLMs has completely changed how we extract information from documents. However, we know images, usually charts and tables, often contain crucial information, but text-based LLMs cannot process media files. For example, we could only use the texts from a PDF file to find answers. But now, with the release of multi-modal LLMs from different labs, it is possible to extract information from images. Multi-modal models like GPT-4V and Gemini Pro Vision have shown great ability to infer data from images. We can use these models to augment a regular RAG pipeline to build a multi-modal RAG pipeline. So, in this article, we will create an MM-RAG pipeline using Gemini Pro models, Chroma vector Database, and Langchain.
Learning Objectives
This article was published as a part of the Data Science Blogathon.
The Langchain is one of the hottest tools of 2023. It is an open-source framework for building chains of tasks and LLM agents. It has almost all the tools you need to create a functional AI application. From data loaders and vector stores to LLMs from different labs, it got it all covered. Langchain has two core value propositions. Chains and Agents.
We will be using Langchain’s expression language to build the multi-modal pipeline. Another crucial aspect of RAG is a vector database. These databases are purpose-built for storing embeddings of data. Vector databases are built to handle millions of embeddings. So, whenever we need context-aware retrieval, the vector stores become implicit. To get embeddings of data, embedding models are used; these models have been trained over a large quantity of data to find the similarity of texts.
So, how do we build a RAG pipeline in a multi-modal context? There are three different ways we can create an MM-RAG pipeline.
Each of the options has its pros and cons.
Option 1 can be good for generic images but might struggle with charts and tables.
The second option is when you cannot use a multi-modal model frequently.
Option 3 for improved accuracy. MM-LLMs like GPT-4V or Gemini can understand charts and tables. Option 3 is the choice when it involves complex image understanding. But it will also cost more.
We will implement the 3rd approach with Gemini Pro Vision.
Now that we are aware of the concepts and tools required. Here is a quick workflow overview.
Let’s delve into coding the RAG pipeline.
PDFs and other data formats often have tables and pictures embedded in them. And it is not possible to extract them just as easily as texts. To achieve this, we need purpose-built tools like Unstructured. Unstructured is an open-source tool to pre-process images and files, like HTML, PDF, and Word docs. It can extract embedded images from files using OCR. Unstrcured requires Poppler and Tessearct to be installed in the system.
Install Tesseract and Poppler.
!sudo apt install tesseract-ocr
!sudo apt-get install poppler-utils
Now, we can install Unstructured along with other required libraries.
!pip install "unstructured[all-docs]" langchain langchain_community \
chromadb langchain-experimental
We use the partiton_pdf to extract images and tables.
from unstructured.partition.pdf import partition_pdf
image_path = "./"
pdf_elements = partition_pdf(
"mistral.pdf",
chunking_strategy="by_title",
extract_images_in_pdf=True,
max_characters=3000,
new_after_n_chars=2800,
combine_text_under_n_chars=2000,
image_output_dir_path=image_path
)
This will partition the PDF and extract images to the given path. It also chunks the texts by the strategy and the character limits provided.
Now, we separate texts and tables into different groups.
# Categorize elements by type
def categorize_elements(raw_pdf_elements):
"""
Categorize extracted elements from a PDF into tables and texts.
raw_pdf_elements: List of unstructured.documents.elements
"""
tables = []
texts = []
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
tables.append(str(element))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
texts.append(str(element))
return texts, tables
texts, tables = categorize_elements(pdf_elements)
We will use Gemini Pro to get short summaries of text chunks. Langchain has a summarizing chain for this purpose. We will use the expression language to build a simple summarizing chain. To use Gemini models with Langchain, you need to set up a GCP account. Enable VertexAI and configure the credentials.
from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda
# Generate summaries of text elements
def generate_text_summaries(texts, tables, summarize_texts=False):
"""
Summarize text elements
texts: List of str
tables: List of str
summarize_texts: Bool to summarize texts
"""
# Prompt
prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
These summaries will be embedded and used to retrieve the raw text or table elements. \
Give a concise summary of the table or text that is well-optimized for retrieval. Table \
or text: {element} """
prompt = PromptTemplate.from_template(prompt_text)
empty_response = RunnableLambda(
lambda x: AIMessage(content="Error processing document")
)
# Text summary chain
model = VertexAI(
temperature=0, model_name="gemini-pro", max_output_tokens=1024
).with_fallbacks([empty_response])
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Initialize empty summaries
text_summaries = []
table_summaries = []
# Apply to text if texts are provided and summarization is requested
if texts and summarize_texts:
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
elif texts:
text_summaries = texts
# Apply to tables if tables are provided
if tables:
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 1})
return text_summaries, table_summaries
# Get text, table summaries
text_summaries2, table_summaries = generate_text_summaries(
texts[9:], tables, summarize_texts=True
)
In the chain, we have four different components. The first is a dictionary, the second is for creating a prompt template, 3rd is the LLM model, and the final is a string parser. The outputs of each are used as the input for subsequent modules.
The AI Message class with runnable Lambda returns a message if it fails to query from the LLM.
As we discussed earlier, we will use a vision model to get text descriptions of images. You can use any vision model like GPT-4, Llava, Gemini, etc. Here, we will use Gemini Pro Vision.
To process images, we will convert them to base64 format and pass them to the Gemini Pro Vision model with a default prompt.
import base64
import os
from langchain_core.messages import HumanMessage
def encode_image(image_path):
"""Getting the base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_summarize(img_base64, prompt):
"""Make image summary"""
model = ChatVertexAI(model_name="gemini-pro-vision", max_output_tokens=1024)
msg = model(
[
HumanMessage(
content=[
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
},
]
)
]
)
return msg.content
def generate_img_summaries(path):
"""
Generate summaries and base64 encoded strings for images
path: Path to list of .jpg files extracted by Unstructured
"""
# Store base64 encoded images
img_base64_list = []
# Store image summaries
image_summaries = []
# Prompt
prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""
# Apply to images
for img_file in sorted(os.listdir(path)):
if img_file.endswith(".jpg"):
img_path = os.path.join(path, img_file)
base64_image = encode_image(img_path)
img_base64_list.append(base64_image)
image_summaries.append(image_summarize(base64_image, prompt))
return img_base64_list, image_summaries
fpath = "./"
# Image summaries
img_base64_list, image_summaries = generate_img_summaries(fpath)
As we discussed earlier, we will store embeddings of the image and table descriptions in a vector store and store the original documents in an in-memory document store. Then, we retrieve the original documents corresponding to the retrieved vectors from the vector store. Langchain has a multi-vector retriever to achieve this. So, here is how we can build a multi-vector retriever.
import uuid
from langchain.embeddings import VertexAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.storage import InMemoryStore
from langchain.vectorstores import Chroma
def create_multi_vector_retriever(
vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
"""
Create retriever that indexes summaries, but returns raw images or texts
"""
# Initialize the storage layer
store = InMemoryStore()
id_key = "doc_id"
# Create the multi-vector retriever
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Helper function to add documents to the vectorstore and docstore
def add_documents(retriever, doc_summaries, doc_contents):
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
summary_docs = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(doc_summaries)
]
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
# Add texts, tables, and images
# Check that text_summaries is not empty before adding
if text_summaries:
add_documents(retriever, text_summaries, texts)
# Check that table_summaries is not empty before adding
if table_summaries:
add_documents(retriever, table_summaries, tables)
# Check that image_summaries is not empty before adding
if image_summaries:
add_documents(retriever, image_summaries, images)
return retriever
# The vectorstore to use to index the summaries
vectorstore = Chroma(
collection_name="mm_rag_mistral",
embedding_function=VertexAIEmbeddings(model_name="textembedding-gecko@latest"),
)
# Create retriever
retriever_multi_vector_img = create_multi_vector_retriever(
vectorstore,
text_summaries,
texts,
table_summaries,
tables,
image_summaries,
img_base64_list,
)
In the above code, we defined a Chroma vector store and an in-memory document store and passed them to the multi-vector retriever. We embedded and stored the summaries of texts, tables, and images in the Chroma collection.
We will use the Langchain Expression Language to build the final pipeline.
import io
import re
from IPython.display import HTML, display
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from PIL import Image
def looks_like_base64(sb):
"""Check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def is_image_data(b64data):
"""
Check if the base64 data is an image by looking at the start of the data
"""
image_signatures = {
b"\xFF\xD8\xFF": "jpg",
b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def resize_base64_image(base64_string, size=(128, 128)):
"""
Resize an image encoded as a Base64 string
"""
# Decode the Base64 string
img_data = base64.b64decode(base64_string)
img = Image.open(io.BytesIO(img_data))
# Resize the image
resized_img = img.resize(size, Image.LANCZOS)
# Save the resized image to a bytes buffer
buffered = io.BytesIO()
resized_img.save(buffered, format=img.format)
# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def split_image_text_types(docs):
"""
Split base64-encoded images and texts
"""
b64_images = []
texts = []
for doc in docs:
# Check if the document is of type Document and extract page_content if so
if isinstance(doc, Document):
doc = doc.page_content
if looks_like_base64(doc) and is_image_data(doc):
doc = resize_base64_image(doc, size=(1300, 600))
b64_images.append(doc)
else:
texts.append(doc)
if len(b64_images) > 0:
return {"images": b64_images[:1], "texts": []}
return {"images": b64_images, "texts": texts}
def img_prompt_func(data_dict):
"""
Join the context into a single string
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
messages = []
# Adding the text for analysis
text_message = {
"type": "text",
"text": (
"You are an AI scientist tasking with providing factual answers.\n"
"You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
"Use this information to provide answers related to the user question. \n"
f"User-provided question: {data_dict['question']}\n\n"
"Text and / or tables:\n"
f"{formatted_texts}"
),
}
messages.append(text_message)
# Adding image(s) to the messages if present
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"]:
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
messages.append(image_message)
return [HumanMessage(content=messages)]
def multi_modal_rag_chain(retriever):
"""
Multi-modal RAG chain
"""
# Multi-modal LLM
model = ChatVertexAI(
temperature=0, model_name="gemini-pro-vision", max_output_tokens=1024
)
# RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
| model
| StrOutputParser()
)
return chain
# Create RAG chain
chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)
The RAG chain is created using the multi-vector retriever, Gemini Pro Vision, and a function that prepends an instruction to ground the LLM response.
The chain has multiple components chained together.
We can now invoke the chain and get our queries answered. You can run the retriever to see if it retrieves the right documents.
query = """compare and contrast between mistral and llama2 across benchmarks and
explain the reasoning in detail"""
docs = retriever_multi_vector_img.get_relevant_documents(query, limit=1)
docs[0]
I have used the official Mistral Arxiv paper as the reference PDF. Running the cell returned the following chart. This seems correct, considering the query I asked.
Now, invoke the RAG chain with the query to see if the chain is working as intended.
chain_multimodal_rag.invoke(query)
To get better answers, try playing with the prompts. Often, you get better responses just by tweaking the prompt a bit.
So, this was all about building MM-RAG in Langchain.
The RAG paired with vector databases filled a crucial lacuna of LLMs, and that is reducing hallucination. RAG and vector search have created a new sub-genre of technology. And with LLMs with vision, we can now retrieve information from images. We can also process and get answers from embedded images in files along with texts.
So, in this article, we used Langchain, chroma, Gemini, and Unstructured to build a multi-modal RAG pipeline.
Here are the key takeaways.
A. LangChain is an open-source framework that simplifies the creation of applications using large language models. It can be used for various tasks, including chatbots, document analysis, code analysis, question answering, and generative tasks.
Agents are more complex than chains. They can make decisions about which steps to execute, and they can also learn from their experiences. Agents are often used for tasks that require a lot of creativity or reasoning, For example, data analysis and code generation.
A. Multimodal AI refers to the Machine Learning models that can process and understand various modalities of data such as image, text, audio, etc.
A. A RAG pipeline retrieves documents from external data stores, processes them to store them in a knowledge base, and provides tools to query them.
A. Llama Index explicitly designs search and retrieval applications, while Langchain offers flexibility for creating custom AI agents.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.