Imagine a world where large language models (LLMs) can seamlessly weave narratives, translate languages on the fly, and answer your questions with context extending beyond the prompt. This is the promise of attention sinks, a revolutionary method that unlocks endless generation for LLMs.
Take your AI innovations to the next level with GenAI Pinnacle. Fine-tune models like Gemini and unlock endless possibilities in NLP, image generation, and more. Dive in today! Explore Now
This article was published as a part of the Data Science Blogathon.
Using large language models (LLMs) for ongoing conversations (like chatbots) is great, but it presents two problems:
A common solution called “window attention” only stores recent words, but this fails for long chats.
Key insight from the research abstract: Large Language Models (LLMs) frequently allocate excessive attention to the initial tokens, behaving like a “sink,” even when those words lack critical importance. A proposed solution involves retaining these early words in memory, leading to a notable enhancement in the performance of LLMs, particularly when utilizing window attention.
This opens the door to using LLMs effectively in long, flowing conversations without needing tons of memory. In short traditional LLMs, like Transformers, struggle with long sequences. They rigorously attend to every word, leading to memory bottlenecks and clunky, context-less outputs or hallucinate. Attention sinks offer a paradigm shift.
Think of sinking a stone in a pond. The ripples spread outward, influencing the surrounding area. Similarly, attention sinks are strategically placed keywords that absorb the LLM’s focus. These “anchors” hold crucial information, allowing the model to efficiently process and generate text without getting lost in the vast chunk of words.
Overall, Streaming LLM offers a practical and efficient solution for unleashing the power of LLMs in real-time, open-ended interactions.
The key idea is to combine two memory caches:
Crucial to Streaming LLM is how it handles positional information:
For more understanding refer here.
Attention sink modules seamlessly integrate with transformer architectures, offering an easy-to-use solution for streaming large language models. Their plug-and-play nature lets you leverage their benefits with minimal effort. Here’s a glimpse of how the attention sink module fits in:
Remember, attention sink modules require minimal code changes, making them a low-effort, high-impact upgrade for LLM streaming needs.
import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM
model_id = "mistralai/Mistral-7B-v0.1"
# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_id,
# for efficiency:
device_map="auto",
torch_dtype=torch.float16,
# `attention_sinks`-specific arguments:
attention_sink_size=4,
attention_sink_window_size=252, # <- Low for the sake of faster generation
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
# Our input text
text = "Data Science Blogathon - 39"
# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
with torch.no_grad():
# A TextStreamer prints tokens as they're being generated
streamer = TextStreamer(tokenizer)
generated_tokens = model.generate(
input_ids,
generation_config=GenerationConfig(
# use_cache=True is required, the rest can be changed up.
use_cache=True,
min_new_tokens=100_000,
max_new_tokens=1_000_000,
penalty_alpha=0.6,
top_k=5,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
),
streamer=streamer,
)
# Decode the final generated text
output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)t csv
Let’s see how we can stream the LLM output using attention sink. We will use the script “https://github.com/tomaarsen/attention_sinks/blob/main/demo/streaming.py“.
import argparse
from pathlib import Path
from typing import Any, Dict, List
import torch
from datasets import Dataset, load_dataset
from transformers import (
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
from utils import FileStreamer
def create_prompts(samples: Dict[str, List[Any]]) -> Dict[str, Any]:
return {"prompt": [prompt for prompts in samples["prompt"] for prompt in prompts]}
@torch.no_grad()
def greedy_generate(
model: PreTrainedModel, tokenizer: PreTrainedTokenizer, dataset: Dataset, log_file: str, max_new_tokens: int = 1000
):
streamer = FileStreamer(tokenizer, log_file)
past_key_values = None
new_line_tokens = tokenizer("\n\n", return_tensors="pt", add_special_tokens=False).input_ids
for prompt_index, prompt in enumerate(dataset["prompt"]):
# Use the chat template initially, as it adds the system prompt if the model has one, and then use [INST] and [/INST]
if prompt_index:
prompt = f"[INST] {prompt} [/INST]"
else:
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer.put(input_ids)
for _ in range(max_new_tokens):
outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
streamer.put(pred_token_idx)
input_ids = pred_token_idx
if pred_token_idx == tokenizer.eos_token_id:
break
streamer.put(new_line_tokens)
The function create_prompts will create a prompt list from the dataset. In the function greedy_generate we will initialize the streamer object which manages text chunks as tokens and past_key_values are initialized, then we will iterate over the prompt, It formats the prompt with “[INST]” and “[/INST]” for streamed dialogue. Tokenizes the prompt and adds it to the streamer. Generates tokens one by one using the model, updating past_key_values. Stops if encountering the end-of-sentence token. Adds a newline token to separate dialogues and dump the predicted output to the streamer object.
In the main function, we set the experiment as attention_sinks and you can change the model name in model_name_or_path or if you have trained model you can give the model path. If you want to use your own dataset, modify the functions responsible for loading data and generating prompts (and create_prompts). Running the code will display a continuous stream of generated text in your terminal, streaming the output.
def main():
parser = argparse.ArgumentParser()
# Which experiment to run?
parser.add_argument(
"--experiment", choices=["attention_sinks", "transformers", "windowed"], default="attention_sinks"
)
# Model args
parser.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-Instruct-v0.1")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--trust_remote_code", action="store_true")
# Dataset args, not recommended to change:
parser.add_argument("--dataset_name", type=str, default="HuggingFaceH4/mt_bench_prompts")
# Where to log
parser.add_argument("--log_file", type=str, default=None)
# Window size for windowed and attention_sinks
parser.add_argument("--window_size", type=int, default=1024)
# Attention Sinks-only settings
# Attention Sink window size is calculated with args.window_size - args.attention_sink_size
parser.add_argument("--attention_sink_size", type=int, default=4)
args = parser.parse_args()
# Initialize the model, either via transformers or via attention_sinks
if args.experiment == "transformers":
from transformers import AutoModelForCausalLM
else:
from attention_sinks import AutoModelForCausalLM
kwargs = {}
if args.experiment == "attention_sinks":
kwargs = {
"attention_sink_size": args.attention_sink_size,
"attention_sink_window_size": args.window_size - args.attention_sink_size, # default: 1020
}
elif args.experiment == "windowed":
kwargs = {
"attention_sink_size": 0,
"attention_sink_window_size": args.window_size,
}
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
revision=args.revision,
trust_remote_code=bool(args.trust_remote_code),
torch_dtype=torch.float16,
device_map="auto",
**kwargs,
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))
tokenizer.pad_token_id = tokenizer.eos_token_id
# Set up the dataset
dataset = load_dataset(args.dataset_name, split="train")
dataset = dataset.map(create_prompts, batched=True, remove_columns=dataset.column_names)
log_file = args.log_file or Path("demo") / "streaming_logs" / args.experiment / f"{args.model_name_or_path}.txt"
greedy_generate(model, tokenizer, dataset, log_file=log_file)
if __name__ == "__main__":
main()
Attention sinks are not just a technological leap; they represent a shift in how we think about LLMs. Instead of static models, we can now conceive LLMs as dynamic entities, constantly learning and adapting within a flowing stream of information.
This opens up a lot of possibilities:
The field of attention sinks is rapidly evolving. If you’re interested in exploring this exciting breakthrough, here are some resources:
In conclusion, attention sinks represent a groundbreaking solution to the challenges faced by large language models in handling long and dynamic conversations. The implementation of attention sinks, coupled with the rolling KV cache, enables LLMs to operate efficiently in real-time scenarios, offering benefits such as reduced memory footprint and enhanced contextual understanding.
Dive into the future of AI with GenAI Pinnacle. From training bespoke models to tackling real-world challenges like PII masking, empower your projects with cutting-edge capabilities. Start Exploring.
A. Attention sinks are strategically placed keywords that act as anchors for LLMs during conversations. They address challenges in LLMs, such as memory overload and limited understanding, by absorbing the model’s focus on crucial initial tokens. This allows LLMs to efficiently process and generate text without getting lost in lengthy sequences.
A. Attention sinks dramatically reduce the memory footprint of LLMs, enabling them to handle much longer sequences. By strategically focusing on key points, attention sinks optimize the processing power of LLMs, resulting in faster generation and lower energy consumption. This makes them ideal for real-time applications.
A. Yes, attention sinks are designed to work seamlessly with existing LLMs, such as Transformers, without the need for retraining. They offer a plug-and-play solution, requiring minimal code changes. This makes attention sinks a practical and efficient upgrade for LLMs, saving both time and resources.
A. Attention sinks represent a shift in how we perceive LLMs. They open up possibilities for dynamic entities that constantly learn and adapt within a flowing stream of information. This evolution paves the way for collaborative writing tools, personalized educational assistants, and AI-powered creative partners, making LLMs more than just tools but collaborators and catalysts for human creativity.
The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.