PaLiGemma is an open-source state-of-the-art model released alongside other products at Google I/O 2024 and combines two other models developed by Google. Based on open components like the SigLIP vision model and the Gemma language model, PaliGemma is a flexible and lightweight vision-language model (VLM) that draws inspiration from PaLI-3. It supports several languages and produces text output after accepting images and text as input. It is intended to serve as a model for various vision-language activities, including text reading, object identification and segmentation, visual question answering, and captioning images and short videos.
In contrast to other VLMs that have had trouble with object detection and segmentation, notably OpenAI’s GPT-4o, Google Gemini, and Anthropic’s Claude 3, PaliGemma offers a wide variety of capabilities and can be finetuned for improved performance on particular tasks.
In today’s blog, we will learn the pipeline for fine-tuning the PaLiGemma model and deploying it over one of the service providers. Throughout the tutorial, we will use Roboflow for easy dataset access in the desired format, Kaggle for loading the model weights, and finally, Azure Virtual Machines. A Colab instance with an NVIDIA T4 GPU would be sufficient for the task.
In this blog, you will learn:
This article was published as a part of the Data Science Blogathon.
Before reading this blog, you should be familiar with Python programming and the training process for large language models (LLMs). Although not compulsory, having a rudimentary understanding of JAX (or related technologies like Keras) would be beneficial when examining the sample code snippets.
Also, for fine-tuning the PaLiGemma, we will follow the below steps:
For first-time users, we must request PaLiGemma access through Kaggle and configure our API key, the steps of which are mentioned below.
Once all is done, set the environment variables as shown below.
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
To fine-tune the PaLiGemma model, we will use the big_vision project maintained by Google Research. The code below can install the repository and corresponding dependencies in your notebooks.
import os
import sys
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
!git clone --quiet --branch=main --depth=1 \
https://github.com/google-research/big_vision big_vision_repo
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"
Code Output
The code below will import the necessary frameworks, like JAX, to complete the model setup.
import base64
import functools
import html
import io
import os
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from IPython.core.display import display, HTML
from PIL import Image
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.lib.xla_bridge.get_backend()
print(f"JAX version: {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices: {jax.device_count()}")
Also read: PaliGemma: Google’s New AI Sees Like You and Writes Like Shakespeare!
For any finetuning tasks using PaLiGemma, we need that data in the PaLiGemma JSONL format. You might not be familiar with this format, as it is not a common data format (like YOLO) for image tasks, but JSONL (JSON Lines) is often used for training large models because it allows for efficient line-by-line processing. Below is an example of the JSONL format for data storage.
{"name": "John Doe", "age": 30, "city": "New York"}
{"name": "Jane Smith", "age": 25, "city": "Los Angeles"}
{"name": "Sam Brown", "age": 22, "city": "Chicago"}
Creating the data in a JSONL format is easy, and below, I am providing sample code to do the same.
import json
import os
# Directory containing the images
image_dir = '/path/to/images'
# Dictionary containing the image labels
labels = {
"image1.jpg": "label1",
"image2.jpg": "label2",
"image3.jpg": "label3"
}
# Create a list of dictionaries with image path and label
data = []
for image_name, label in labels.items():
image_path = os.path.join(image_dir, image_name)
data.append({"image_path": image_path, "label": label})
# Write the data to a JSONL file
with open('images_labels.jsonl', 'w') as file:
for entry in data:
file.write(json.dumps(entry) + '\n')
However, here we will use Roboflow for easy task achievement. Roboflow has already provided full support to the PaLiGemma JSONL format, which can be used to access any datasets from the Roboflow Universe. You can use any of the datasets according to your task requirements by using the Roboflow API key. Below is a code snippet showing how to achieve the same.
#Install the required dependencies to download and parse a dataset
!pip install roboflow supervision
from google.colab import userdata
from roboflow import Roboflow
ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')
rf = Roboflow(api_key=ROBOFLOW_API_KEY)
project = rf.workspace("workspace-user-id").project("sample-project-name")
version = project.version(#enterversionnumber)
dataset = version.download("PaliGemma")
Now that we have successfully completed the model setup and imported the data in the desired format and platform, we can obtain the PaLiGemma weights to finetune the model further.
This step involves downloading the PaLiGemma weights from Kaggle. For easy computation in limited resources, we will use the paligemma-3b-pt-224 version. JAX/FLAX PaliGemma 3B is available in three different versions, differing in input image resolution (224, 448, and 896) and input text sequence length (128, 512, and 512 tokens, respectively).
The float16 version of the model checkpoint can be downloaded from Kaggle by running the following code. This process may be a bit time-consuming.
import os
import kagglehub
MODEL_PATH = "./pt_224_128.params.f16.npz"
if not os.path.exists(MODEL_PATH):
MODEL_PATH = kagglehub.model_download
('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')
print(f"Model path: {MODEL_PATH}")
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATPaLiGemma modelH):
print("Downloading the model tokenizer...")
!gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
print(f"Tokenizer path: {TOKENIZER_PATH}")
DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
print("Downloading the dataset...")
!gsutil -m -q cp -n -r gs://longcap100/ .
print(f"Data path: {DATA_DIR}")
Code Output
The next step would require configuring and moving the model to fit with the Colab T4 GPU. To set up the model, start by initializing the `model_config` as a `FrozenConfigDict,` which helps freeze certain parameters and reduces memory usage. Then, create an instance of the `PaliGemma Model` class, using `model_config` for its settings. Load the model parameters into RAM and define a decode function to sample outputs from the model. Once done, the model can then be moved to the T4 GPU. The below code will guide both steps.
# Define model
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True,
"dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(),
eos_token=tokenizer.eos_id())
#Move model to T4 GPU
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves,
trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
Code Output
This step has completed all the necessities for our fine-tuning process, so we can proceed to the subsequent step.
Also read: SynthID: Google is Expanding Ways to Protect AI Misinformation
Before proceeding to the fine-tuning step, a few more checks and preprocessing steps must be performed. These are standard procedures, and their codes would be long, so they are not considered in the current scope. Details of these can be found in additional open-source resources mentioned in subsequent sections. Regardless, a broad overview of the steps is mentioned below.
With all these steps done, we can now finetune the model. The below code will achieve the same. It runs the training loop for the model over 64 steps, displaying the learning rate (lr) and loss rate at each step. Every 16 steps, it outputs the model’s predictions for the same set of images, allowing you to observe the improvement in the model’s ability to predict descriptions. Early in the training, predictions may contain errors like repeated or incomplete sentences, but as training progresses, the accuracy of the descriptions improves. By step 64, the model’s predictions should closely match the descriptions from the training data.
BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4
train_data_it = train_data_iterator()
sched_fn = big_vision.utils.create_learning_rate_schedule(
total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
# Make list of N training examples.
examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Training step and report training loss
learning_rate = sched_fn(step)
params, loss = update_fn(params, batch, learning_rate)
loss = jax.device_get(loss)
print(f"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}")
if (step % EVAL_STEPS) == 0:
print(f"Model predictions at step {step}")
html_out = ""
for image, caption in make_predictions(
validation_data_iterator(), num_examples=4, batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
You can now test the fine-tuned model using a pre-defined function called `make_predictions`, which processes images iteratively and performs inference on each one. This function can be used to test our fine-tuned object detection model.
print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
html_out += render_example(image, caption)
display(HTML(html_out))
Below is a sample of the model outputs over each iteration. For the current purpose, the fineunting was done for 30 steps, as it was performed for a demo purpose. The dataset, number of steps, and other hyperparameters will also change based on your usage and requirements.
Once finetuning is completed and the model predictions have been checked, to use the same model further or to be able to deploy it for the later stages, it can be saved using the below code:
flat, _ = big_vision.utils.tree_flatten_with_names(params)
with open("/content/fine-tuned-PaliGemma-3b-pt-224.f16.npz", "wb") as f:
np.savez(f, **{k: v for k, v in flat})
For deploying, we will rely on the Roboflow Inference server and deploy it on an AWS EC2 instance. The Roboflow Inference Server allows you to deploy computer vision models to various devices, including AWS EC2. The Inference Server relies on Docker to run. If you don’t already have Docker installed on the device(s) on which you want to run inference, install it by following the official Docker installation instructions. Once you have Docker installed, run the following command to download the Roboflow Inference Server on your AWS EC2.
pip install inference supervision
Now, the Roboflow Inference server will be running, and you can use the finetuned model in the EC2 server.
In this blog, we have walked through the comprehensive process of fine-tuning and deploying the PaLiGemma model, a cutting-edge vision-language model from Google. Starting with installing the necessary dependencies and setting up our environment, we leveraged various tools and platforms, including Kaggle for accessing model weights, Roboflow for dataset preparation, and Azure Virtual Machines for deployment. By following these steps, you can harness the power of PaLiGemma for a range of vision-language tasks such as object detection, image captioning, and visual question answering. I hope this guide provides a clear and practical pathway to enhance your projects with advanced AI capabilities.
In addition to this blog, here are a few more interesting reads and inspirations for this blog.
The media shown in this article are not owned by Analytics Vidhya and is used at the Author’s discretion.
A. You must be familiar with Python programming and have experience training large language models (LLMs). Knowledge of JAX or Keras is beneficial for understanding the code snippets. Additionally, you’ll need access to Kaggle to download the model weights and datasets and an Azure account to deploy the model.
A. First, log in to your Kaggle account and request access to the PaLiGemma model through its model card on Kaggle. Accept the terms and generate an API key from your Kaggle settings. Download the model weights using this API key and store it securely in your Colab instance to access the model.
A. Your dataset should be in JSONL format, where each line in the file represents a JSON object. For example:{"image_path": "/path/to/image1.jpg", "label": "label1"} {"image_path": "/path/to/image2.jpg", "label": "label2"}
You can use tools like Roboflow to prepare and download datasets in the required JSONL format.
A. You need to set the model configuration to be compatible with your environment, such as a Colab T4 GPU. Load the model weights and tokenizer, and appropriately set up the model parameters and data sharding. Use JAX and the necessary libraries to prepare the model for training.
A. After fine-tuning your model, save the model parameters. Set up an Azure Virtual Machine (VM) to host your model. Transfer the fine-tuned model to the VM and use Azure’s deployment services to make it accessible for inference. The specific deployment steps on Azure will depend on your VM configuration and preferred deployment method.