Beam search is a powerful decoding algorithm extensively used in natural language processing (NLP) and machine learning. It is especially important in sequence generation tasks such as text generation, machine translation, and summarization. Beam search balances between exploring the search space efficiently and generating high-quality output. In this blog, we will dive deep into the workings of beam search, its importance in decoding, and an implementation while exploring its real-world applications and challenges.
This article was published as a part of the Data Science Blogathon.
Beam search is a heuristic search algorithm used to decode sequences from models such as transformers, LSTMs, and other sequence-to-sequence architectures. It generates text by maintaining a fixed number (“beam width”) of the most probable sequences at each step. Unlike greedy search, which only picks the most likely next token, beam search considers multiple hypotheses at once. This ensures that the final sequence is not only fluent but also globally optimal in terms of model confidence.
For example, in machine translation, there might be multiple valid ways to translate a sentence. Beam search allows the model to explore these possibilities by keeping track of multiple candidate translations simultaneously.
Beam search works by exploring a graph where nodes represent tokens and edges represent probabilities of transitioning from one token to another. At each step:
The “beam width” determines how many candidate sequences are retained at each step. A larger beam width allows for exploring more sequences but increases computational cost. Conversely, a smaller beam width is faster but risks missing better sequences due to limited exploration.
Beam search is vital in decoding for several reasons:
Below is a practical example of beam search implementation. The algorithm builds a search tree, evaluates cumulative scores, and selects the best sequence:
# Install transformers and graphviz
!sudo apt-get install graphviz graphviz-dev
!pip install transformers pygraphviz
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from tqdm import tqdm
import matplotlib.colors as mcolors
System Commands: Installs required libraries for graph generation (graphviz) and Python packages (transformers and pygraphviz).
Imported Libraries:
Output:
# Load model and tokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()
Output:
# Input text
text = "I have a dream"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
def get_log_prob(logits, token_id):
probabilities = torch.nn.functional.softmax(logits, dim=-1)
log_probabilities = torch.log(probabilities)
return log_probabilities[token_id].item()
Implements recursive beam search for text generation using the GPT-2 model.
def beam_search(input_ids, node, bar, length, beams, temperature=1.0):
if length == 0:
return
outputs = model(input_ids)
predictions = outputs.logits
# Get logits for the next token
logits = predictions[0, -1, :]
top_token_ids = torch.topk(logits, beams).indices
for j, token_id in enumerate(top_token_ids):
bar.update(1)
# Compute the score of the predicted token
token_score = get_log_prob(logits, token_id)
cumulative_score = graph.nodes[node]['cumscore'] + token_score
# Add the predicted token to the list of input ids
new_input_ids = torch.cat([input_ids, token_id.unsqueeze(0).unsqueeze(0)], dim=-1)
# Add node and edge to graph
token = tokenizer.decode(token_id, skip_special_tokens=True)
current_node = list(graph.successors(node))[j]
graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100
graph.nodes[current_node]['cumscore'] = cumulative_score
graph.nodes[current_node]['sequencescore'] = cumulative_score / len(new_input_ids.squeeze())
graph.nodes[current_node]['token'] = token + f"_{length}_{j}"
# Recursive call
beam_search(new_input_ids, current_node, bar, length - 1, beams, temperature)
Finds the best sequence generated during beam search based on cumulative scores.
def get_best_sequence(G):
# Find all leaf nodes
leaf_nodes = [node for node in G.nodes if G.out_degree(node) == 0]
# Find the best leaf node based on sequence score
max_score_node = max(leaf_nodes, key=lambda n: G.nodes[n]['sequencescore'])
max_score = G.nodes[max_score_node]['sequencescore']
# Retrieve the path from root to this node
path = nx.shortest_path(G, source=0, target=max_score_node)
# Construct the sequence
sequence = "".join([G.nodes[node]['token'].split('_')[0] for node in path])
return sequence, max_score
Visualizes the tree-like beam search graph.
def plot_graph(graph, length, beams, score):
fig, ax = plt.subplots(figsize=(3 + 1.2 * beams**length, max(5, 2 + length)), dpi=300, facecolor='white')
# Create positions for each node
pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")
# Normalize the colors along the range of token scores
scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] is not None]
vmin, vmax = min(scores), max(scores)
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256)
# Draw the nodes
nx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape='o', alpha=1, linewidths=4,
node_color=scores, cmap=cmap)
# Draw the edges
nx.draw_networkx_edges(graph, pos)
# Draw the labels
labels = {node: data['token'].split('_')[0] + f"\n{data['tokenscore']:.2f}%" \
for node, data in graph.nodes(data=True) if data['token'] is not None}
nx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)
plt.box(False)
# Add a colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label='Token probability (%)')
plt.show()
# Parameters
length = 5
beams = 2
# Create a balanced tree graph
graph = nx.balanced_tree(beams, length, create_using=nx.DiGraph())
bar = tqdm(total=len(graph.nodes))
# Initialize graph attributes
for node in graph.nodes:
graph.nodes[node]['tokenscore'] = 100
graph.nodes[node]['cumscore'] = 0
graph.nodes[node]['sequencescore'] = 0
graph.nodes[node]['token'] = text
# Perform beam search
beam_search(input_ids, 0, bar, length, beams)
# Get the best sequence
sequence, max_score = get_best_sequence(graph)
print(f"Generated text: {sequence}")
# Plot the graph
plot_graph(graph, length, beams, 'token')
Parameters:
Graph Initialization:
Output:
You can access colab notebook here
Despite its advantages, beam search has some limitations:
Despite its advantages, beam search has some limitations:
Beam search is a cornerstone of modern NLP and sequence generation. By maintaining a balance between exploration and computational efficiency, it enables high-quality decoding in tasks ranging from machine translation to creative text generation. Despite its challenges, beam search remains a preferred choice due to its flexibility and ability to produce coherent and meaningful outputs.
Understanding and implementing beam search equips you with a powerful tool to enhance your NLP models and applications. Whether you’re working on language models, chatbots, or translation systems, mastering beam search will significantly elevate the performance of your solutions.
A. Beam search maintains multiple candidate sequences at each step, while greedy search only selects the most probable token. This makes beam search more robust and accurate.
A. The optimal beam width depends on the task and computational resources. Smaller beam widths are faster but risk missing better sequences, while larger beam widths explore more possibilities at the cost of speed.
A. Yes, beam search is particularly effective in tasks with multiple valid outputs, such as machine translation. It explores multiple hypotheses and selects the most probable one.
A. Beam search can produce repetitive sequences, favor shorter outputs, and require careful tuning of parameters like beam width.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.