Building a Multi-Modal RAG Pipeline with Langchain

Sunil Kumar Last Updated : 15 Jan, 2024
10 min read

Introduction

The advent of LLMs has completely changed how we extract information from documents. However, we know images, usually charts and tables, often contain crucial information, but text-based LLMs cannot process media files. For example, we could only use the texts from a PDF file to find answers. But now, with the release of multi-modal LLMs from different labs, it is possible to extract information from images. Multi-modal models like GPT-4V and Gemini Pro Vision have shown great ability to infer data from images. We can use these models to augment a regular RAG pipeline to build a multi-modal RAG pipeline. So, in this article, we will create an MM-RAG pipeline using Gemini Pro models, Chroma vector Database, and Langchain.

Learning Objectives

  • A Brief Primer on Langchain and Vector Databases.
  • Explore different approaches to building multi-modal RAG pipelines.
  • Build a multi-vector retriever in Langchain using Chroma.
  • Create a RAG pipeline for multi-modal data.

This article was published as a part of the Data Science Blogathon.

Langchain and Vector Databases

The Langchain is one of the hottest tools of 2023. It is an open-source framework for building chains of tasks and LLM agents. It has almost all the tools you need to create a functional AI application. From data loaders and vector stores to LLMs from different labs, it got it all covered. Langchain has two core value propositions. Chains and Agents. 

  • Chains are sequences of Langchain components. The outputs of one chain will be used as the inputs of the other. To make development easier, Langchain has an expression language called LCEL.
  • Agents are autonomous bots powered by LLMs. Unlike chains, the steps are not rigid; an agent based on data and descriptions of tools can use any tools it can access.

We will be using Langchain’s expression language to build the multi-modal pipeline. Another crucial aspect of RAG is a vector database. These databases are purpose-built for storing embeddings of data. Vector databases are built to handle millions of embeddings. So, whenever we need context-aware retrieval, the vector stores become implicit. To get embeddings of data, embedding models are used; these models have been trained over a large quantity of data to find the similarity of texts.

Building a Multi-Modal RAG Pipeline with Langchain | Langchain and Vector databases

Approaches to Building Multi-Modal RAG Pipeline

So, how do we build a RAG pipeline in a multi-modal context? There are three different ways we can create an MM-RAG pipeline.

  • Option 1: Use a multi-modal embedding model like CLIP or Imagebind to create embeddings of images and texts. Retrieve both using similarity search and pass the documents to a multi-modal LLM.
  • Option 2: Use a multi-modal model to create summaries of images. Retrieve the summaries from vector stores and pass them to an LLM for Q&A.
  • Option 3: Use a multi-modal LLM to get descriptions of images. Embed the text descriptions using any embedding model of choice and store original documents in a doc store. Retrieve summaries with a reference to the original image and text chunks. Pass the original documents to a multi-modal LLM for answer generation.

Each of the options has its pros and cons.

Option 1 can be good for generic images but might struggle with charts and tables.

The second option is when you cannot use a multi-modal model frequently.

Option 3 for improved accuracy. MM-LLMs like GPT-4V or Gemini can understand charts and tables. Option 3 is the choice when it involves complex image understanding. But it will also cost more.

Approaches to Building Multi-Modal RAG Pipeline

We will implement the 3rd approach with Gemini Pro Vision. 

Building RAG Pipeline

Now that we are aware of the concepts and tools required. Here is a quick workflow overview.

  • We start with extracting images and texts from files.
  • Get summaries of them from a vision LLM.
  • Embed the summaries in Chroma and the original files in an in-memory database.
  • Create a multi-vector retriever. It retrieves the original documents from the datastore corresponding to their summaries using a similarity score.
  • Pass the documents to an MM-LLM to get answers.

Let’s delve into coding the RAG pipeline. 

Dependencies

PDFs and other data formats often have tables and pictures embedded in them. And it is not possible to extract them just as easily as texts. To achieve this, we need purpose-built tools like Unstructured. Unstructured is an open-source tool to pre-process images and files, like HTML, PDF, and Word docs. It can extract embedded images from files using OCR. Unstrcured requires Poppler and Tessearct to be installed in the system.

Install Tesseract and Poppler.

!sudo apt install tesseract-ocr
!sudo apt-get install poppler-utils

Now, we can install Unstructured along with other required libraries. 

!pip install "unstructured[all-docs]" langchain langchain_community \
 chromadb langchain-experimental

Extract Images and Tables

We use the partiton_pdf to extract images and tables.

from unstructured.partition.pdf import partition_pdf

image_path = "./"
pdf_elements = partition_pdf(
    "mistral.pdf",
    chunking_strategy="by_title",
    extract_images_in_pdf=True,
    max_characters=3000,
    new_after_n_chars=2800,
    combine_text_under_n_chars=2000,
    image_output_dir_path=image_path
    )

This will partition the PDF and extract images to the given path. It also chunks the texts by the strategy and the character limits provided.

Now, we separate texts and tables into different groups.

# Categorize elements by type
def categorize_elements(raw_pdf_elements):
    """
    Categorize extracted elements from a PDF into tables and texts.
    raw_pdf_elements: List of unstructured.documents.elements
    """
    tables = []
    texts = []
    for element in raw_pdf_elements:
        if "unstructured.documents.elements.Table" in str(type(element)):
            tables.append(str(element))
        elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
            texts.append(str(element))
    return texts, tables

texts, tables = categorize_elements(pdf_elements)

Text and Table Summaries

We will use Gemini Pro to get short summaries of text chunks. Langchain has a summarizing chain for this purpose. We will use the expression language to build a simple summarizing chain.  To use Gemini models with Langchain, you need to set up a GCP account. Enable VertexAI and configure the credentials

from langchain.chat_models import ChatVertexAI
from langchain.llms import VertexAI
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda


# Generate summaries of text elements
def generate_text_summaries(texts, tables, summarize_texts=False):
    """
    Summarize text elements
    texts: List of str
    tables: List of str
    summarize_texts: Bool to summarize texts
    """

    # Prompt
    prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
    These summaries will be embedded and used to retrieve the raw text or table elements. \
    Give a concise summary of the table or text that is well-optimized for retrieval. Table \
    or text: {element} """
    prompt = PromptTemplate.from_template(prompt_text)
    empty_response = RunnableLambda(
        lambda x: AIMessage(content="Error processing document")
    )
    # Text summary chain
    model = VertexAI(
        temperature=0, model_name="gemini-pro", max_output_tokens=1024
    ).with_fallbacks([empty_response])
    summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

    # Initialize empty summaries
    text_summaries = []
    table_summaries = []

    # Apply to text if texts are provided and summarization is requested
    if texts and summarize_texts:
        text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
    elif texts:
        text_summaries = texts

    # Apply to tables if tables are provided
    if tables:
        table_summaries = summarize_chain.batch(tables, {"max_concurrency": 1})

    return text_summaries, table_summaries


# Get text, table summaries
text_summaries2, table_summaries = generate_text_summaries(
    texts[9:], tables, summarize_texts=True
)

In the chain, we have four different components. The first is a dictionary, the second is for creating a prompt template, 3rd is the LLM model, and the final is a string parser. The outputs of each are used as the input for subsequent modules.

The AI Message class with runnable Lambda returns a message if it fails to query from the LLM.

Image Summaries

As we discussed earlier, we will use a vision model to get text descriptions of images. You can use any vision model like GPT-4, Llava, Gemini, etc. Here, we will use Gemini Pro Vision.

To process images, we will convert them to base64 format and pass them to the Gemini Pro Vision model with a default prompt. 

import base64
import os

from langchain_core.messages import HumanMessage


def encode_image(image_path):
    """Getting the base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def image_summarize(img_base64, prompt):
    """Make image summary"""
    model = ChatVertexAI(model_name="gemini-pro-vision", max_output_tokens=1024)

    msg = model(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
                    },
                ]
            )
        ]
    )
    return msg.content

def generate_img_summaries(path):
    """
    Generate summaries and base64 encoded strings for images
    path: Path to list of .jpg files extracted by Unstructured
    """

    # Store base64 encoded images
    img_base64_list = []

    # Store image summaries
    image_summaries = []

    # Prompt
    prompt = """You are an assistant tasked with summarizing images for retrieval. \
    These summaries will be embedded and used to retrieve the raw image. \
    Give a concise summary of the image that is well optimized for retrieval."""

    # Apply to images
    for img_file in sorted(os.listdir(path)):
        if img_file.endswith(".jpg"):
            img_path = os.path.join(path, img_file)
            base64_image = encode_image(img_path)
            img_base64_list.append(base64_image)
            image_summaries.append(image_summarize(base64_image, prompt))

    return img_base64_list, image_summaries

fpath = "./"
# Image summaries
img_base64_list, image_summaries = generate_img_summaries(fpath)

Multi-Vector Retriever

As we discussed earlier, we will store embeddings of the image and table descriptions in a vector store and store the original documents in an in-memory document store. Then, we retrieve the original documents corresponding to the retrieved vectors from the vector store. Langchain has a multi-vector retriever to achieve this. So, here is how we can build a multi-vector retriever.

import uuid

from langchain.embeddings import VertexAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.storage import InMemoryStore
from langchain.vectorstores import Chroma


def create_multi_vector_retriever(
    vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
    """
    Create retriever that indexes summaries, but returns raw images or texts
    """

    # Initialize the storage layer
    store = InMemoryStore()
    id_key = "doc_id"

    # Create the multi-vector retriever
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key=id_key,
    )
    # Helper function to add documents to the vectorstore and docstore
    def add_documents(retriever, doc_summaries, doc_contents):
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i]})
            for i, s in enumerate(doc_summaries)
        ]
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

    # Add texts, tables, and images
    # Check that text_summaries is not empty before adding
    if text_summaries:
        add_documents(retriever, text_summaries, texts)
    # Check that table_summaries is not empty before adding
    if table_summaries:
        add_documents(retriever, table_summaries, tables)
    # Check that image_summaries is not empty before adding
    if image_summaries:
        add_documents(retriever, image_summaries, images)

    return retriever

# The vectorstore to use to index the summaries
vectorstore = Chroma(
    collection_name="mm_rag_mistral",
    embedding_function=VertexAIEmbeddings(model_name="textembedding-gecko@latest"),
)

# Create retriever
retriever_multi_vector_img = create_multi_vector_retriever(
    vectorstore,
    text_summaries,
    texts,
    table_summaries,
    tables,
    image_summaries,
    img_base64_list,
)

In the above code, we defined a Chroma vector store and an in-memory document store and passed them to the multi-vector retriever. We embedded and stored the summaries of texts, tables, and images in the Chroma collection. 

RAG Pipeline

We will use the Langchain Expression Language to build the final pipeline. 

import io
import re

from IPython.display import HTML, display
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from PIL import Image

def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None


def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xFF\xD8\xFF": "jpg",
        b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False

def resize_base64_image(base64_string, size=(128, 128)):
    """
    Resize an image encoded as a Base64 string
    """
    # Decode the Base64 string
    img_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(img_data))

    # Resize the image
    resized_img = img.resize(size, Image.LANCZOS)

    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)

    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content
        if looks_like_base64(doc) and is_image_data(doc):
            doc = resize_base64_image(doc, size=(1300, 600))
            b64_images.append(doc)
        else:
            texts.append(doc)
    if len(b64_images) > 0:
        return {"images": b64_images[:1], "texts": []}
    return {"images": b64_images, "texts": texts}
  
def img_prompt_func(data_dict):
    """
    Join the context into a single string
    """
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding the text for analysis
    text_message = {
        "type": "text",
        "text": (
            "You are an AI scientist tasking with providing factual answers.\n"
            "You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
            "Use this information to provide answers related to the user question. \n"
            f"User-provided question: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)
    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)
    return [HumanMessage(content=messages)]

def multi_modal_rag_chain(retriever):
    """
    Multi-modal RAG chain
    """

    # Multi-modal LLM
    model = ChatVertexAI(
        temperature=0, model_name="gemini-pro-vision", max_output_tokens=1024
    )

    # RAG pipeline
    chain = (
        {
            "context": retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
    )

    return chain


# Create RAG chain
chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)

The RAG chain is created using the multi-vector retriever, Gemini Pro Vision, and a function that prepends an instruction to ground the LLM response.

The chain has multiple components chained together.

  • The first one is a dictionary with contexts and user queries. The value of the context key is another chain with the retriever and a function for separating image and text.
  • The final dictionary is then passed to a function that adds an instruction for grounding LLM responses.
  • The Chat model receives the output of the previous step and generates a response based on the query and the contexts.
  • Finally, the output is parsed using the StrOutputParser.

We can now invoke the chain and get our queries answered. You can run the retriever to see if it retrieves the right documents.

query = """compare and contrast between mistral and llama2 across benchmarks and 
explain the reasoning in detail"""
docs = retriever_multi_vector_img.get_relevant_documents(query, limit=1)
docs[0]

I have used the official Mistral Arxiv paper as the reference PDF. Running the cell returned the following chart. This seems correct, considering the query I asked.

Building a Multi-Modal RAG Pipeline with Langchain

Now,  invoke the RAG chain with the query to see if the chain is working as intended.

chain_multimodal_rag.invoke(query)

To get better answers, try playing with the prompts. Often, you get better responses just by tweaking the prompt a bit.

So, this was all about building MM-RAG in Langchain. 

Conclusion

The RAG paired with vector databases filled a crucial lacuna of LLMs, and that is reducing hallucination. RAG and vector search have created a new sub-genre of technology. And with LLMs with vision, we can now retrieve information from images. We can also process and get answers from embedded images in files along with texts.

So, in this article, we used Langchain, chroma, Gemini, and Unstructured to build a multi-modal RAG pipeline.

Here are the key takeaways.

  • Much information stays untouched in the form of images in a standard RAG. A multi-modal RAG fills this gap by augmenting existing RAG with LLMs with vision.
  • There are different approaches to building MM-RAG. Using MM-LLM for image summarizing, passing the original documents retrieved by calculating similarity scores of summaries to query text to an MM-LLM provides the most accuracy.
  • Langchain is an open-source framework for building LLM workflows and agents.
  • It provides a multi-vector retriever for retrieving documents from multiple data stores.

Frequently Asked Question

Q1. What is Langchain used for?

A. LangChain is an open-source framework that simplifies the creation of applications using large language models. It can be used for various tasks, including chatbots, document analysis, code analysis, question answering, and generative tasks.

Q2. What is the difference between chains and agents in Langchain?

Agents are more complex than chains. They can make decisions about which steps to execute, and they can also learn from their experiences. Agents are often used for tasks that require a lot of creativity or reasoning, For example, data analysis and code generation.

Q3. What is multimodal AI?

A. Multimodal AI refers to the Machine Learning models that can process and understand various modalities of data such as image, text, audio, etc.

Q4. What is a RAG pipeline?

A. A RAG pipeline retrieves documents from external data stores, processes them to store them in a knowledge base, and provides tools to query them.

Q5. What is the difference between the Langchain and Llama Index?

A. Llama Index explicitly designs search and retrieval applications, while Langchain offers flexibility for creating custom AI agents.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

Meet your author Sunil kumar Dash, a developer and a writer. Has diverse interests in tech, pop culture, wellness, philosophy and Anime. Exploring underrated music is his hobby. And loves to doom scroll Twitter when bored.

Responses From Readers

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details