Flax is an advanced neural network library built on top of JAX, aimed at giving researchers and developers a flexible, high-performance toolset for building complex machine learning models. Flax’s seamless integration with JAX enables automatic differentiation, Just-In-Time (JIT) compilation, and support for hardware accelerators, making it ideal for both experimental research and production.
This blog will explore Flax’s core features, compare them to other frameworks, and provide a practical example using Flax’s functional programming approach.
This article was published as a part of the Data Science Blogathon.
Flax is a high-performance neural network library built on top of JAX, designed to provide researchers and developers with the flexibility and efficiency needed to build cutting-edge machine learning models. Flax leverages JAX’s capabilities, such as automatic differentiation and Just-In-Time (JIT) compilation, to offer a powerful framework for both research and production environments.
Flax distinguishes itself from other deep learning frameworks like TensorFlow, PyTorch, and Keras through its unique design principles:
Also read: Flax
Before building models with Flax, it’s essential to set up your development environment with the necessary libraries. We’ll install the latest versions of JAX, JAXlib, and Flax. JAX is the backbone that provides high-performance numerical computing, while Flax builds upon it to offer a flexible neural network framework.
# Install the latest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
Explanation:
Linear regression is a foundational machine learning technique used to model the relationship between a dependent variable and one or more independent variables. In Flax, we can implement linear regression using a single dense (fully connected) layer.
First, let’s instantiate a dense layer with Flax’s Linen API.
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)
Explanation:
In Flax, model parameters are not stored within the model itself. Instead, you need to initialize them using a random key and dummy input data. This process leverages Flax’s lazy initialization, where parameter shapes are inferred based on the input data.
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes
Explanation:
Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.
After initializing the parameters, you can perform a forward pass to compute the model’s output for a given input.
model.apply(params, x)
Explanation:
With the model initialized, we can perform gradient descent to train our linear regression model. We’ll generate synthetic data and define a mean squared error (MSE) loss function.
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5
# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
Explanation:
Next, we’ll define the mean squared error (MSE) loss function and perform gradient descent using JAX’s JIT compilation for efficiency.
# Define the MSE loss function.
@jax.jit
def mse(params, x_batched, y_batched):
# Define the squared loss for a single pair (x, y)
def squared_error(x, y):
pred = model.apply(params, x)
return jnp.inner(y - pred, y - pred) / 2.0
# Vectorize the previous to compute the average of the loss on all samples.
return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
Explanation:
We’ll set the learning rate and define functions to compute gradients and update model parameters.
learning_rate = 0.3 # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
for i in range(101):
# Perform one gradient update.
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = update_params(params, learning_rate, grads)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
Explanation:
Finally, we’ll execute the training loop, performing parameter updates and monitoring the loss.
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print('Loss step {}: '.format(i), loss_val)
Explanation:
Benefits of Using Optax:
After training, you may want to save your model’s parameters for later use or deployment. Flax provides robust serialization utilities to facilitate this process.
from flax import serialization
# Serialize parameters to bytes.
bytes_output = serialization.to_bytes(params)
# Serialize parameters to a dictionary.
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
Explanation:
Using the from_bytes method with a parameter template to load the model parameters back.
# Load the model back using the serialized bytes.
loaded_params = serialization.from_bytes(params, bytes_output)
Flax’s flexibility shines when defining custom models beyond simple linear regressions. This section’ll explore how to create custom multi-layer perceptrons (MLPs) and manage state within your models.
Modules in Flax are subclasses of nn.Module and represent layers or entire models. Here’s how to define a custom MLP with a sequence of dense layers and activation functions.
class ExplicitMLP(nn.Module):
features: Sequence[int]
def setup(self):
# we automatically know what to do with lists, dicts of submodules
self.layers = [nn.Dense(feat) for feat in self.features]
# for single submodules, we would just write:
# self.layer1 = nn.Dense(feat1)
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
Explanation:
Attempting to call the model directly without using apply will result in an error:
try:
y = model(x) # Returns an error
except AttributeError as e:
print(e)
Explanation:
An alternative and more concise way to define submodules is by using the @nn.compact decorator within the __call__ method.
class SimpleMLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, name=f'layers_{i}')(x)
if i != len(self.features) - 1:
x = nn.relu(x)
# providing a name is optional though!
# the default autonames would be "Dense_0", "Dense_1", ...
return x
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))
model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
Explanation:
Differences Between setup and @nn.compact:
Sometimes, you might need to define custom layers not provided by Flax. Here’s how to create a simple dense layer from scratch using the @nn.compact approach.
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros_init()
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init, # Initialization function
(inputs.shape[-1], self.features)) # Shape info.
y = jnp.dot(inputs, kernel)
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y
key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4, 4))
model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameters:\n', params)
print('output:\n', y)
Explanation:
Key Points:
In addition to parameters, neural networks often maintain state variables, such as running statistics in batch normalization. Flax allows you to manage these variables using the variable method.
Example: Bias Adder with Running Mean
class BiasAdderWithRunningMean(nn.Module):
decay: float = 0.99
@nn.compact
def __call__(self, x):
# Check if 'mean' variable is initialized.
is_initialized = self.has_variable('batch_stats', 'mean')
# Initialize running average of the mean.
ra_mean = self.variable('batch_stats', 'mean',
lambda s: jnp.zeros(s),
x.shape[1:])
# Initialize bias parameter.
bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
if is_initialized:
ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)
return x - ra_mean.value + bias
# Initialize and apply the model.
key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10, 5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
Explanation:
Handling both parameters and state variables (like running means) can be complex. Here’s an example of integrating parameter updates with state variable updates using Optax.
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
from functools import partial
@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):
def loss(params):
y, updated_state = apply_fn({'params': params, **state},
x, mutable=list(state.keys()))
l = ((x - y) ** 2).sum()
return l, updated_state
(l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return opt_state, params, state
x = jnp.ones((10,5))
variables = model.init(random.key(0), x)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
for _ in range(3):
opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
print('Updated state: ', state)
Explanation:
Note: The function signature can be verbose and may not work with jax.jit() directly because some function arguments are not “valid JAX types.” Flax provides a convenient wrapper called TrainState to simplify this process. Refer to flax.training.train_state.TrainState for more information.
JAX released an experimental converter called jax2tf, which allows converting trained Flax models into TensorFlow SavedModel format (so it can be used for TF Hub, TF.lite, TF.js, or other downstream applications). The repository contains more documentation and has various examples for Flax.
Flax is a versatile and powerful neural network library that leverages JAX’s high-performance capabilities. From setting up simple linear regression models to defining complex custom architectures and managing state, Flax provides a flexible framework for research and production environments.
In this guide, we covered:
By mastering these fundamentals, you’re well-equipped to harness Flax’s full potential in your machine-learning projects. Whether you’re conducting academic research, developing production-ready models, or exploring innovative architectures, Flax offers the tools and flexibility to support your endeavours.
Also, if you are looking for an AI/ML course online, then explore: Certified AI & ML BlackBelt PlusProgram
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.
Ans. Flax is an advanced neural network library built on JAX, designed for high flexibility and performance. It is used by researchers and developers to build complex machine learning models efficiently, leveraging JAX’s automatic differentiation and JIT compilation for optimized computation.
Ans. Flax stands out due to its adoption of a functional programming paradigm, where models are treated as pure functions without hidden state. This promotes ease of debugging and reproducibility. It also has deep integration with JAX, enabling seamless use of transformations like jit, grad, and vmap for enhanced optimization.
Ans. The Linen API is Flax’s high-level, user-friendly API for defining neural network layers and models. It emphasizes clarity and modularity, making building, understanding, and extending complex architectures easier.
Ans. Optax library provides advanced gradient processing and optimization tools for JAX. When used with Flax, it simplifies the training process through composable optimizers, reducing manual coding and enhancing flexibility with support for a variety of optimization algorithms.
Ans. Flax uses immutable data structures like FrozenDict for parameter management, ensuring functional purity. Model state, such as running statistics for batch normalization, can be managed using collections and updated with the mutable argument during the forward pass.