In this tutorial, you will learn how to construct iterate update and train a CNN model using JAX, Flax, and Optax on the MNIST dataset. This tutorial starts from how to set up the environment and preprocess the data to how to define the CNN structure and the final step is to test the model. It will also indicate how the core elements of JAX’s strong numerical performance, Flax’s flexible neural network, and Optax’s sophisticated optimization tools train and evaluate a state-of-the-art deep learning model efficiently. The goal of this guide is to explain how all these tools can be useful towards optimizing deep learning procedures and making models better.
This article was published as a part of the Data Science Blogathon.
For deep learning models to be highly efficient and easily scalable, developers look for some valuable tools which help in computation, model designing and optimization. Guess what? Even with that assumption, there is still the question of how the three; JAX, Flax and Optax collectively address the challenges inherent in the development of complex ML models; well, let’s find out.
JAX is a high-performance numerical computing library with a familiar NumPy-like syntax. It excels in scenarios requiring hardware acceleration or automatic differentiation. Key features include:
Flax is a JAX-based library for building neural networks. It’s designed to be both user-friendly and customizable:
Optax simplifies gradient processing and optimization, offering:
Together, these libraries offer a powerful, modular ecosystem for building and training deep learning models efficiently.
However, to learn more about JAX and all of its capabilities, one must first start by implementing the structure on the system. Here, you will get a brief overview of how you can easily install JAX and get on with using these awesome features of JAX.
!pip install --upgrade -q pip jax jaxlib flax optax tensorflow-datasets
Installs the required libraries:
To harness the power of JAX, Flax, and Optax, the first step is to import the necessary libraries into your development environment. This section will guide you through the process of importing these key libraries, ensuring that you have everything set up for the efficient execution of machine learning tasks. By correctly importing JAX, Flax, and Optax, you’re laying the foundation for creating high-performance models that can leverage advanced features like GPU/TPU acceleration and automatic differentiation. Let’s get started with the essential imports!
import jax
import jax.numpy as jnp # JAX NumPy
from flax import linen as nn # The Linen API
from flax.training import train_state
import optax # The Optax gradient processing and optimization library
import numpy as np # Ordinary NumPy
import tensorflow_datasets as tfds # TFDS for MNIST
In this section, we will perform loading and preprocessing of the MNIST dataset which is being a standard dataset used in machine learning. MNIST dataset comprises of handwritten digits, by preparing this correctly, we ensure that the model is in a position to learn from the data. We will also describe how to import the dataset, resize the images and properly structure the data for training and assessment.
def get_datasets():
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
# Split into training/test sets
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
# Convert to floating-points
train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
return train_ds, test_ds
train_ds, test_ds = get_datasets()
We use TFDS to load and preprocess the MNIST dataset:
The function returns train_ds and test_ds dictionaries with keys ‘image’ and ‘label’.
CNNs is the architecture of choice for image classification problems, and in this section we will create a CNN in the jax + flax + optax stack. CNNs are expected to learn spatial hierarchies of image data by themselves due to layers of convolutions. This way, we will explain how to define layers, activation function’s layers, and the last layer which is the output layer for recognizing the digits in the MNIST data set.
class CNN(nn.Module):
@nn.compact
# Provide a constructor to register a new parameter
# and return its initial value
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # Flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x) # There are 10 classes in MNIST
return x
After our Convolutional Neural Network (CNN) has been trained properly, we need to evaluate its performance and do so using the right measures. Now we will discuss about the major observations regarding the model accuracy, loss, etc., specifically on the training and validation set.
def compute_metrics(logits, labels):
loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, num_classes=10)))
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy
}
return metrics
We define metrics to evaluate model performance:
The function returns train_ds and test_ds dictionaries with keys ‘image’ and ‘label’.
We define the functions responsible for training the model on the dataset and evaluating its performance. These functions handle the forward pass, loss calculation, backpropagation, and tracking the model’s accuracy during both training and validation phases.
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image'])
loss = jnp.mean(optax.softmax_cross_entropy(
logits=logits,
labels=jax.nn.one_hot(batch['label'], num_classes=10)))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, batch['label'])
return state, metrics
@jax.jit
def eval_step(params, batch):
logits = CNN().apply({'params': params}, batch['image'])
return compute_metrics(logits, batch['label'])
Training Step:
Evaluation Step:
Both functions are JIT-compiled for faster performance execution.
We integrate the training process into a loop that iteratively trains the model over multiple epochs. During each iteration, the model is updated based on the computed gradients, and performance metrics are tracked to ensure steady progress towards optimization.
def train_epoch(state, train_ds, batch_size, epoch, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size] # Skip an incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics = train_step(state, batch)
batch_metrics.append(metrics)
training_batch_metrics = jax.device_get(batch_metrics)
training_epoch_metrics = {
k: np.mean([metrics[k] for metrics in training_batch_metrics])
for k in training_batch_metrics[0]}
print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))
return state, training_epoch_metrics
def eval_model(model, test_ds):
metrics = eval_step(model, test_ds)
metrics = jax.device_get(metrics)
eval_summary = jax.tree.map(lambda x: x.item(), metrics)
return eval_summary['loss'], eval_summary['accuracy']
This step involves running the training loop and during each epoch the model performance has to be tested. By checking the training and validation metrics, we make sure that the model learning process is going on and, moreover, the model’s ability to generalize data that it has never encountered before.
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
cnn = CNN()
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))['params']
nesterov_momentum = 0.9
learning_rate = 0.001
tx = optax.sgd(learning_rate=learning_rate, nesterov=nesterov_momentum)
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
# Initialize lists to store metrics for graph visualization
training_losses = []
training_accuracies = []
testing_losses = []
testing_accuracies = []
num_epochs = 10
batch_size = 64
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
# Run an optimization step over a training batch
state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
# Evaluate on the test set after each training epoch
test_loss, test_accuracy = eval_model(state.params, test_ds)
print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))
# Store metrics for graph visualization
training_losses.append(train_metrics['loss'])
training_accuracies.append(train_metrics['accuracy'])
testing_losses.append(test_loss)
testing_accuracies.append(test_accuracy)
In this step, we visualize the training and testing metrics such as accuracy and loss over time. This helps to identify trends, diagnose potential issues like overfitting or underfitting, and assess the overall performance of the model during training.
import matplotlib.pyplot as plt
# Graph visualization for training/testing loss and accuracy
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(14, 5))
# Plot for Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, training_losses, label='Training Loss', marker='o')
plt.plot(epochs, testing_losses, label='Testing Loss', marker='o')
plt.title('Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
We will now demonstrate how to use the trained model to make predictions on custom images. This allows you to evaluate the model’s performance on unseen data and test its ability to generalize to new, real-world examples.
from google.colab import files
from PIL import Image
import numpy as np
# Step 1: Upload the image file
uploaded = files.upload()
# Step 2: Process the uploaded image
def load_and_preprocess_image(file_path):
img = Image.open(file_path).convert('L') # Convert to grayscale
img = img.resize((28, 28)) # Resize to 28x28
img = np.array(img) / 255.0 # Normalize pixel values to [0, 1]
img = img.reshape(1, 28, 28, 1) # Add batch and channel dimensions
return img
# Step 3: Load and preprocess each uploaded image
for file_name in uploaded.keys():
test_image = load_and_preprocess_image(file_name)
print(f"Processed image from {file_name}.")
import jax.numpy as jnp
# Convert to JAX array
test_image_jax = jnp.array(test_image, dtype=jnp.float32)
# Step 4: Use your trained model for predictions
logits = state.apply_fn({'params': state.params}, test_image_jax)
prediction = jnp.argmax(logits, axis=-1)
print(f"Predicted class: {prediction[0]}")
# Display the uploaded image
plt.imshow(test_image[0].squeeze(), cmap='gray')
plt.title(f"Predicted Class: {prediction[0]}")
plt.axis('off')
plt.show()
After uploading, the model processes the images to match the expected input format.
This step-by-step guide demonstrated the power and flexibility of JAX, Flax, and Optax in building a robust deep learning pipeline for image classification. By leveraging their unique features like efficient hardware acceleration, modular design, and advanced optimization capabilities, we trained a Convolutional Neural Network (CNN) on the dataset with ease. The integration with TensorFlow Datasets (TFDS) simplified data loading and preprocessing, while visualizing metrics provided valuable insights into the model’s performance.
The pipeline culminated in testing the model on custom images, showcasing its practical application. This approach is not only scalable for more complex datasets but also serves as a foundation for exploring cutting-edge deep learning techniques.
Here is the collab link: Click Here.
A. JAX is a high-performance numerical computing library that offers features like automatic differentiation and GPU/TPU acceleration. We use it here to efficiently compute gradients and execute deep learning operations seamlessly on hardware accelerators.
A. Flax is a lightweight, modular library built on JAX, designed for flexibility and scalability. Its @compact API simplifies model definitions, making it easier to experiment with different architectures while leveraging JAX’s powerful features.
A. Optax offers a comprehensive suite of optimization algorithms and tools for gradient processing, such as SGD with momentum, which efficiently trains the CNN.
A. TFDS simplifies dataset handling by providing pre-built datasets like MNIST, along with tools for automatic downloading, preprocessing, and splitting into training and testing sets.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.