In Artificial Intelligence, Understanding the underlying workings of language models has proven to be significant and difficult. Google has made a significant step forward in tackling this issue by releasing Gemma Scope, a comprehensive package of tools to assist researchers in peering inside the “black box” of AI language models. This article will look at Gemma Scope, its significance, and how it intends to transform the field of mechanistic interpretability.
Gemma Scope is a collection of hundreds of publicly available open sparse autoencoders (SAEs) for Google’s lightweight open model family, Gemma 2 9B and Gemma 2 2B. These technologies serve as a “microscope” for academics, allowing them to analyze the internal processes of language models and gain insights into how they work and decide.
To realize Gemma Scope’s significance, you must first understand the concept of mechanical interpretability. When researchers design AI language models, they create systems that can learn from large volumes of data without human intervention. As a result, the inner workings of these models are frequently unknown, even to their authors.
Mechanistic interpretability is a research subject devoted to understanding these fundamental workings. By studying it, researchers can acquire a deeper knowledge of how language models function.
Gemma Scope uses sparse autoencoders to interpret a model’s activations while processing text input. Here’s a simple explanation of the process:
Also read: How to Use Gemma LLM?
Let’s dive into the technical details of implementing Gemma Scope, using code examples to illustrate key concepts:
First, we need to load the Gemma 2 model:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
We load Gemma 2 2B, the smallest model for which Gemma Scope works. We load the base model rather than the conversation model because that is where our SAEs are taught. The SAEs appear to transfer to these models.
To obtain the model weights, you first need to authenticate them with huggingface.
notebook_login()
torch.set_grad_enabled(False) # avoid blowing up mem
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b",
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
Now we’ve loaded the model, let’s try running it! We give it the prompt
“Just a drop in the ocean A change in the weather,I was praying that you and me might end up together. Its like wiching for the rain as I stand in the desert.” and print the generated output
from IPython.display import display, Markdown
prompt = "Just a drop in the ocean A change in the weather,I was praying that you and me might end up together. Its like wiching for the rain as I stand in the desert."
# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda")
display(Markdown(f"**Encoded inputs:**\n```\n{inputs}\n```"))
# Pass it in to the model and generate text
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
generated_text = tokenizer.decode(outputs[0])
display(Markdown(f"**Generated text:**\n\n{generated_text}"))
So we have Gemma 2 loaded and can sample from it to get sensible results.
Now, let’s load one of our SAE files.
GemmaScope has nearly four hundred SAEs, but for now, we’ll merely load one on the residual stream at the end of layer 20.
Loading the parameters of the model and moving them to GPU:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}
We now define the SAE’s forward pass for educational reasons.
Gemma Scope is a collection of JumpReLU SAEs, similar to a typical two-layer (one hidden layer) neural network but with a JumpReLU activation function: a ReLU with a discontinuous jump.
import torch.nn as nn
class JumpReLUSAE(nn.Module):
def __init__(self, d_model, d_sae):
# Note that we initialise these to zeros because we're loading in pre-trained weights.
# If you want to train your own SAEs then we recommend using blah
super().__init__()
self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
self.threshold = nn.Parameter(torch.zeros(d_sae))
self.b_enc = nn.Parameter(torch.zeros(d_sae))
self.b_dec = nn.Parameter(torch.zeros(d_model))
def encode(self, input_acts):
pre_acts = input_acts @ self.W_enc + self.b_enc
mask = (pre_acts > self.threshold)
acts = mask * torch.nn.functional.relu(pre_acts)
return acts
def decode(self, acts):
return acts @ self.W_dec + self.b_dec
def forward(self, acts):
acts = self.encode(acts)
recon = self.decode(acts)
return recon
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
First, let’s run some model activations at the SAE target site. We’ll start by demonstrating how to do this ‘ manually’ using Pytorch hooks. It should be noted that this is not especially good practice, and it is probably more practical to utilize a library like TransformerLens to handle plugging the SAE into a model’s forward pass. However, seeing how it’s done can be valuable for illustration.
We can collect activations at a place by registering a hook. To keep this local, we may wrap it in a function that registers a hook, runs the model while recording the intermediate activation, and then removes the hook.
def gather_residual_activations(model, target_layer, inputs):
target_act = None
def gather_target_act_hook(mod, inputs, outputs):
nonlocal target_act # make sure we can modify the target_act from the outer scope
target_act = outputs[0]
return outputs
handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
_ = model.forward(inputs)
handle.remove()
return target_act
target_act = gather_residual_activations(model, 20, inputs)
sae.cuda()
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)
Let’s just double-check that the model looks sensible by checking that we explain a decent chunk of the variance:
1 - torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2) / (target_act[:, 1:].to(torch.float32).var())
This probably appears fine. This SAE reportedly has an L0 of roughly 70, so let’s also check that.
(sae_acts > 1).sum(-1)
There is one catch: our SAEs are not trained on the BOS token because we discovered that it tended to be a huge outlier and cause training to fail. As a result, when we ask them to do something, they tend to say gibberish, and we must be careful not to do this by accident! As shown above, the BOS token is a huge outlier in terms of L0!
Let’s take a look at the most activating aspects in this input text at each token position.
values, inds = sae_acts.max(-1)
inds
So we notice that one of the max activation examples on this topic is which fires on notions connected to time travel!
Let’s visualize the features in a more interactive way by utilizing the Neuropedia dashboard.
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
return html_template.format(sae_release, sae_id, feature_idx)
html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004)
IFrame(html, width=1200, height=600)
Also Read: Google Gemma, the Open-Source LLM Powerhouse
Consider examining and evaluating recent items to show Gemma Scope’s practical use. This example shows Gemma 2’s fundamental methods for handling various news content.
First, we’ll prepare our environment by importing the necessary libraries and loading the Gemma 2 2B model and its tokenizer.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
# Load Gemma 2 2B model and tokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b", device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
Next, we’ll implement the JumpReLU Sparse Autoencoder (SAE) and load pre-trained parameters:
# Define JumpReLU SAE
class JumpReLUSAE(torch.nn.Module):
def __init__(self, d_model, d_sae):
super().__init__()
self.W_enc = torch.nn.Parameter(torch.zeros(d_model, d_sae))
self.W_dec = torch.nn.Parameter(torch.zeros(d_sae, d_model))
self.threshold = torch.nn.Parameter(torch.zeros(d_sae))
self.b_enc = torch.nn.Parameter(torch.zeros(d_sae))
self.b_dec = torch.nn.Parameter(torch.zeros(d_model))
def encode(self, input_acts):
pre_acts = input_acts @ self.W_enc + self.b_enc
mask = (pre_acts > self.threshold)
acts = mask * torch.nn.functional.relu(pre_acts)
return acts
def decode(self, acts):
return acts @ self.W_dec + self.b_dec
# Load pre-trained SAE parameters
path_to_params = hf_hub_download(
repo_id="google/gemma-scope-2b-pt-res",
filename="layer_20/width_16k/average_l0_71/params.npz",
)
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}
# Initialize and load SAE
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae.cuda()
# Function to gather activations
def gather_residual_activations(model, target_layer, inputs):
target_act = None
def gather_target_act_hook(mod, inputs, outputs):
nonlocal target_act
target_act = outputs[0]
handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
_ = model(inputs)
handle.remove()
return target_act
We’ll create a function to analyze headlines using Gemma Scope:
# Analyze headline with Gemma Scope
def analyze_headline(headline, top_k=5):
inputs = tokenizer.encode(headline, return_tensors="pt", add_special_tokens=True).to("cuda")
# Gather activations
target_act = gather_residual_activations(model, 20, inputs)
# Apply SAE
sae_acts = sae.encode(target_act.to(torch.float32))
# Get top activated features
values, indices = torch.topk(sae_acts.sum(dim=1), k=top_k)
return indices[0].tolist()
For our analysis, we’ll use a diverse set of news headlines:
# Sample news headlines
headlines = [
"Global temperatures reach record high in 2024",
"Tech giant unveils revolutionary quantum computer",
"Historic peace treaty signed in Middle East",
"Breakthrough in renewable energy storage announced",
"Major cybersecurity attack affects millions worldwide"
]
To make our analysis more interpretable, we’ll categorize the activated features into broad topics:
# Predefined feature categories (for demonstration purposes)
feature_categories = {
1000: "Climate and Environment",
2000: "Technology and Innovation",
3000: "Global Politics",
4000: "Energy and Sustainability",
5000: "Cybersecurity and Digital Threats"
}
def categorize_feature(feature_id):
category_id = (feature_id // 1000) * 1000
return feature_categories.get(category_id, "Uncategorized")
Now, let’s analyze each headline and interpret the results:
# Analyze headlines
for headline in headlines:
print(f"\nHeadline: {headline}")
top_features = analyze_headline(headline)
print("Top activated feature categories:")
for feature in top_features:
category = categorize_feature(feature)
print(f"- Feature {feature}: {category}")
print(f"For detailed feature interpretation, visit: https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{top_features[0]}")
# Generate a summary report
print("\n--- Summary Report ---")
print("This analysis demonstrates how Gemma Scope can be used to understand the underlying concepts")
print("that the model activates when processing different types of news headlines.")
print("By examining the activated features, we can gain insights into the model's interpretation")
print("of various news topics and potentially identify biases or focus areas in its training data.")
This investigation sheds light on how the Gemma 2 model reads different news subjects. For example, we may see that headlines regarding climate change frequently activate features in the “Climate and Environment” category, whereas tech news activates features in “Technology and Innovation”.
Also read: Gemma 2: Successor to Google Gemma Family of Large Language Models.
Gemma Scope is an important achievement in the realm of mechanistic interpretability. Its potential impact on AI research and development is extensive:
While Gemma Scope offers a huge step forward in language model interpretability, there are still various obstacles and topics for future research.
Gemma Scope is a big step forward in the field of mechanical interpretability for language models. Google has opened up new paths for studying, enhancing, and protecting these increasingly essential technologies by offering academics powerful tools to examine the inner workings of AI systems.
Ans. Gemma Scope is a collection of open sparse autoencoders (SAEs) for Google’s lightweight open model family, Gemma 2 9B and Gemma 2 2B, which allows researchers to analyze the internal processes of language models and gain insights into their workings.
Ans. Mechanistic interpretability helps researchers understand the fundamental workings of AI models, enabling the creation of more resilient systems, improving model safeguards against hallucinations, and protecting against risks like dishonesty or manipulation by autonomous AI agents.
Ans. SAEs are a type of neural network used in Gemma Scope to decompose activations into limited features, revealing the underlying characteristics of the language model.
Ans. Yes, the implementation involves loading the Gemma 2 model, running it with specific text input, and analyzing activations using sparse autoencoders. The article provides sample code for detailed steps.