Memory-Efficient Model Weight Loading in PyTorch

Janvi Kumari Last Updated : 21 Oct, 2024
11 min read

I recently came across a post by Sebastian that caught my attention, and I wanted to dive deeper into its content. As models grow larger and more complex, efficiently managing memory during model loading becomes increasingly important, especially when working with limited GPU or CPU resources. In his post, Sebastian covers practical tips for loading larger pre-trained or fine-tuned models in constrained memory environments, which is particularly relevant when working with PyTorch.

This guide emphasizes how to handle situations where models are saved using torch.save(model.state_dict(), "model.pth") and later need to be loaded for continued pre-training or further fine-tuning. While the examples focus on a large language model (LLM), Sebastian’s methods are broadly applicable to any PyTorch model. Additionally, they provide valuable insights into memory-efficient model weight loadingy in PyTorch, helping optimize memory usage during the loading process.

Overview

  • Efficient memory management is crucial for loading large neural networks in PyTorch, especially on systems with limited GPU or CPU resources.
  • Instead of loading the entire model at once, you can load weights incrementally.Normally, calling model.to(device) moves all the model’s parameters to the device (like a GPU), which can consume significant memory.
  • PyTorch introduced the “meta” device, which allows for the creation of tensors without using memory.
  • By utilizing the meta device, you can load weights directly into GPU memory, bypassing the CPU and optimizing memory usage.

Initial Setup: Environment Check

Before diving into the specifics, let’s ensure that the necessary packages and versions are available. Here’s a snippet that checks for the version of PyTorch and other useful tools.

from importlib.metadata import version

pkgs = [
    "torch",
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

Benchmark Utilities for Memory Tracking

The first step is to set up a utility to track GPU memory (VRAM). Tracking memory usage helps in understanding how different methods impact memory load during model loading and inference. Later, we will also track the system’s RAM (CPU memory).

Here’s the utility code for GPU memory tracking:


import gc
import time
import torch

def start_memory_tracking():
    """Initialize GPU memory tracking."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    else:
        print("This notebook is intended for CUDA GPUs but CUDA is not available.")

def print_memory_usage():
    max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert bytes to GB
    print(f"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB")

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(3)  # Allow time for memory to clear
    torch.cuda.reset_peak_memory_stats()
    max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 3)
    print(f"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB")

These functions help track GPU memory usage before, during, and after model operations. The cleanup() function is especially useful for clearing unused memory to avoid running out of VRAM.

Model Setup

Next, we set up the model. For demonstration, we will use the “GPT-2 large” model (though you can adjust the model size to suit your memory constraints). By changing the configuration, the model size can range from “gpt2-small” (124M parameters) to “gpt2-xl” (1558M parameters).

Here’s the configuration:

from previous_chapters import GPTModel

BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "drop_rate": 0.0,        # Dropout rate
    "qkv_bias": True         # Query-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

CHOOSE_MODEL = "gpt2-xl (1558M)"
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])

This configuration allows flexibility in choosing models based on available memory resources. For lower memory consumption, selecting a smaller variant (like gpt2-small) is advisable.

Once the model configuration is set up, the next steps will dive into loading, managing, and optimizing the model weights for efficient memory utilization.

Tracking GPU Memory During Model Loading

Let’s now put the GPU memory tracking utilities into action. First, we initialize memory tracking and load the model to observe memory consumption. The code below tracks GPU memory usage as we load and run a GPT model.

start_memory_tracking()

model = GPTModel(BASE_CONFIG)
device = torch.device("cuda")
model.to(device)

print_memory_usage()
# Output: Maximum GPU memory allocated: 6.4 GB

This shows that loading and placing the model onto the GPU consumes around 6.4 GB of VRAM, which is typical for larger models like GPT-2. However, this is just the initial setup.

Running the Model

To verify that everything works correctly, we pass a simple input tensor to the model. Although we aren’t tracking memory during this step, it’s essential to check that the model operates as expected.

# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()

with torch.no_grad():
    model(test_input)

Saving the Model

Now, imagine we are pretraining the model (or finetuning it). For this example, we skip the actual pretraining process and directly save the initialized model. The following code saves the model’s weights using torch.save().

# Training code would go here...

model.train()
torch.save(model.state_dict(), "model.pth")

Memory Cleanup

After saving the model, it’s important to free up GPU memory to ensure efficient resource management in subsequent operations. By deleting the model and the test input tensor, and then running our cleanup() function, we clear up VRAM.

del model, test_input
cleanup()
# Output: Maximum GPU memory allocated: 0.0 GB

At this point, the GPU memory usage is reset to zero, as expected.

Loading Pretrained Model Weights

The next step involves reloading the saved model weights to continue training or finetuning. However, loading pretrained weights requires more GPU memory than initializing a fresh model because the model’s weights are loaded twice: once when loading the model itself, and again when loading the weights into memory.

# Start tracking memory
start_memory_tracking()

# Recreate the model architecture
model = GPTModel(BASE_CONFIG)
model.to(device)

# Load the saved state_dict
model.load_state_dict(
    torch.load("model.pth", map_location=device, weights_only=True)
)
model.to(device)
model.eval()

print_memory_usage()
# Output: Maximum GPU memory allocated: 12.8 GB

The GPU memory usage has now doubled compared to the initial load, peaking at 12.8 GB. This happens because, for a short period, both the original model and the newly loaded weights are held in memory. Eventually, the loaded weights are copied into the model, and the temporary state_dict is discarded. However, this memory spike can cause issues when working with limited resources.

Resetting GPU Memory

After loading the model weights and testing it, it’s essential to reset the GPU memory once again. Testing the model ensures it works as expected, and clearing memory is crucial for efficient resource usage.

# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()

with torch.no_grad():
    model(test_input)

del model, test_input
cleanup()
# Output: Maximum GPU memory allocated: 0.0 GB

This reset brings GPU memory usage back to zero, ensuring a clean state for future operations.

Loading Weights Sequentially

One effective workaround for the problem of double memory usage when loading model weights is sequential loading. Instead of loading both the model and weights simultaneously into GPU memory, we can load the model first, keep the weights in CPU memory, and then copy each parameter one by one to the GPU. This method significantly reduces the peak memory usage.

Here’s how to implement sequential weight loading:

Step-by-Step Breakdown:

  1. Load the Model onto the GPU: First, we load the model architecture into GPU memory, as usual.
  2. Load the Weights onto the CPU: The model weights are loaded onto CPU memory, avoiding the initial memory spike caused by moving both the model and the weights to the GPU.
  3. Copy Weights Parameter by Parameter: Each weight is then copied sequentially from the CPU to GPU, meaning that at no point do we have both the model and the full state_dict in GPU memory.

The code below demonstrates this approach:

start_memory_tracking()

# Load the model into GPU memory
model = GPTModel(BASE_CONFIG).to(device)

# Load the model's saved state_dict onto the CPU
state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)

print_memory_usage()
# Output: Maximum GPU memory allocated: 6.4 GB

# Copy each parameter to GPU memory one by one
with torch.no_grad():
    for name, param in model.named_parameters():
        if name in state_dict:
            param.copy_(state_dict[name].to(device))
        else:
            print(f"Warning: {name} not found in state_dict.")

print_memory_usage()
# Output: Maximum GPU memory allocated: 6.7 GB

Memory Comparison:

  • Initially, the model alone occupies
  • As we copy each parameter sequentially, the memory increases slightly to

However, this is a much smaller peak compared to the 12.8 GB required when loading everything at once. By sequentially loading the weights, we avoid having both the full model and the full set of weights in GPU memory simultaneously.

Model Testing & Memory Reset:

After copying the weights, we test the model to ensure everything works as expected. Finally, we reset the GPU memory to clear any lingering objects, just as we did in previous steps.

# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()

with torch.no_grad():
    model(test_input)

# Clean up GPU memory
del model, test_input, state_dict, param
cleanup()
# Output: Maximum GPU memory allocated: 0.0 GB

Loading the Model with Low CPU Memory

In the previous section, we reduced GPU memory usage by loading model weights into CPU memory first and then sequentially copying them into the GPU. But what if the machine has limited CPU memory and larger GPU memory? To tackle this, we can use PyTorch’s “meta” device approach, which is ideal for machines with constrained CPU resources.

Meta Device: A Smart Tradeoff

The “meta” device is a special device type in PyTorch that creates “meta” tensors. These tensors represent the shape and type of the data without allocating memory for the data itself. This allows us to define models without consuming CPU or GPU memory until necessary.

Using the meta device, we can first initialize the model without any memory allocation, and then load the model weights directly into GPU memory, bypassing the CPU.

Monitoring CPU Memory Usage

Before we dive into the meta device approach, we will define a utility function to track CPU memory usage:


import os
import psutil
from threading import Thread

def memory_usage_in_gb(func, *args, **kwargs):
    process = psutil.Process(os.getpid())
    baseline_mem = process.memory_info().rss / 1024 ** 3  # in GB
    mem_usage = []
    done = False

    def monitor_memory():
        while not done:
            mem_usage.append(process.memory_info().rss / 1024 ** 3)  # Convert to GB
            time.sleep(0.1)

    t = Thread(target=monitor_memory)
    t.start()

    func(*args, **kwargs)
    done = True
    t.join()

    peak_mem_usage_gb = max(mem_usage) - baseline_mem
    return peak_mem_usage_gb

Now that we can measure CPU memory usage, let’s track the memory used during the sequential weight loading approach from the previous section:

def load_sequentially():
    start_memory_tracking()

    model = GPTModel(BASE_CONFIG).to(device)
    state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)

    print_memory_usage()

    # Sequentially copy weights to the model's parameters
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in state_dict:
                param.copy_(state_dict[name].to(device))

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(load_sequentially)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")

This approach outputs:

  • Maximum GPU memory allocated: 6.7 GB
  • Maximum CPU memory allocated: 6.3 GB

Meta Device Approach

To further reduce CPU memory usage, we can use the meta device to load the model without allocating memory until we actually need it. Here’s the implementation:

def load_sequentially_with_meta():
    start_memory_tracking()

    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)

    model = model.to_empty(device=device)
    state_dict = torch.load("model.pth", map_location=device, weights_only=True)

    print_memory_usage()

    # Sequentially copy weights to the model's parameters
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in state_dict:
                param.copy_(state_dict[name])

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")

Memory Usage with Meta Device:

  • Maximum GPU memory allocated: 12.8 GB
  • Maximum CPU memory allocated: 1.3 GB

By using the meta device and directly loading the model weights into GPU memory, we drastically reduce CPU memory consumption from 6.3 GB to just 1.3 GB.

Comparison with Baseline

Finally, let’s compare this method with the simple PyTorch weight loading method, where no meta device or sequential loading is used:

def baseline():
    start_memory_tracking()

    model = GPTModel(BASE_CONFIG)
    model.to(device)
    model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
    model.to(device)
    model.eval()

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(baseline)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")

For this approach:

  • Maximum GPU memory allocated: 12.8 GB
  • Maximum CPU memory allocated: 4.4 GB

Using mmap=True for Efficient Model Loading

For more advanced users of PyTorch, there’s an alternative approach to handling memory constraints when loading large models—using the mmap=True setting in torch.load(). This setting leverages memory-mapped file I/O, which allows the model to read data directly from disk without fully loading it into RAM. This is particularly useful on systems with limited CPU memory, as it minimizes the memory footprint during model loading.

What is mmap=True?

Memory-mapped I/O (mmap) is a mechanism that enables a file to be read directly from disk by mapping it into the virtual address space. Instead of loading the entire model into RAM, PyTorch can load parts of the model on demand, effectively reducing memory usage. This can be particularly advantageous when dealing with large pretrained or finetuned models, such as GPT-2 or GPT-3, on machines with limited resources.

The mmap=True option can be added when calling torch.load() to achieve this behavior.

Example Implementation of mmap=True

Let’s see how the mmap=True option works in practice. Below is a sample implementation where we load a model using this setting:

def best_practices():
    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)

    model.load_state_dict(
        torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
        assign=True
    )

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")

Results with mmap=True

  • Maximum GPU memory allocated: 6.4 GB
  • Maximum CPU memory allocated: 5.9 GB

Here, we see that the GPU memory usage remains efficient (6.4 GB), and CPU memory usage is fairly high because the machine has enough CPU RAM to support it. However, on a system with limited CPU RAM, the mmap=True approach would use less memory by avoiding loading the full model into RAM.

When to Use mmap=True

The mmap=True option is especially helpful in the following scenarios:

  • Limited CPU RAM
  • Disk I/O Speed

Performance Considerations

At first glance, the mmap=True approach might seem less efficient compared to the sequential weight loading approach. However, for machines with limited CPU memory, mmap=True can be a game-changer, providing an effective way to load large models without overwhelming the CPU’s available memory.

By using mmap=True, you’re balancing disk access with memory availability, which can help in environments where memory is scarce but disk I/O is fast.

Other Methods for Model Weight Loading

In this notebook, we’ve focused on simple, built-in methods for efficiently loading model weights in PyTorch, particularly when memory (either GPU or CPU) is constrained. The recommended method for managing limited CPU memory is the mmap=True approach, as explained previously.

However, if you’re dealing with extreme memory limitations or need more control over the process, there’s another brute-force approach: saving and loading each weight tensor individually.

Saving Model Weights Individually

Instead of saving the entire state_dict as a single file, this method stores each model parameter (tensor) separately. This allows you to load each parameter one at a time, preventing the need to hold the entire model in memory simultaneously.

Here’s how you can save the model weights individually:

model = GPTModel(BASE_CONFIG)
# Assume `model` is your trained model
state_dict = model.state_dict()

# Create a directory to store individual parameter files
os.makedirs("model_parameters", exist_ok=True)

# Save each parameter tensor separately
for name, param in state_dict.items():
    torch.save(param.cpu(), f"model_parameters/{name}.pt")

del model  # Free up GPU memory

This breaks the model into individual components, saving each tensor to its own file in the “model_parameters” directory.

Loading Weights Individually

Now, let’s see how you can load these weights one-by-one to avoid overwhelming memory usage.

def load_individual_weights():
    start_memory_tracking()

    with torch.device("meta"):
        model = GPTModel(BASE_CONFIG)

    model = model.to_empty(device=device)

    print_memory_usage()
    param_dir = "model_parameters"

    with torch.no_grad():
        for name, param in model.named_parameters():
            weight_path = os.path.join(param_dir, f"{name}.pt")
            if os.path.exists(weight_path):
                param_data = torch.load(weight_path, map_location="cpu", weights_only=True)
                param.copy_(param_data.to(device))  # Move tensor to GPU
                del param_data  # Free memory after copying
            else:
                print(f"Warning: {name} not found in {param_dir}.")

    print_memory_usage()

Results from Individual Weight Loading

  • Maximum GPU memory allocated: 6.4 GB
  • Maximum CPU memory allocated: 0.3 GB

The memory footprint here is significantly reduced—both on the GPU and CPU. By loading weights individually, you ensure that no unnecessary memory is consumed at any stage, making this approach ideal for extremely memory-limited environments.

When to Use This Method

  • Extreme Memory Limitations

When CPU and GPU memory are both highly constrained, this method offers precise control, ensuring that only one parameter tensor is loaded into memory at any given time.

  • Low Resources

On machines where you cannot afford to use more than minimal resources, this brute-force method provides a solution to ensure you can load even the largest models.

Performance Considerations

The trade-off here is performance. Since each tensor is loaded separately, this method incurs extra disk I/O, which may slow down the loading process compared to methods that load the entire model or larger chunks of data at once.

Conclusion

When working with large models, such as GPT variants or other deep learning models, memory efficiency is crucial. Techniques like sequential weight loading, using the meta device, and enabling mmap=True help reduce memory usage on both CPU and GPU. These methods, known for memory-efficient model weight loading in PyTorch, are highly flexible and can be adapted depending on the specific constraints of your hardware environment, whether you have limited CPU RAM, GPU VRAM, or both.

By employing these techniques, you can work with large models even on constrained hardware, ensuring smooth model training and fine-tuning workflows.

Hope you like the article! Memory-efficient model weight loading in PyTorch helps save resources. Using memory efficient model weight loading in Python can reduce overhead. For a memory efficient model weight loading in PyTorch example, try using torch.load() with memory mapping to lower RAM usage.

Q1.What is the importance of memory-efficient model loading in PyTorch?

As deep learning models grow larger (especially models like GPT-2, GPT-3), efficiently loading these models becomes essential to prevent running out of GPU or CPU memory. Memory-efficient loading allows you to work with large models even in constrained environments.

Q2.How can I track GPU memory usage during model loading in PyTorch?

You can use the functions torch.cuda.reset_peak_memory_stats() and torch.cuda.max_memory_allocated() to track GPU memory usage before, during, and after loading or training models. The provided utility functions help monitor memory usage efficiently.

Q3.What is sequential weight loading in PyTorch, and how does it help?

Sequential weight loading involves loading the model architecture onto the GPU and then transferring weights one at a time from CPU to GPU. This reduces the peak memory usage compared to loading both the model and its weights at once, helping manage limited GPU memory.

Q4.How do I reduce memory usage in PyTorch?

Use lower precision: float16, mixed precision.
Optimize tensor operations: avoid copies, efficient shapes, views.
Gradient accumulation: update weights less frequently.
Reduce model size: prune connections, quantize weights, smaller models.
Optimize data loading: data loaders, prefetching, memory-mapped files.
GPU memory efficiency: monitor usage, free unused memory, multiple GPUs.
Advanced techniques: knowledge distillation, low-rank approximation.

Q5.What is the “meta” device in PyTorch, and how does it help with memory constraints?

The “meta” device allows you to initialize models without allocating memory for their parameters. This is useful when you have limited CPU memory since you can later load weights directly into the GPU, bypassing the need for large memory allocations on the CPU.

Hi, I am Janvi, a passionate data science enthusiast currently working at Analytics Vidhya. My journey into the world of data began with a deep curiosity about how we can extract meaningful insights from complex datasets.

Responses From Readers

Clear

null null
null null

Good Article very informative

null null
null null

Good Article

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details