How to Fine-tune LLMs to 1.58 bits?

Badrinarayan M Last Updated : 03 Oct, 2024
10 min read

Introduction

We all know that Large Language Models are growing in size and complexity. Finding ways to reduce their computational and energy cost is getting difficult. One popular method to reduce cost is quantization. In quantization, we reduce the precision of parameters from the standard 16-bit floating point (FP16) or 32-bit floating point (FP32) to lower-bit formats like 8-bit or 4-bit. This method reduces memory and speeds up computation but it gives a tradeoff with accuracy. Reducing precision so much causes models to lose crucial information. Hence, we get reduced performance. In this article, we will talk about – How to Fine-tune LLMs to 1.58 bits.

Fine-tune LLMs to 1.58 bits

Overview

  • Quantization reduces LLM costs by lowering precision but often comes with a tradeoff in accuracy.
  • BitNet introduces a 1.58-bit LLM that achieves comparable performance to full-precision models while drastically cutting energy consumption and computation costs.
  • Using ternary precision, BitNet replaces traditional layers with BitLinear, leveraging STE to handle non-differentiable weights.
  • Fine-tuning BitNet models(Fine-tune LLMs to 1.58 bits) from pre-trained Llama3 8B models improve performance but faces challenges in preserving information through quantization.
  • Though some performance gaps remain, dynamic lambda scheduling and alternative quantization methods help improve fine-tuning outcomes.
  • BitNet demonstrates the potential to create efficient, cost-effective LLMs, offering a new paradigm for future large-scale model training and hardware optimization.

BitNet for 1-bit Large Language Models (LLMs)

New advances in research, like BitNet, are opening the door for 1-bit Large Language Models (LLMs) to become the norm(Fine-tune LLMs to 1.58 bits). They presented BitNet b1.58, a 1-bit LLM variation in which each and every LLM parameter is ternary {-1, 0, 1}. 

Note: Perplexity is a metric used to evaluate how well a language model (LLM) predicts the next word in a sequence.

In terms of perplexity and end-task performance, it is comparable to the full-precision (FP16 or BF16) Transformer LLM with the same model size and training tokens. Still, latency, memory, throughput, and energy usage are substantially more economical. More importantly, the 1.58-bit LLM establishes a new scaling law and training recipe for future generations of high-performing and reasonably priced LLMs. It also makes it possible to construct hardware specifically optimized for 1-bit LLMs and create a new paradigm for computation.

BitNet for 1-bit Large Language Models (LLMs)

One limitation is that we need to train a model from scratch. We can say that the results are very good but not everyone has the budget to pre-train an LLM. Hence to overcome this limitation, authors of this article have explored a few tricks that allow fine-tuning an existing model to 1.58 bits. 

BitNet for 1-bit Large Language Models (LLMs)

This architecture uses INT8 addition calculations when performing matrix multiplication, in contrast to LLaMA LLM’s FP16 addition and multiplication operations. This results in BitNet b1.58 saving 71.4 times the arithmetic operations energy for matrix multiplication compared to Llama baseline. 

BitNet for 1-bit Large Language Models (LLMs)

Energy consumption of BitNet b1.58 compared to LLaMA LLM at 7nm process nodes. On the left are the components of arithmetic operations energy. On the right is the end-to-end energy cost across different model sizes.

What does BitNet do?

BitNet replaces traditional Linear layers in Multi-Head Attention and Feed-Forward Networks with specialized layers called BitLinear. This BitLinear layer uses ternary precision (or even binary in the initial version). One big obstacle when training a ternary precision is that the weights are discretized (using a round() function). This makes weights non-differentiable. If it is non-differentiable, then the weights won’t learn during back propagation. Hence BitNet uses a technique called STE (Straight Through Estimator)

What is STE?

Straight-Through Estimator (STE): The paper provides a detailed study of STE, a technique used to deal with non-differentiable functions that arise in quantized neural networks (QNNs). The STE allows the gradient to “pass-through” discrete variables during backpropagation by approximating their gradients. This is especially crucial in the context of QNNs, where weights and activations are often quantized to lower precision, making them non-differentiable.

A further way to look at it is that the STE allows the gradient to continue as if rounding had never happened, allowing weight updates using conventional gradient-based optimisation methods.

What is STE

(a) The computation flow of BitLinear. (b) The architecture of BitNet consists of the stacks of attentions and FFNs, where matrix multiplication is implemented as BitLinear.

Trying-out Pre-Training in 1.58b Quantization

So authors attempted to reproduce the results from the BitNet paper, they started with a small dataset, tinystories, and a Llama3 8B model. Upon implementation they have confirmed that adding a normalization function improves that performance. They also found that the training was stable. For example, after 2000 steps of training, we had a perplexity on the validation set equal to 6.3 without normalization, and 5.9 with normalization. 

Trying out Pre-Training in 1.58b Quantization

This approach reduced the cost while maintaining accuracy, but not many organizations can afford it. Other groups have reported that fine-tuning results were not very promising, so they tested that as well. 

Fine-Tuning using 1.58bit Quantization

When they began fine-tuning (Fine-tune LLMs to 1.58 bits) from the pre-trained Llama3 8B weights, the model performed slightly better but not as well as we expected.

Fine-Tuning Using 1.58bit Quantization

To understand why this is happening, they inspected the weight distribution of the randomly initialized and pre-trained models to find the issues. They also did inspect the scale values of two distributions. They found out that the pretrained model starts with more information, and adding more BitLinear layers overwhelms the model. It loses all its prior information.

Hence, to improve the fine-tuning results, they tried using per-row and per-column quantization instead of per-tensor quantization. This way, they kept more information that was already present in Llama 3. However, they observed that the model lost information when they did quantization. So, to investigate how much information was lost, they experimented with per-group quantization. 

As a sanity check, they first set the group size to 1, which essentially means no quantization. In this scenario, the loss started at 1.45, the same as they saw during normal fine-tuning. However, when we increased the group size to 2, the loss jumped to around 11. This indicates that even with a minimal group size of 2, the model still loses nearly all of its information. So, to address this issue, they considered introducing quantization gradually instead of applying it abruptly. 

To do this, they introduced a lambda value to control the process. When lambda = 0, no quantization is done, and when lambda = 1, full quantization is done. Initially, they tested discrete lambda values like 0.25, 0.50, 0.75, and 1. But the results were not that significant. This is because at lambda = 0.25, the loss started very high. 

Fine-Tuning Using 1.58bit Quantization

Hence, they decided to experiment with a dynamic lambda value that adjusts based on training steps. 

lambda_ = training_step / total_training_steps

Using this lambda value led to better loss convergence, but the perplexity was not satisfactory. This was because the model was not trained for long enough with lambda = 1. Hence, to address this, they used the dynamic lambda value below.  

lambda_ = min(2 * training_step / total_training_steps, 1)

With this configuration, after 2000 steps:

Fine-Tuning Using 1.58bit Quantization

We can see that this fine-tuning method shows better convergence overall. A slight increase in the loss curve around 1000 steps, but we can see that it improves, leading to a perplexity of approximately 4. Now, they tested the quantized model on the bigger WikiText dataset (not on tiny stories, which was used for fine-tuning); this resulted in high perplexity, which indicates that fine-tuning on low-bit mode causes the model to lose its general knowledge. Hence to overcome this issue they used a larger dataset FineWeb-edu. They used the below dynamic lambda value. 

lambda_ = min(training_step/1000, 1)

They chose this lambda value because it was a good starting point for warming up the model. They use a learning rate of 1e-4 for 5,000 steps on the FineWeb-edu dataset. The training involved a batch size (BS) of 2 million, totaling 10 billion tokens. Finding the right learning rate and the right decay was challenging; it seems to be a crucial factor in the model’s performance.

Fine-Tuning Using 1.58bit Quantization

After the completion of fine-tuning on the Fineweb-Edu dataset, the perplexity on the WikiText dataset reached 12.2 using only 10 billion tokens, which is very good. 

You can see that there is a sharp increase when the lambda approaches 1. To smooth out this they considered lambda schedulers that grow exponentially at first then level off as they get closer to 1. 

def scheduler(step, total_steps, k):
    normalized_step = step / total_steps
    return 1 - (1 - normalized_step)**k

For different values of k, with a total warmup steps of 1, plots look like the following:

Plots

They ran 4 experiments using the best performing learning rate 1e-4, testing values of k in [4, 6, 8, 10]. 

It did smooth the curve but the perplexity isn’t great and stayed around 15, and the performance downstream tasks is not better as well. We can notice the spike at the beginning and the model struggles to recover from the spike. So to avoid the spike they tried a different scheduler like sigmoid which starts slowly but rises sharply to 1, and they levels off when it approaches 1.  

def sigmoid_scheduler(step, total_steps, k):
    # Sigmoid-like curve: slow start, fast middle, slow end
    normalized_step = step / total_steps
    return 1 / (1 + np.exp(-k * (normalized_step - 0.5)))

For different k values we have the following curves :

plots

They ran 5 experiments this time with k in [15, 20, 25, 40, 100] :

The sharp increase in lambda caused instability around the 500th step and didn’t fix the first convergence issue. But for k = 100, we did observe some improvement in downstream tasks, although the perplexity remained around 13.5. Despite this, it didn’t show a clear performance boost over a linear scheduler.

They even experimented with training models from scratch using random weights and various learning rates. This allowed them to compare the effectiveness of the fine-tuning approach against traditional pre-training methods.

lm_loss

None of the models trained from random weights performed better than the fine-tuned model. The best perplexity they achieved with those models was 26, which falls short compared to the results from our fine-tuning approach.

Scaling to Fine-tuning the Model With 100B Tokens

They tried longer training runs, using the best-performing checkpoint from the shorter runs with the linear scheduler. They continued it until 45000 steps. The model performed closely to the Llama 3 model in some metrics, but in general, it lagged behind. 

Scaling to Fine-tuning the Model With 100B Tokens

Experimenting on Smaller Models

They observed that warmup quantization did not greatly affect the outcome. This suggests that the effectiveness of warmup quantization could be more related to model size and complexity. For example, they tried warmup quantization and full quantization on the SmolLM 135M model. The curves closely align, resulting in the same perplexity. 

Experimenting on smaller models

Accessing using Hugging Face

Models in ternary precision are packed with 2 bits per weight. You can load them directly using from_pretrained, provided the quantization method is specified as BitNet in the config.json.

Installing Dependencies

# start by installing the transformers version with the correct configuration to load bitnet models
!pip install git+https://github.com/huggingface/transformers.git@refs/pull/33410/head

Hugging Face CLI Login

!huggingface-cli login

Enter your HF Token to authenticate and log in.

Import Necessary Libraries

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from IPython.display import Markdown

Load the Fine-tuned model

In the code below, we will use the fine-tuned model of Llama3 – 8B. It is a model fine-tuned based on 1.58bit quantization. The number of tokens used for fine-tuning is 100B. We saw this final model scaling our model with a 100B tokens section. 

model = AutoModelForCausalLM.from_pretrained("HF1BitLLM/Llama3-8B-1.58-100B-tokens", device_map="cuda", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token

Create a Prompt and Generate Output

input_text = """
Which of the following is the capital of France?
A) Berlin
B) Madrid
C) Paris
D) Rome
"""
input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
output = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

Output

Markdown(generated_text)

Conclusion

When compared to baseline techniques, BitNet provides good performance, particularly at lower bit levels. The study claims that BitNet obtains comparable results to 8-bit models but at far reduced inference costs. Because activations are more difficult to measure, approaches that solely quantise weights in 4-bit models perform better than those that quantise both weights and activations. But BitNet outperforms both weight-only and weight-and-activation quantisation techniques; BitNet utilises 1.58-bit weights. I hope you are clear with Fine-tune LLMs to 1.58 bits.

Moreove, the outcomes for several metrics using Llama3 8B’s 10B fine-tuning procedure are shown in the table below. To give a thorough overview of performance, these results are compared to those from various model designs (all evaluations were carried out using Lighteval on the Nanotron format model)

parameters

The model shows outstanding performance after fine-tuning on only 10 billion tokens using ternary weights, particularly compared to other models that underwent more extensive training. For example, it performs better than the Bitnet 7B model, even though the latter was trained on a far larger dataset with 100 billion tokens. Besides, it outperforms the FBI LLM (Fully Binarized LLM) model, refined on an even larger scale of 1.26 trillion tokens. This demonstrates the model’s efficacy and efficiency in spite of the fine-tuning process’s comparatively tiny scale.

Are you looking for an online Generative AI course? If yes, explore this: GenAI Pinnacle Program.

Frequently Asked Questions

Q1. What is quantization in the context of LLMs?

Ans. Quantization reduces the precision of model parameters, like weights, from 16-bit or 32-bit floating points to lower-bit formats (8-bit, 4-bit, or even 1-bit), reducing memory usage and speeding up computation at the cost of some accuracy.

Q2. What is BitNet, and how does it differ from traditional LLMs?

Ans. BitNet is a new 1.58-bit quantized LLM, where each model parameter is represented as {-1, 0, 1}. It achieves similar performance to full-precision models while significantly reducing memory, energy, and computational costs.

Q3. What is STE (Straight-Through Estimator), and why is it used in BitNet?

Ans. STE allows gradients to pass through non-differentiable functions (like rounding) in quantized neural networks, enabling effective training and weight updates even when using low-precision parameters.

Q4. How does BitNet handle fine-tuning using 1.58-bit quantization?

Ans. Fine-tuning starts from pretrained Llama3 models, using techniques like dynamic lambda scheduling to gradually introduce quantization, which helps prevent loss of information and improves convergence.

Q5. What are the advantages of BitNet over traditional 8-bit models?

Ans. BitNet offers similar perplexity and downstream performance while dramatically reducing energy consumption and computational costs, making it a more efficient alternative for large-scale LLMs.

Data science Trainee at Analytics Vidhya, specializing in ML, DL and Gen AI. Dedicated to sharing insights through articles on these subjects. Eager to learn and contribute to the field's advancements. Passionate about leveraging data to solve complex problems and drive innovation.

Responses From Readers

Clear

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details