Attention Sinks for LLM – Endless Generation 

Syed Abdul Gaffar Last Updated : 12 Apr, 2024
8 min read

Introduction

Imagine a world where large language models (LLMs) can seamlessly weave narratives, translate languages on the fly, and answer your questions with context extending beyond the prompt. This is the promise of attention sinks, a revolutionary method that unlocks endless generation for LLMs.

Attention Sinks for LLM - Endless Generation 
Source: Arxiv

Learning Objectives

  • Recognizing the challenges associated with long conversations using traditional LLMs.
  • Understanding the concept of attention sinks and their role in addressing memory overload and limited understanding.
  • Exploring the benefits of attention sinks, including memory efficiency, computational savings, and enhanced fluency.
  • Grasping the implementation details of attention sinks, particularly in combination with the rolling KV cache.
  • Learning how attention sinks seamlessly integrate with existing transformer architectures.
  • Gaining practical insights into streaming LLM output with attention sinks.
  • Recognizing real-world applications of endless generation, such as in streaming chatbots, real-time translation, and open-ended storytelling.

Take your AI innovations to the next level with GenAI Pinnacle. Fine-tune models like Gemini and unlock endless possibilities in NLP, image generation, and more. Dive in today! Explore Now

This article was published as a part of the Data Science Blogathon.

What are Attention Sinks?

Using large language models (LLMs) for ongoing conversations (like chatbots) is great, but it presents two problems:

  • Memory overload
  • Limited understanding

A common solution called “window attention” only stores recent words, but this fails for long chats.

Key insight from the research abstract: Large Language Models (LLMs) frequently allocate excessive attention to the initial tokens, behaving like a “sink,” even when those words lack critical importance. A proposed solution involves retaining these early words in memory, leading to a notable enhancement in the performance of LLMs, particularly when utilizing window attention.

 source: https://arxiv.org/pdf/2309.17453.pdf
Source: Arxiv

This opens the door to using LLMs effectively in long, flowing conversations without needing tons of memory. In short traditional LLMs, like Transformers, struggle with long sequences. They rigorously attend to every word, leading to memory bottlenecks and clunky, context-less outputs or hallucinate. Attention sinks offer a paradigm shift.

Think of sinking a stone in a pond. The ripples spread outward, influencing the surrounding area. Similarly, attention sinks are strategically placed keywords that absorb the LLM’s focus. These “anchors” hold crucial information, allowing the model to efficiently process and generate text without getting lost in the vast chunk of words.

Benefits of Attention Sinks

  • Memory Efficiency: Attention sinks dramatically reduce the memory footprint, enabling LLMs to handle much longer sequences. Imagine generating chapters of a novel without ever forgetting the plot!
  • Computational Savings: By focusing on key points, the LLM’s processing power is greatly optimized. This translates to faster generation and lower energy consumption, ideal for real-time applications.
  • Enhanced Fluency: Attention sinks ensure context awareness even in open-ended scenarios. The LLM retains the essence of previous interactions, leading to more coherent, contextual and natural-sounding dialogues and narratives.
  • Versatile and Adaptable to different encoding schemes. Works with existing LLMs without retraining, saving time and resources

Overall, Streaming LLM offers a practical and efficient solution for unleashing the power of LLMs in real-time, open-ended interactions.

Rolling KW Cache with Attention SInks

Rolling KW Cache with Attention SInks
Source: Arxiv

The key idea is to combine two memory caches:

  • Attention sinks: These hold a few initial tokens (around four) and their key-value states (KV). These act as anchors, stabilizing the attention mechanism even when the rest of the conversation scrolls out of the main cache.
  • Rolling KV Cache: This holds the most recent tokens similar to traditional window attention.

Crucial to Streaming LLM is how it handles positional information:

  • Instead of referencing positions in the original text, it uses relative positions within the combined cache.
  • This ensures the model understands the relationships between tokens even as the conversation flows.
  • For specific encoding schemes like RoPE and ALiBi, Streaming LLM adapts its caching and position transformation methods to seamlessly integrate.

For more understanding refer here.

Let’s Dive into Implementation

Attention sink modules seamlessly integrate with transformer architectures, offering an easy-to-use solution for streaming large language models. Their plug-and-play nature lets you leverage their benefits with minimal effort. Here’s a glimpse of how the attention sink module fits in:

  • Existing Transformer: Imagine your standard transformer setup.
  • Attention Sink Addition: Introduce the attention sink module alongside the transformer. It acts as a dedicated memory bank, holding onto those crucial initial tokens.
  • Enhanced Attention: During decoding, the transformer taps into both the rolling cache (recent tokens) and the attention sink (early anchors). This stabilizes the attention mechanism for longer dialogues.

Remember, attention sink modules require minimal code changes, making them a low-effort, high-impact upgrade for LLM streaming needs.

import torch
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM

model_id = "mistralai/Mistral-7B-v0.1"

# Load the chosen model and corresponding tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # for efficiency:
    device_map="auto",
    torch_dtype=torch.float16,
    # `attention_sinks`-specific arguments:
    attention_sink_size=4,
    attention_sink_window_size=252, # <- Low for the sake of faster generation
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Our input text
text = "Data Science Blogathon - 39"

# Encode the text
input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)

with torch.no_grad():
    # A TextStreamer prints tokens as they're being generated
    streamer = TextStreamer(tokenizer)
    generated_tokens = model.generate(
        input_ids,
        generation_config=GenerationConfig(
            # use_cache=True is required, the rest can be changed up.
            use_cache=True,
            min_new_tokens=100_000,
            max_new_tokens=1_000_000,
            penalty_alpha=0.6,
            top_k=5,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        ),
        streamer=streamer,
    )
    # Decode the final generated text
    output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)t csv

Streaming

Let’s see how we can stream the LLM output using attention sink. We will use the script “https://github.com/tomaarsen/attention_sinks/blob/main/demo/streaming.py“.

import argparse
from pathlib import Path
from typing import Any, Dict, List

import torch
from datasets import Dataset, load_dataset
from transformers import (
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)
from utils import FileStreamer
def create_prompts(samples: Dict[str, List[Any]]) -> Dict[str, Any]:
    return {"prompt": [prompt for prompts in samples["prompt"] for prompt in prompts]}


@torch.no_grad()
def greedy_generate(
    model: PreTrainedModel, tokenizer: PreTrainedTokenizer, dataset: Dataset, log_file: str, max_new_tokens: int = 1000
):
    streamer = FileStreamer(tokenizer, log_file)
    past_key_values = None
    new_line_tokens = tokenizer("\n\n", return_tensors="pt", add_special_tokens=False).input_ids

    for prompt_index, prompt in enumerate(dataset["prompt"]):
        # Use the chat template initially, as it adds the system prompt if the model has one, and then use [INST] and [/INST]
        if prompt_index:
            prompt = f"[INST] {prompt} [/INST]"
        else:
            prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False)
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        input_ids = input_ids.to(model.device)

        streamer.put(input_ids)
        for _ in range(max_new_tokens):
            outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
            past_key_values = outputs.past_key_values
            pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
            streamer.put(pred_token_idx)
            input_ids = pred_token_idx

            if pred_token_idx == tokenizer.eos_token_id:
                break

        streamer.put(new_line_tokens)

The function create_prompts will create a prompt list from the dataset. In the function greedy_generate we will initialize the streamer object which manages text chunks as tokens and past_key_values are initialized, then we will iterate over the prompt, It formats the prompt with “[INST]” and “[/INST]” for streamed dialogue. Tokenizes the prompt and adds it to the streamer. Generates tokens one by one using the model, updating past_key_values. Stops if encountering the end-of-sentence token. Adds a newline token to separate dialogues and dump the predicted output to the streamer object.

In the main function, we set the experiment as attention_sinks and you can change the model name in model_name_or_path or if you have trained model you can give the model path. If you want to use your own dataset, modify the functions responsible for loading data and generating prompts (and create_prompts). Running the code will display a continuous stream of generated text in your terminal, streaming the output.

def main():
    parser = argparse.ArgumentParser()

    # Which experiment to run?
    parser.add_argument(
        "--experiment", choices=["attention_sinks", "transformers", "windowed"], default="attention_sinks"
    )

    # Model args
    parser.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-Instruct-v0.1")
    parser.add_argument("--revision", type=str, default="main")
    parser.add_argument("--trust_remote_code", action="store_true")

    # Dataset args, not recommended to change:
    parser.add_argument("--dataset_name", type=str, default="HuggingFaceH4/mt_bench_prompts")

    # Where to log
    parser.add_argument("--log_file", type=str, default=None)

    # Window size for windowed and attention_sinks
    parser.add_argument("--window_size", type=int, default=1024)

    # Attention Sinks-only settings
    # Attention Sink window size is calculated with args.window_size - args.attention_sink_size
    parser.add_argument("--attention_sink_size", type=int, default=4)

    args = parser.parse_args()

    # Initialize the model, either via transformers or via attention_sinks
    if args.experiment == "transformers":
        from transformers import AutoModelForCausalLM
    else:
        from attention_sinks import AutoModelForCausalLM
    kwargs = {}
    if args.experiment == "attention_sinks":
        kwargs = {
            "attention_sink_size": args.attention_sink_size,
            "attention_sink_window_size": args.window_size - args.attention_sink_size,  # default: 1020
        }
    elif args.experiment == "windowed":
        kwargs = {
            "attention_sink_size": 0,
            "attention_sink_window_size": args.window_size,
        }
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        revision=args.revision,
        trust_remote_code=bool(args.trust_remote_code),
        torch_dtype=torch.float16,
        device_map="auto",
        **kwargs,
    )
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))
    tokenizer.pad_token_id = tokenizer.eos_token_id

    # Set up the dataset
    dataset = load_dataset(args.dataset_name, split="train")
    dataset = dataset.map(create_prompts, batched=True, remove_columns=dataset.column_names)

    log_file = args.log_file or Path("demo") / "streaming_logs" / args.experiment / f"{args.model_name_or_path}.txt"
    greedy_generate(model, tokenizer, dataset, log_file=log_file)


if __name__ == "__main__":
    main()

Applications of Endless Generation

  • Streaming Chatbots: Imagine a chatbot that remembers your entire conversation history and seamlessly adapts to your changing needs. Attention sinks make this a reality, enabling rich and personalized interactions.
  • Real-time Translation: Imagine translating a live speech with perfect accuracy, even for extended conversations. Attention sinks bridge the gap between consecutive sentences, preserving context for flawless translation.
  • Open-ended Storytelling: Imagine scripting an epic novel one chapter at a time, with each chapter seamlessly building upon the last. Attention sinks unlock the potential for truly immersive and interconnected narratives.

The Future LLMs

Attention sinks are not just a technological leap; they represent a shift in how we think about LLMs. Instead of static models, we can now conceive LLMs as dynamic entities, constantly learning and adapting within a flowing stream of information.

This opens up a lot of possibilities:

  • Collaborative writing tools that seamlessly weave together inputs from multiple users.
  • Personalized educational assistants that adapt their explanations based on your learning style and progress.
  • AI-powered creative partners that help you brainstorm ideas.
  • The possibilities are endless, and attention sinks pave the way for a future where LLMs are not just tools, but collaborators, companions, and catalysts for human creativity.

The field of attention sinks is rapidly evolving. If you’re interested in exploring this exciting breakthrough, here are some resources:

Conclusion

In conclusion, attention sinks represent a groundbreaking solution to the challenges faced by large language models in handling long and dynamic conversations. The implementation of attention sinks, coupled with the rolling KV cache, enables LLMs to operate efficiently in real-time scenarios, offering benefits such as reduced memory footprint and enhanced contextual understanding.

Key Takeaways

  • Paradigm Shift: Attention sinks mark a paradigm shift in the capabilities of LLMs, transforming them from static models to dynamic entities adaptable to flowing streams of information.
  • Practical Applications: Endless generation facilitated by attention sinks opens the door to practical applications, including personalized chatbots, real-time translation, and immersive storytelling.
  • Future Possibilities: Attention sinks pave the way for collaborative writing tools, personalized educational assistants, and AI-powered creative partners, signaling a future where LLMs actively contribute to human creativity.
  • Resource Exploration: Readers are encouraged to explore additional resources, including blog posts, research papers, and open-source implementations, to stay informed about the evolving field of attention sinks.

Dive into the future of AI with GenAI Pinnacle. From training bespoke models to tackling real-world challenges like PII masking, empower your projects with cutting-edge capabilities. Start Exploring.

Frequently Asked Questions

Q1. What are attention sinks, and how do they address challenges in large language models (LLMs)?

A. Attention sinks are strategically placed keywords that act as anchors for LLMs during conversations. They address challenges in LLMs, such as memory overload and limited understanding, by absorbing the model’s focus on crucial initial tokens. This allows LLMs to efficiently process and generate text without getting lost in lengthy sequences.

Q2. How do attention sinks improve the efficiency of LLMs in long conversations?

A. Attention sinks dramatically reduce the memory footprint of LLMs, enabling them to handle much longer sequences. By strategically focusing on key points, attention sinks optimize the processing power of LLMs, resulting in faster generation and lower energy consumption. This makes them ideal for real-time applications.

Q3. Can attention sinks be integrated into existing LLMs without retraining?

A. Yes, attention sinks are designed to work seamlessly with existing LLMs, such as Transformers, without the need for retraining. They offer a plug-and-play solution, requiring minimal code changes. This makes attention sinks a practical and efficient upgrade for LLMs, saving both time and resources.

Q4. How do attention sinks contribute to the future of LLMs, and what possibilities do they unlock?

A. Attention sinks represent a shift in how we perceive LLMs. They open up possibilities for dynamic entities that constantly learn and adapt within a flowing stream of information. This evolution paves the way for collaborative writing tools, personalized educational assistants, and AI-powered creative partners, making LLMs more than just tools but collaborators and catalysts for human creativity.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

I thrive on the thrill of the challenge, tackling complex problems and crafting innovative AI solutions that make a difference. Whether it's optimizing or building sustainable AI ecosystems, I believe in harnessing the power of AI for the greater good. Let's brainstorm, collaborate, and change the world, one byte at a time.

Responses From Readers

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details