Generative AI, with its remarkable ability to create data that closely resembles real-world examples, has garnered significant attention in recent years. While models like GANs and VAEs have stolen the limelight, a lesser-known gem called “Normalizing Flows” in generative AI has quietly reshaped the generative modeling landscape.
In this article, we embark on a journey into Normalizing Flows, exploring their unique features and applications and providing hands-on Python examples to demystify their inner workings. In this article, we will learn about:
Basic understanding of Normalizing Flows.
Applications of Normalizing Flows, such as Density estimation, Data Generation, Variational Inference, and Data Augmentation.
Python Code example to understand Normalizing flows.
Normalizing Flows, often abbreviated as NFs, are generative models that tackle the challenge of sampling from complex probability distributions. They are rooted in the concept of change of variables from probability theory. The fundamental idea is to start with a simple probability distribution, such as a Gaussian, and apply a series of invertible transformations to transform it into the desired complex distribution gradually.
The key distinguishing feature of Normalizing Flows is their invertibility. Every transformation applied to the data can be reversed, ensuring that both sampling and density estimation are feasible. This property sets them apart from many other generative models.
Anatomy of a Normalizing Flow
Base Distribution: A simple probability distribution (e.g., Gaussian) from which sampling begins.
Transformations: A series of bijective (invertible) transformations that progressively modify the base distribution.
Inverse Transformations: Every transformation has an inverse, allowing for data generation and likelihood estimation.
Final Complex Distribution: The composition of transformations leads to a complex distribution that closely matches the target data distribution.
Applications of Normalizing Flows
Density Estimation: Normalizing Flows excel at density estimation. They can accurately model complex data distributions, making them valuable for anomaly detection and uncertainty estimation.
Data Generation: NFs can generate data samples that resemble real data closely. This ability is crucial in applications like image generation, text generation, and music composition.
Variational Inference: Normalizing Flows plays a vital role in Bayesian machine learning, particularly in Variational Autoencoders (VAEs). They enable more flexible and expressive posterior approximations.
Data Augmentation: NFs can augment datasets by generating synthetic samples, useful when data is scarce.
Let’s Dive into Python: Implementing a Normalizing Flow
We implement a simple 1D Normalizing Flow using Python and the PyTorch library. In this example, we’ll focus on transforming a Gaussian distribution into a more complex distribution.
import torch
import torch.nn as nn
import torch.optim as optim
# Define a bijective transformation
class AffineTransformation(nn.Module):
def __init__(self):
super(AffineTransformation, self).__init__()
self.scale = nn.Parameter(torch.Tensor(1))
self.shift = nn.Parameter(torch.Tensor(1))
def forward(self, x):
return self.scale * x + self.shift, torch.log(self.scale)
# Create a sequence of transformations
transformations = [AffineTransformation() for _ in range(5)]
flow = nn.Sequential(*transformations)
# Define the base distribution (Gaussian)
base_distribution = torch.distributions.Normal(0, 1)
# Sample from the complex distribution
samples = flow(base_distribution.sample((1000,))).squeeze()
Libraries Used
torch: This library is PyTorch, a popular deep-learning framework. It provides tools and modules for building and training neural networks. In the code, we use it to define neural network modules, create tensors, and efficiently perform various mathematical operations on tensors.
torch.nn: This submodule of PyTorch contains classes and functions for building neural networks. In the code, we use it to define the nn.Module class serves as the base class for custom neural network modules.
torch.optim: This submodule of PyTorch provides optimization algorithms commonly used for training neural networks. In the code, it’s used to define an optimizer for training the parameters of the AffineTransformation module. However, the code I provided doesn’t explicitly include the optimizer setup.
AffineTransformation Class
The AffineTransformation class is a custom PyTorch module representing one step in the sequence of transformations used in a Normalizing Flow. Let’s break down its components:
nn.Module: This class is the base class for all custom neural network modules in PyTorch. By inheriting from nn.Module, AffineTransformation becomes a PyTorch module itself, and it can contain learnable parameters (like self.scale and self.shift) and define a forward pass operation.
__init__(self): The class’s constructor method. When an instance of AffineTransformation is created, it initializes two learnable parameters: self.scale and self.shift. These parameters will be optimized during training.
self.scale and self.shift: These are PyTorch nn.Parameter objects. Parameters are tensors automatically tracked by PyTorch’s autograd system, making them suitable for optimization. Here, self.scale and self.shift represents the scaling and shifting factors applied to the input x.
forward(self, x): This method defines the forward pass of the module. When you pass an input tensor x to an instance of AffineTransformation, it computes the transformation using the affine operation self.scale * x + self.shift. Additionally, it returns the logarithm of self.scale. The logarithm is used because it ensures that self.scale remains positive, which is important for invertibility in Normalizing Flows.
In a Normalizing Flow in a Generative AI context, this AffineTransformation class represents a simple invertible transformation applied to the data. Each step in the flow consists of such transformations, which collectively reshape the probability distribution from a simple one (e.g., Gaussian) to a more complex one that closely matches the target distribution of the data. These transformations, when composed, allow for flexible density estimation and data generation.
# Create a sequence of transformations
transformations = [AffineTransformation() for _ in range(5)]
flow = nn.Sequential(*transformations)
In the above code section, we’re creating a sequence of transformations using the AffineTransformation class. This sequence represents the series of invertible transformations that will be applied to our base distribution to make it more complex.
What’s Happening?
Here’s what’s happening:
We’re initializing an empty list called transformations.
We use a list comprehension to create a sequence of AffineTransformation instances. The [AffineTransformation() for _ in range(5)] construct creates a list containing five instances of the AffineTransformation class. Apply these transformations in sequence to our data.
# Define the base distribution (Gaussian)
base_distribution = torch.distributions.Normal(0, 1)
Here, we’re defining a base distribution as our starting point. In this case, we’re using a Gaussian distribution with a mean of 0 and a standard deviation of 1 (i.e., a standard normal distribution). This distribution represents the simple probability distribution from which we’ll start our sequence of transformations.
# Sample from the complex distribution
samples = flow(base_distribution.sample((1000,))).squeeze()
This section involves sampling data from the complex distribution that results from applying our sequence of transformations to the base distribution. Here’s the breakdown:
base_distribution.sample((1000,)): We use the sample method of the base_distribution object to generate 1000 samples from the base distribution. The sequence of transformations will transform these samples to create complex data.
flow(…): The flow object represents the sequence of transformations we created earlier. We apply these transformations in sequence by passing the samples from the base distribution through the flow.
squeeze(): This removes any unnecessary dimensions from the generated samples. People often use it when dealing with PyTorch tensors to ensure that the shape matches the desired format.
Conclusion
NFs are generative models that sculpt complex data distributions by progressively transforming a simple base distribution through a series of invertible operations. The article explores the core components of NFs, including base distributions, bijective transformations, and the invertibility that underpins their power. It highlights their pivotal role in density estimation, data generation, variational inference, and data augmentation.
Key Takeaways
The key takeaways from the article are:
Normalizing Flows are generative models that transform a simple base distribution into a complex target distribution through a series of invertible transformations.
They find applications in density estimation, data generation, variational inference, and data augmentation.
Normalizing Flows offer flexibility and interpretability, making them a powerful tool for capturing complex data distributions.
Implementing a Normalizing Flow involves defining bijective transformations and sequentially composing them.
Exploring Normalizing Flows unveils a versatile approach to generative modeling, offering new possibilities for creativity and understanding complex data distributions.
Frequently Asked Questions
Q1: Are Normalizing Flows limited to 1D data?
A. Yes, you can apply Normalizing Flows to high-dimensional data as well. Our example was in 1D for simplicity, but people commonly use NFs in tasks like image generation and other high-dimensional applications.
Q2: How do Normalizing Flows compare to GANs and VAEs?
A. While GANs focus on generating data and VAEs on probabilistic modeling, Normalizing Flows excel in density estimation and flexible data generation. They offer a different perspective on generative modeling.
Q3: Are Normalizing Flows computationally expensive?
A. The computational cost depends on the transformations’ complexity and the data’s dimensionality. In practice, NFs can be computationally expensive for high-dimensional data.
Q4: Can Normalizing Flows handle discrete data?
A. NFs are primarily designed for continuous data. Adapting them for discrete data can be challenging and may require additional techniques.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.
Competent and passionate professional holding over 3 years of Python, Data Science, Data Analytics, and ML experience with recent experience in Prompt Engineering. I love writing and one of my blogs at Analytics Vidhya was among the top-3 winners of the Data Science Blogathon, read by 700+ users.
We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.
Show details
Powered By
Cookies
This site uses cookies to ensure that you get the best experience possible. To learn more about how we use cookies, please refer to our Privacy Policy & Cookies Policy.
brahmaid
It is needed for personalizing the website.
csrftoken
This cookie is used to prevent Cross-site request forgery (often abbreviated as CSRF) attacks of the website
Identityid
Preserves the login/logout state of users across the whole site.
sessionid
Preserves users' states across page requests.
g_state
Google One-Tap login adds this g_state cookie to set the user status on how they interact with the One-Tap modal.
MUID
Used by Microsoft Clarity, to store and track visits across websites.
_clck
Used by Microsoft Clarity, Persists the Clarity User ID and preferences, unique to that site, on the browser. This ensures that behavior in subsequent visits to the same site will be attributed to the same user ID.
_clsk
Used by Microsoft Clarity, Connects multiple page views by a user into a single Clarity session recording.
SRM_I
Collects user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
SM
Use to measure the use of the website for internal analytics
CLID
The cookie is set by embedded Microsoft Clarity scripts. The purpose of this cookie is for heatmap and session recording.
SRM_B
Collected user data is specifically adapted to the user or device. The user can also be followed outside of the loaded website, creating a picture of the visitor's behavior.
_gid
This cookie is installed by Google Analytics. The cookie is used to store information of how visitors use a website and helps in creating an analytics report of how the website is doing. The data collected includes the number of visitors, the source where they have come from, and the pages visited in an anonymous form.
_ga_#
Used by Google Analytics, to store and count pageviews.
_gat_#
Used by Google Analytics to collect data on the number of times a user has visited the website as well as dates for the first and most recent visit.
collect
Used to send data to Google Analytics about the visitor's device and behavior. Tracks the visitor across devices and marketing channels.
AEC
cookies ensure that requests within a browsing session are made by the user, and not by other sites.
G_ENABLED_IDPS
use the cookie when customers want to make a referral from their gmail contacts; it helps auth the gmail account.
test_cookie
This cookie is set by DoubleClick (which is owned by Google) to determine if the website visitor's browser supports cookies.
_we_us
this is used to send push notification using webengage.
WebKlipperAuth
used by webenage to track auth of webenagage.
ln_or
Linkedin sets this cookie to registers statistical data on users' behavior on the website for internal analytics.
JSESSIONID
Use to maintain an anonymous user session by the server.
li_rm
Used as part of the LinkedIn Remember Me feature and is set when a user clicks Remember Me on the device to make it easier for him or her to sign in to that device.
AnalyticsSyncHistory
Used to store information about the time a sync with the lms_analytics cookie took place for users in the Designated Countries.
lms_analytics
Used to store information about the time a sync with the AnalyticsSyncHistory cookie took place for users in the Designated Countries.
liap
Cookie used for Sign-in with Linkedin and/or to allow for the Linkedin follow feature.
visit
allow for the Linkedin follow feature.
li_at
often used to identify you, including your name, interests, and previous activity.
s_plt
Tracks the time that the previous page took to load
lang
Used to remember a user's language setting to ensure LinkedIn.com displays in the language selected by the user in their settings
s_tp
Tracks percent of page viewed
AMCV_14215E3D5995C57C0A495C55%40AdobeOrg
Indicates the start of a session for Adobe Experience Cloud
s_pltp
Provides page name value (URL) for use by Adobe Analytics
s_tslv
Used to retain and fetch time since last visit in Adobe Analytics
li_theme
Remembers a user's display preference/theme setting
li_theme_set
Remembers which users have updated their display / theme preferences
We do not use cookies of this type.
_gcl_au
Used by Google Adsense, to store and track conversions.
SID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SAPISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
__Secure-#
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
APISID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
SSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
HSID
Save certain preferences, for example the number of search results per page or activation of the SafeSearch Filter. Adjusts the ads that appear in Google Search.
DV
These cookies are used for the purpose of targeted advertising.
NID
These cookies are used for the purpose of targeted advertising.
1P_JAR
These cookies are used to gather website statistics, and track conversion rates.
OTZ
Aggregate analysis of website visitors
_fbp
This cookie is set by Facebook to deliver advertisements when they are on Facebook or a digital platform powered by Facebook advertising after visiting this website.
fr
Contains a unique browser and user ID, used for targeted advertising.
bscookie
Used by LinkedIn to track the use of embedded services.
lidc
Used by LinkedIn for tracking the use of embedded services.
bcookie
Used by LinkedIn to track the use of embedded services.
aam_uuid
Use these cookies to assign a unique ID when users visit a website.
UserMatchHistory
These cookies are set by LinkedIn for advertising purposes, including: tracking visitors so that more relevant ads can be presented, allowing users to use the 'Apply with LinkedIn' or the 'Sign-in with LinkedIn' functions, collecting information about how visitors use the site, etc.
li_sugr
Used to make a probabilistic match of a user's identity outside the Designated Countries
MR
Used to collect information for analytics purposes.
ANONCHK
Used to store session ID for a users session to ensure that clicks from adverts on the Bing search engine are verified for reporting purposes and for personalisation
We do not use cookies of this type.
Cookie declaration last updated on 24/03/2023 by Analytics Vidhya.
Cookies are small text files that can be used by websites to make a user's experience more efficient. The law states that we can store cookies on your device if they are strictly necessary for the operation of this site. For all other types of cookies, we need your permission. This site uses different types of cookies. Some cookies are placed by third-party services that appear on our pages. Learn more about who we are, how you can contact us, and how we process personal data in our Privacy Policy.