Retrieval Augmented Generation (RAG) has revolutionized how large language models access external data, but traditional approaches are limited to text. With the rise of multimodal data, integrating text and visual information is crucial for comprehensive analysis, especially in complex fields like finance and research. Multimodal RAG addresses this by enabling models to process both text and images for better knowledge retrieval and reasoning. This article explores building a multimodal RAG system using Google’s Gemini models, Vertex AI, and LangChain, guiding you through environment setup, data processing, embedding generation, and constructing a robust document search engine.
Multimodal RAG models combine visual and printed information to supply more strong and context-aware yields. Not at all like conventional Cloth models, which exclusively depend on content, multimodal Clothes are outlined to get and consolidate visual substance such as graphs, charts, and pictures. This dual-processing capability is particularly valuable for analyzing complex records where visuals are as enlightening as content, such as money-related reports, logical papers, or client manuals.
By preparing content and pictures, the show offers a more profound understanding of the substance, driving to more precise and smart reactions. This integration relieves the chance of producing deceiving or relevantly erroneous data (commonly known as visualization in machine learning), coming about in more dependable yields for decision-making and investigation.
Here’s a summary of each key technology:
The architecture of a multimodal RAG system involves:
Now let’s get into the actual coding part. In this section, I will guide you through the steps of building a multimodal RAG system for content and images, using Google Gemini, Vertex AI, and LangChain.
Let’s begin by setting up the environment.
The %pip install command installs all the necessary Python libraries, including google-cloud-aiplatform, langchain, and various document-processing libraries like pypdf.
%pip install -U -q google-cloud-aiplatform langchain-core langchain-google-vertexai langchain-text-splitters langchain-community "unstructured[all-docs]" pypdf pydantic lxml pillow matplotlib opencv-python tiktoken
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)
Add the code to authenticate and initialize the Vertex AI environment
The auth.authenticate_user() function is used for authenticating your Google Cloud account in Google Colab.
import sys
# Additional authentication is required for Google Colab
if "google.colab" in sys.modules:
# Authenticate user to Google Cloud
from google.colab import auth
auth.authenticate_user()
PROJECT_ID = “YOUR_PROJECT_ID” # @param {type:”string”}
PROJECT_ID = "YOUR_PROJECT_ID" # @param {type:"string"}
LOCATION = "us-central1" # @param {type:"string"}
# For Vector Search Staging
GCS_BUCKET = "YOUR_BUCKET_NAME" # @param {type:"string"}
GCS_BUCKET_URI = f"gs://{GCS_BUCKET}"
from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=LOCATION, staging_bucket=GCS_BUCKET_URI)
Add the code for constructing the document repository and integrating LangChain:
Imports various libraries like langchain, IPython, pillow, and others needed for the retrieval and processing pipeline.
import base64
import os
import re
import uuid
from IPython.display import Image, Markdown, display
from langchain.prompts import PromptTemplate
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_google_vertexai import (
ChatVertexAI,
VectorSearchVectorStore,
VertexAI,
VertexAIEmbeddings,
)
from langchain_text_splitters import CharacterTextSplitter
from unstructured.partition.pdf import partition_pdf
# from langchain_community.vectorstores import Chroma # Optional
MODEL_NAME = "gemini-1.5-flash"
GEMINI_OUTPUT_TOKEN_LIMIT = 8192
EMBEDDING_MODEL_NAME = "text-embedding-004"
EMBEDDING_TOKEN_LIMIT = 2048
TOKEN_LIMIT = min(GEMINI_OUTPUT_TOKEN_LIMIT, EMBEDDING_TOKEN_LIMIT)
# Download documents and images used in this notebook
!gsutil -m rsync -r gs://github-repo/rag/intro_multimodal_rag/ .
print("Download completed")
pdf_folder_path = "/content/data/" if "google.colab" in sys.modules else "data/"
pdf_file_name = "google-10k-sample-14pages.pdf"
# Extract images, tables, and chunk text from a PDF file.
raw_pdf_elements = partition_pdf(
filename=pdf_file_name,
extract_images_in_pdf=False,
infer_table_structure=True,
chunking_strategy="by_title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=pdf_folder_path,
)
# Categorize extracted elements from a PDF into tables and texts.
tables = []
texts = []
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
tables.append(str(element))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
texts.append(str(element))
# Optional: Enforce a specific token size for texts
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=10000, chunk_overlap=0
)
joined_texts = " ".join(texts)
texts_4k_token = text_splitter.split_text(joined_texts)
def generate_text_summaries(
texts: list[str], tables: list[str], summarize_texts: bool = False
) -> tuple[list, list]:
"""
Summarize text elements
texts: List of str
tables: List of str
summarize_texts: Bool to summarize texts
"""
# Prompt
prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
These summaries will be embedded and used to retrieve the raw text or table elements. \
Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """
prompt = PromptTemplate.from_template(prompt_text)
empty_response = RunnableLambda(
lambda x: AIMessage(content="Error processing document")
)
# Text summary chain
model = VertexAI(
temperature=0, model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT
).with_fallbacks([empty_response])
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Initialize empty summaries
text_summaries = []
table_summaries = []
# Apply to text if texts are provided and summarization is requested
if texts:
if summarize_texts:
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
else:
text_summaries = texts
# Apply to tables if tables are provided
if tables:
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 1})
return text_summaries, table_summaries
# Get text, table summaries
text_summaries, table_summaries = generate_text_summaries(
texts_4k_token, tables, summarize_texts=True
)
def encode_image(image_path: str) -> str:
"""Getting the base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_summarize(model: ChatVertexAI, base64_image: str, prompt: str) -> str:
"""Make image summary"""
msg = model.invoke(
[
HumanMessage(
content=[
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
},
]
)
]
)
return msg.content
def generate_img_summaries(path: str) -> tuple[list[str], list[str]]:
"""
Generate summaries and base64 encoded strings for images
path: Path to list of .jpg files extracted by Unstructured
"""
# Store base64 encoded images
img_base64_list = []
# Store image summaries
image_summaries = []
# Prompt
prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval.
If it's a table, extract all elements of the table.
If it's a graph, explain the findings in the graph.
Do not include any numbers that are not mentioned in the image.
"""
model = ChatVertexAI(model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT)
# Apply to images
for img_file in sorted(os.listdir(path)):
if img_file.endswith(".png"):
base64_image = encode_image(os.path.join(path, img_file))
img_base64_list.append(base64_image)
image_summaries.append(image_summarize(model, base64_image, prompt))
return img_base64_list, image_summaries
# Image summaries
img_base64_list, image_summaries = generate_img_summaries(".")
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings
DIMENSIONS = 768 # Dimensions output from textembedding-gecko
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name="mm_rag_langchain_index",
dimensions=DIMENSIONS,
approximate_neighbors_count=150,
leaf_node_embedding_count=500,
leaf_nodes_to_search_percent=7,
description="Multimodal RAG LangChain Index",
index_update_method="STREAM_UPDATE",
)
DEPLOYED_INDEX_ID = "mm_rag_langchain_index_endpoint"
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=DEPLOYED_INDEX_ID,
description="Multimodal RAG LangChain Index Endpoint",
public_endpoint_enabled=True,
)
index_endpoint = index_endpoint.deploy_index(
index=index, deployed_index_id="mm_rag_langchain_deployed_index"
)
index_endpoint.deployed_indexes
# The vectorstore to use to index the summaries
vectorstore = VectorSearchVectorStore.from_components(
project_id=PROJECT_ID,
region=LOCATION,
gcs_bucket_name=GCS_BUCKET,
index_id=index.name,
endpoint_id=index_endpoint.name,
embedding=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),
stream_update=True,
)
docstore = InMemoryStore()
id_key = "doc_id"
# Create the multi-vector retriever
retriever_multi_vector_img = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=docstore,
id_key=id_key,
)
• Load data into Document Store and Vector Store
# Raw Document Contents
doc_contents = texts + tables + img_base64_list
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
summary_docs = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(text_summaries + table_summaries + image_summaries)
]
retriever_multi_vector_img.docstore.mset(list(zip(doc_ids, doc_contents)))
# If using Vertex AI Vector Search, this will take a while to complete.
# You can cancel this cell and continue later.
retriever_multi_vector_img.vectorstore.add_documents(summary_docs)
def looks_like_base64(sb):
"""Check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def is_image_data(b64data):
"""
Check if the base64 data is an image by looking at the start of the data
"""
image_signatures = {
b"\xFF\xD8\xFF": "jpg",
b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def split_image_text_types(docs):
"""
Split base64-encoded images and texts
"""
b64_images = []
texts = []
for doc in docs:
# Check if the document is of type Document and extract page_content if so
if isinstance(doc, Document):
doc = doc.page_content
if looks_like_base64(doc) and is_image_data(doc):
b64_images.append(doc)
else:
texts.append(doc)
return {"images": b64_images, "texts": texts}
def img_prompt_func(data_dict):
"""
Join the context into a single string
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
messages = [
{
"type": "text",
"text": (
"You are financial analyst tasking with providing investment advice.\n"
"You will be given a mix of text, tables, and image(s) usually of charts or graphs.\n"
"Use this information to provide investment advice related to the user's question. \n"
f"User-provided question: {data_dict['question']}\n\n"
"Text and / or tables:\n"
f"{formatted_texts}"
),
}
]
# Adding image(s) to the messages if present
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"]:
messages.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
)
return [HumanMessage(content=messages)]
# Create RAG chain
chain_multimodal_rag = (
{
"context": retriever_multi_vector_img | RunnableLambda(split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
| ChatVertexAI(
temperature=0,
model_name=MODEL_NAME,
max_output_tokens=TOKEN_LIMIT,
) # Multi-modal LLM
| StrOutputParser()
)
query = "What are the EV / NTM and NTM rev growth for MongoDB, Cloudflare, and Datadog?
"
# List of source documents
docs = retriever_multi_vector_img.get_relevant_documents(query, limit=1)
# We get relevant docs
len(docs)
docs
plt_img_base64(docs[3])
result = chain_multimodal_rag.invoke(query)
from IPython.display import Markdown as md
md(result)
Multimodal RAG (Retrieval-Augmented Generation) combines text and visual data to enhance information retrieval, enabling more contextually accurate and comprehensive AI responses. By leveraging tools like Gemini, Vertex AI, and LangChain, developers can build intelligent systems that efficiently process both textual and visual data.
Gemini enables understanding of diverse data types, while Vertex AI supports scalable model deployment for real-time applications. LangChain streamlines integration with external APIs and databases, allowing seamless interaction with multiple data sources. Together, these technologies provide powerful capabilities for creating context-aware, data-rich systems for use in areas like content generation, personalized recommendations, and interactive AI assistants.
A. Multimodal RAG (Retrieval Augmented Generation) combines text and visual data to improve the accuracy and context of information retrieval, allowing AI systems to provide more comprehensive and relevant responses.
A. Gemini, by Google, is designed to process both text and visual data, enabling AI models to understand and generate insights from mixed data types, enhancing the overall performance of multimodal systems.
A. Vertex AI may be a stage by Google Cloud that provides tools for sending and overseeing AI models at scale. It streamlines the method of building, preparing, and optimizing models, making it simpler for engineers to execute effective multimodal frameworks.
A. LangChain is a framework that helps integrate large language models with external data sources, APIs, and databases. It enables seamless interaction with different types of data, enhancing the capabilities of multimodal RAG systems.
A. Multimodal RAG can be applied in areas like personalized recommendations, content generation, image-captioning, healthcare (cross-referencing X-rays with medical records), and AI assistants that provide context-aware responses.