I recently came across a post by Sebastian that caught my attention, and I wanted to dive deeper into its content. As models grow larger and more complex, efficiently managing memory during model loading becomes increasingly important, especially when working with limited GPU or CPU resources. In his post, Sebastian covers practical tips for loading larger pre-trained or fine-tuned models in constrained memory environments, which is particularly relevant when working with PyTorch.
This guide emphasizes how to handle situations where models are saved using torch.save(model.state_dict(), "model.pth")
and later need to be loaded for continued pre-training or further fine-tuning. While the examples focus on a large language model (LLM), Sebastian’s methods are broadly applicable to any PyTorch model. Additionally, they provide valuable insights into memory-efficient model weight loadingy in PyTorch, helping optimize memory usage during the loading process.
model.to(device)
moves all the model’s parameters to the device (like a GPU), which can consume significant memory.Before diving into the specifics, let’s ensure that the necessary packages and versions are available. Here’s a snippet that checks for the version of PyTorch and other useful tools.
from importlib.metadata import version
pkgs = [
"torch",
]
for p in pkgs:
print(f"{p} version: {version(p)}")
The first step is to set up a utility to track GPU memory (VRAM). Tracking memory usage helps in understanding how different methods impact memory load during model loading and inference. Later, we will also track the system’s RAM (CPU memory).
Here’s the utility code for GPU memory tracking:
import gc
import time
import torch
def start_memory_tracking():
"""Initialize GPU memory tracking."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
else:
print("This notebook is intended for CUDA GPUs but CUDA is not available.")
def print_memory_usage():
max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # Convert bytes to GB
print(f"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB")
def cleanup():
gc.collect()
torch.cuda.empty_cache()
time.sleep(3) # Allow time for memory to clear
torch.cuda.reset_peak_memory_stats()
max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB")
These functions help track GPU memory usage before, during, and after model operations. The cleanup() function is especially useful for clearing unused memory to avoid running out of VRAM.
Next, we set up the model. For demonstration, we will use the “GPT-2 large” model (though you can adjust the model size to suit your memory constraints). By changing the configuration, the model size can range from “gpt2-small” (124M parameters) to “gpt2-xl” (1558M parameters).
Here’s the configuration:
from previous_chapters import GPTModel
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
CHOOSE_MODEL = "gpt2-xl (1558M)"
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
This configuration allows flexibility in choosing models based on available memory resources. For lower memory consumption, selecting a smaller variant (like gpt2-small) is advisable.
Once the model configuration is set up, the next steps will dive into loading, managing, and optimizing the model weights for efficient memory utilization.
Let’s now put the GPU memory tracking utilities into action. First, we initialize memory tracking and load the model to observe memory consumption. The code below tracks GPU memory usage as we load and run a GPT model.
start_memory_tracking()
model = GPTModel(BASE_CONFIG)
device = torch.device("cuda")
model.to(device)
print_memory_usage()
# Output: Maximum GPU memory allocated: 6.4 GB
This shows that loading and placing the model onto the GPU consumes around 6.4 GB of VRAM, which is typical for larger models like GPT-2. However, this is just the initial setup.
To verify that everything works correctly, we pass a simple input tensor to the model. Although we aren’t tracking memory during this step, it’s essential to check that the model operates as expected.
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()
with torch.no_grad():
model(test_input)
Now, imagine we are pretraining the model (or finetuning it). For this example, we skip the actual pretraining process and directly save the initialized model. The following code saves the model’s weights using torch.save().
# Training code would go here...
model.train()
torch.save(model.state_dict(), "model.pth")
After saving the model, it’s important to free up GPU memory to ensure efficient resource management in subsequent operations. By deleting the model and the test input tensor, and then running our cleanup() function, we clear up VRAM.
del model, test_input
cleanup()
# Output: Maximum GPU memory allocated: 0.0 GB
At this point, the GPU memory usage is reset to zero, as expected.
The next step involves reloading the saved model weights to continue training or finetuning. However, loading pretrained weights requires more GPU memory than initializing a fresh model because the model’s weights are loaded twice: once when loading the model itself, and again when loading the weights into memory.
# Start tracking memory
start_memory_tracking()
# Recreate the model architecture
model = GPTModel(BASE_CONFIG)
model.to(device)
# Load the saved state_dict
model.load_state_dict(
torch.load("model.pth", map_location=device, weights_only=True)
)
model.to(device)
model.eval()
print_memory_usage()
# Output: Maximum GPU memory allocated: 12.8 GB
The GPU memory usage has now doubled compared to the initial load, peaking at 12.8 GB. This happens because, for a short period, both the original model and the newly loaded weights are held in memory. Eventually, the loaded weights are copied into the model, and the temporary state_dict is discarded. However, this memory spike can cause issues when working with limited resources.
After loading the model weights and testing it, it’s essential to reset the GPU memory once again. Testing the model ensures it works as expected, and clearing memory is crucial for efficient resource usage.
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()
with torch.no_grad():
model(test_input)
del model, test_input
cleanup()
# Output: Maximum GPU memory allocated: 0.0 GB
This reset brings GPU memory usage back to zero, ensuring a clean state for future operations.
One effective workaround for the problem of double memory usage when loading model weights is sequential loading. Instead of loading both the model and weights simultaneously into GPU memory, we can load the model first, keep the weights in CPU memory, and then copy each parameter one by one to the GPU. This method significantly reduces the peak memory usage.
Here’s how to implement sequential weight loading:
The code below demonstrates this approach:
start_memory_tracking()
# Load the model into GPU memory
model = GPTModel(BASE_CONFIG).to(device)
# Load the model's saved state_dict onto the CPU
state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)
print_memory_usage()
# Output: Maximum GPU memory allocated: 6.4 GB
# Copy each parameter to GPU memory one by one
with torch.no_grad():
for name, param in model.named_parameters():
if name in state_dict:
param.copy_(state_dict[name].to(device))
else:
print(f"Warning: {name} not found in state_dict.")
print_memory_usage()
# Output: Maximum GPU memory allocated: 6.7 GB
However, this is a much smaller peak compared to the 12.8 GB required when loading everything at once. By sequentially loading the weights, we avoid having both the full model and the full set of weights in GPU memory simultaneously.
After copying the weights, we test the model to ensure everything works as expected. Finally, we reset the GPU memory to clear any lingering objects, just as we did in previous steps.
# Test if the model works (no need to track memory here)
test_input = torch.tensor([[1, 2, 3]]).to(device)
model.eval()
with torch.no_grad():
model(test_input)
# Clean up GPU memory
del model, test_input, state_dict, param
cleanup()
# Output: Maximum GPU memory allocated: 0.0 GB
In the previous section, we reduced GPU memory usage by loading model weights into CPU memory first and then sequentially copying them into the GPU. But what if the machine has limited CPU memory and larger GPU memory? To tackle this, we can use PyTorch’s “meta” device approach, which is ideal for machines with constrained CPU resources.
The “meta” device is a special device type in PyTorch that creates “meta” tensors. These tensors represent the shape and type of the data without allocating memory for the data itself. This allows us to define models without consuming CPU or GPU memory until necessary.
Using the meta device, we can first initialize the model without any memory allocation, and then load the model weights directly into GPU memory, bypassing the CPU.
Before we dive into the meta device approach, we will define a utility function to track CPU memory usage:
import os
import psutil
from threading import Thread
def memory_usage_in_gb(func, *args, **kwargs):
process = psutil.Process(os.getpid())
baseline_mem = process.memory_info().rss / 1024 ** 3 # in GB
mem_usage = []
done = False
def monitor_memory():
while not done:
mem_usage.append(process.memory_info().rss / 1024 ** 3) # Convert to GB
time.sleep(0.1)
t = Thread(target=monitor_memory)
t.start()
func(*args, **kwargs)
done = True
t.join()
peak_mem_usage_gb = max(mem_usage) - baseline_mem
return peak_mem_usage_gb
Now that we can measure CPU memory usage, let’s track the memory used during the sequential weight loading approach from the previous section:
def load_sequentially():
start_memory_tracking()
model = GPTModel(BASE_CONFIG).to(device)
state_dict = torch.load("model.pth", map_location="cpu", weights_only=True)
print_memory_usage()
# Sequentially copy weights to the model's parameters
with torch.no_grad():
for name, param in model.named_parameters():
if name in state_dict:
param.copy_(state_dict[name].to(device))
print_memory_usage()
peak_memory_used = memory_usage_in_gb(load_sequentially)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
This approach outputs:
To further reduce CPU memory usage, we can use the meta device to load the model without allocating memory until we actually need it. Here’s the implementation:
def load_sequentially_with_meta():
start_memory_tracking()
with torch.device("meta"):
model = GPTModel(BASE_CONFIG)
model = model.to_empty(device=device)
state_dict = torch.load("model.pth", map_location=device, weights_only=True)
print_memory_usage()
# Sequentially copy weights to the model's parameters
with torch.no_grad():
for name, param in model.named_parameters():
if name in state_dict:
param.copy_(state_dict[name])
print_memory_usage()
peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
By using the meta device and directly loading the model weights into GPU memory, we drastically reduce CPU memory consumption from 6.3 GB to just 1.3 GB.
Finally, let’s compare this method with the simple PyTorch weight loading method, where no meta device or sequential loading is used:
def baseline():
start_memory_tracking()
model = GPTModel(BASE_CONFIG)
model.to(device)
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
model.to(device)
model.eval()
print_memory_usage()
peak_memory_used = memory_usage_in_gb(baseline)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
For this approach:
For more advanced users of PyTorch, there’s an alternative approach to handling memory constraints when loading large models—using the mmap=True setting in torch.load(). This setting leverages memory-mapped file I/O, which allows the model to read data directly from disk without fully loading it into RAM. This is particularly useful on systems with limited CPU memory, as it minimizes the memory footprint during model loading.
Memory-mapped I/O (mmap) is a mechanism that enables a file to be read directly from disk by mapping it into the virtual address space. Instead of loading the entire model into RAM, PyTorch can load parts of the model on demand, effectively reducing memory usage. This can be particularly advantageous when dealing with large pretrained or finetuned models, such as GPT-2 or GPT-3, on machines with limited resources.
The mmap=True option can be added when calling torch.load() to achieve this behavior.
Let’s see how the mmap=True option works in practice. Below is a sample implementation where we load a model using this setting:
def best_practices():
with torch.device("meta"):
model = GPTModel(BASE_CONFIG)
model.load_state_dict(
torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
assign=True
)
print_memory_usage()
peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Here, we see that the GPU memory usage remains efficient (6.4 GB), and CPU memory usage is fairly high because the machine has enough CPU RAM to support it. However, on a system with limited CPU RAM, the mmap=True approach would use less memory by avoiding loading the full model into RAM.
The mmap=True option is especially helpful in the following scenarios:
At first glance, the mmap=True approach might seem less efficient compared to the sequential weight loading approach. However, for machines with limited CPU memory, mmap=True can be a game-changer, providing an effective way to load large models without overwhelming the CPU’s available memory.
By using mmap=True, you’re balancing disk access with memory availability, which can help in environments where memory is scarce but disk I/O is fast.
In this notebook, we’ve focused on simple, built-in methods for efficiently loading model weights in PyTorch, particularly when memory (either GPU or CPU) is constrained. The recommended method for managing limited CPU memory is the mmap=True approach, as explained previously.
However, if you’re dealing with extreme memory limitations or need more control over the process, there’s another brute-force approach: saving and loading each weight tensor individually.
Instead of saving the entire state_dict as a single file, this method stores each model parameter (tensor) separately. This allows you to load each parameter one at a time, preventing the need to hold the entire model in memory simultaneously.
Here’s how you can save the model weights individually:
model = GPTModel(BASE_CONFIG)
# Assume `model` is your trained model
state_dict = model.state_dict()
# Create a directory to store individual parameter files
os.makedirs("model_parameters", exist_ok=True)
# Save each parameter tensor separately
for name, param in state_dict.items():
torch.save(param.cpu(), f"model_parameters/{name}.pt")
del model # Free up GPU memory
This breaks the model into individual components, saving each tensor to its own file in the “model_parameters” directory.
Now, let’s see how you can load these weights one-by-one to avoid overwhelming memory usage.
def load_individual_weights():
start_memory_tracking()
with torch.device("meta"):
model = GPTModel(BASE_CONFIG)
model = model.to_empty(device=device)
print_memory_usage()
param_dir = "model_parameters"
with torch.no_grad():
for name, param in model.named_parameters():
weight_path = os.path.join(param_dir, f"{name}.pt")
if os.path.exists(weight_path):
param_data = torch.load(weight_path, map_location="cpu", weights_only=True)
param.copy_(param_data.to(device)) # Move tensor to GPU
del param_data # Free memory after copying
else:
print(f"Warning: {name} not found in {param_dir}.")
print_memory_usage()
The memory footprint here is significantly reduced—both on the GPU and CPU. By loading weights individually, you ensure that no unnecessary memory is consumed at any stage, making this approach ideal for extremely memory-limited environments.
When CPU and GPU memory are both highly constrained, this method offers precise control, ensuring that only one parameter tensor is loaded into memory at any given time.
On machines where you cannot afford to use more than minimal resources, this brute-force method provides a solution to ensure you can load even the largest models.
The trade-off here is performance. Since each tensor is loaded separately, this method incurs extra disk I/O, which may slow down the loading process compared to methods that load the entire model or larger chunks of data at once.
When working with large models, such as GPT variants or other deep learning models, memory efficiency is crucial. Techniques like sequential weight loading, using the meta device, and enabling mmap=True
help reduce memory usage on both CPU and GPU. These methods, known for memory-efficient model weight loading in PyTorch, are highly flexible and can be adapted depending on the specific constraints of your hardware environment, whether you have limited CPU RAM, GPU VRAM, or both.
By employing these techniques, you can work with large models even on constrained hardware, ensuring smooth model training and fine-tuning workflows.
Hope you like the article! Memory-efficient model weight loading in PyTorch helps save resources. Using memory efficient model weight loading in Python can reduce overhead. For a memory efficient model weight loading in PyTorch example, try using torch.load()
with memory mapping to lower RAM usage.
As deep learning models grow larger (especially models like GPT-2, GPT-3), efficiently loading these models becomes essential to prevent running out of GPU or CPU memory. Memory-efficient loading allows you to work with large models even in constrained environments.
You can use the functions torch.cuda.reset_peak_memory_stats() and torch.cuda.max_memory_allocated() to track GPU memory usage before, during, and after loading or training models. The provided utility functions help monitor memory usage efficiently.
Sequential weight loading involves loading the model architecture onto the GPU and then transferring weights one at a time from CPU to GPU. This reduces the peak memory usage compared to loading both the model and its weights at once, helping manage limited GPU memory.
Use lower precision: float16, mixed precision.
Optimize tensor operations: avoid copies, efficient shapes, views.
Gradient accumulation: update weights less frequently.
Reduce model size: prune connections, quantize weights, smaller models.
Optimize data loading: data loaders, prefetching, memory-mapped files.
GPU memory efficiency: monitor usage, free unused memory, multiple GPUs.
Advanced techniques: knowledge distillation, low-rank approximation.
The “meta” device allows you to initialize models without allocating memory for their parameters. This is useful when you have limited CPU memory since you can later load weights directly into the GPU, bypassing the need for large memory allocations on the CPU.
Good Article very informative
Good Article