We all know that Large Language Models are growing in size and complexity. Finding ways to reduce their computational and energy cost is getting difficult. One popular method to reduce cost is quantization. In quantization, we reduce the precision of parameters from the standard 16-bit floating point (FP16) or 32-bit floating point (FP32) to lower-bit formats like 8-bit or 4-bit. This method reduces memory and speeds up computation but it gives a tradeoff with accuracy. Reducing precision so much causes models to lose crucial information. Hence, we get reduced performance. In this article, we will talk about – How to Fine-tune LLMs to 1.58 bits.
New advances in research, like BitNet, are opening the door for 1-bit Large Language Models (LLMs) to become the norm(Fine-tune LLMs to 1.58 bits). They presented BitNet b1.58, a 1-bit LLM variation in which each and every LLM parameter is ternary {-1, 0, 1}.
Note: Perplexity is a metric used to evaluate how well a language model (LLM) predicts the next word in a sequence.
In terms of perplexity and end-task performance, it is comparable to the full-precision (FP16 or BF16) Transformer LLM with the same model size and training tokens. Still, latency, memory, throughput, and energy usage are substantially more economical. More importantly, the 1.58-bit LLM establishes a new scaling law and training recipe for future generations of high-performing and reasonably priced LLMs. It also makes it possible to construct hardware specifically optimized for 1-bit LLMs and create a new paradigm for computation.
One limitation is that we need to train a model from scratch. We can say that the results are very good but not everyone has the budget to pre-train an LLM. Hence to overcome this limitation, authors of this article have explored a few tricks that allow fine-tuning an existing model to 1.58 bits.
This architecture uses INT8 addition calculations when performing matrix multiplication, in contrast to LLaMA LLM’s FP16 addition and multiplication operations. This results in BitNet b1.58 saving 71.4 times the arithmetic operations energy for matrix multiplication compared to Llama baseline.
Energy consumption of BitNet b1.58 compared to LLaMA LLM at 7nm process nodes. On the left are the components of arithmetic operations energy. On the right is the end-to-end energy cost across different model sizes.
BitNet replaces traditional Linear layers in Multi-Head Attention and Feed-Forward Networks with specialized layers called BitLinear. This BitLinear layer uses ternary precision (or even binary in the initial version). One big obstacle when training a ternary precision is that the weights are discretized (using a round() function). This makes weights non-differentiable. If it is non-differentiable, then the weights won’t learn during back propagation. Hence BitNet uses a technique called STE (Straight Through Estimator).
Straight-Through Estimator (STE): The paper provides a detailed study of STE, a technique used to deal with non-differentiable functions that arise in quantized neural networks (QNNs). The STE allows the gradient to “pass-through” discrete variables during backpropagation by approximating their gradients. This is especially crucial in the context of QNNs, where weights and activations are often quantized to lower precision, making them non-differentiable.
A further way to look at it is that the STE allows the gradient to continue as if rounding had never happened, allowing weight updates using conventional gradient-based optimisation methods.
(a) The computation flow of BitLinear. (b) The architecture of BitNet consists of the stacks of attentions and FFNs, where matrix multiplication is implemented as BitLinear.
So authors attempted to reproduce the results from the BitNet paper, they started with a small dataset, tinystories, and a Llama3 8B model. Upon implementation they have confirmed that adding a normalization function improves that performance. They also found that the training was stable. For example, after 2000 steps of training, we had a perplexity on the validation set equal to 6.3 without normalization, and 5.9 with normalization.
This approach reduced the cost while maintaining accuracy, but not many organizations can afford it. Other groups have reported that fine-tuning results were not very promising, so they tested that as well.
When they began fine-tuning (Fine-tune LLMs to 1.58 bits) from the pre-trained Llama3 8B weights, the model performed slightly better but not as well as we expected.
To understand why this is happening, they inspected the weight distribution of the randomly initialized and pre-trained models to find the issues. They also did inspect the scale values of two distributions. They found out that the pretrained model starts with more information, and adding more BitLinear layers overwhelms the model. It loses all its prior information.
Hence, to improve the fine-tuning results, they tried using per-row and per-column quantization instead of per-tensor quantization. This way, they kept more information that was already present in Llama 3. However, they observed that the model lost information when they did quantization. So, to investigate how much information was lost, they experimented with per-group quantization.
As a sanity check, they first set the group size to 1, which essentially means no quantization. In this scenario, the loss started at 1.45, the same as they saw during normal fine-tuning. However, when we increased the group size to 2, the loss jumped to around 11. This indicates that even with a minimal group size of 2, the model still loses nearly all of its information. So, to address this issue, they considered introducing quantization gradually instead of applying it abruptly.
To do this, they introduced a lambda value to control the process. When lambda = 0, no quantization is done, and when lambda = 1, full quantization is done. Initially, they tested discrete lambda values like 0.25, 0.50, 0.75, and 1. But the results were not that significant. This is because at lambda = 0.25, the loss started very high.
Hence, they decided to experiment with a dynamic lambda value that adjusts based on training steps.
lambda_ = training_step / total_training_steps
Using this lambda value led to better loss convergence, but the perplexity was not satisfactory. This was because the model was not trained for long enough with lambda = 1. Hence, to address this, they used the dynamic lambda value below.
lambda_ = min(2 * training_step / total_training_steps, 1)
With this configuration, after 2000 steps:
We can see that this fine-tuning method shows better convergence overall. A slight increase in the loss curve around 1000 steps, but we can see that it improves, leading to a perplexity of approximately 4. Now, they tested the quantized model on the bigger WikiText dataset (not on tiny stories, which was used for fine-tuning); this resulted in high perplexity, which indicates that fine-tuning on low-bit mode causes the model to lose its general knowledge. Hence to overcome this issue they used a larger dataset FineWeb-edu. They used the below dynamic lambda value.
lambda_ = min(training_step/1000, 1)
They chose this lambda value because it was a good starting point for warming up the model. They use a learning rate of 1e-4 for 5,000 steps on the FineWeb-edu dataset. The training involved a batch size (BS) of 2 million, totaling 10 billion tokens. Finding the right learning rate and the right decay was challenging; it seems to be a crucial factor in the model’s performance.
After the completion of fine-tuning on the Fineweb-Edu dataset, the perplexity on the WikiText dataset reached 12.2 using only 10 billion tokens, which is very good.
You can see that there is a sharp increase when the lambda approaches 1. To smooth out this they considered lambda schedulers that grow exponentially at first then level off as they get closer to 1.
def scheduler(step, total_steps, k):
normalized_step = step / total_steps
return 1 - (1 - normalized_step)**k
For different values of k, with a total warmup steps of 1, plots look like the following:
They ran 4 experiments using the best performing learning rate 1e-4, testing values of k in [4, 6, 8, 10].
It did smooth the curve but the perplexity isn’t great and stayed around 15, and the performance downstream tasks is not better as well. We can notice the spike at the beginning and the model struggles to recover from the spike. So to avoid the spike they tried a different scheduler like sigmoid which starts slowly but rises sharply to 1, and they levels off when it approaches 1.
def sigmoid_scheduler(step, total_steps, k):
# Sigmoid-like curve: slow start, fast middle, slow end
normalized_step = step / total_steps
return 1 / (1 + np.exp(-k * (normalized_step - 0.5)))
For different k values we have the following curves :
They ran 5 experiments this time with k in [15, 20, 25, 40, 100] :
The sharp increase in lambda caused instability around the 500th step and didn’t fix the first convergence issue. But for k = 100, we did observe some improvement in downstream tasks, although the perplexity remained around 13.5. Despite this, it didn’t show a clear performance boost over a linear scheduler.
They even experimented with training models from scratch using random weights and various learning rates. This allowed them to compare the effectiveness of the fine-tuning approach against traditional pre-training methods.
None of the models trained from random weights performed better than the fine-tuned model. The best perplexity they achieved with those models was 26, which falls short compared to the results from our fine-tuning approach.
They tried longer training runs, using the best-performing checkpoint from the shorter runs with the linear scheduler. They continued it until 45000 steps. The model performed closely to the Llama 3 model in some metrics, but in general, it lagged behind.
They observed that warmup quantization did not greatly affect the outcome. This suggests that the effectiveness of warmup quantization could be more related to model size and complexity. For example, they tried warmup quantization and full quantization on the SmolLM 135M model. The curves closely align, resulting in the same perplexity.
Models in ternary precision are packed with 2 bits per weight. You can load them directly using from_pretrained, provided the quantization method is specified as BitNet in the config.json.
# start by installing the transformers version with the correct configuration to load bitnet models
!pip install git+https://github.com/huggingface/transformers.git@refs/pull/33410/head
!huggingface-cli login
Enter your HF Token to authenticate and log in.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from IPython.display import Markdown
In the code below, we will use the fine-tuned model of Llama3 – 8B. It is a model fine-tuned based on 1.58bit quantization. The number of tokens used for fine-tuning is 100B. We saw this final model scaling our model with a 100B tokens section.
model = AutoModelForCausalLM.from_pretrained("HF1BitLLM/Llama3-8B-1.58-100B-tokens", device_map="cuda", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token
input_text = """
Which of the following is the capital of France?
A) Berlin
B) Madrid
C) Paris
D) Rome
"""
input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
output = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
Markdown(generated_text)
When compared to baseline techniques, BitNet provides good performance, particularly at lower bit levels. The study claims that BitNet obtains comparable results to 8-bit models but at far reduced inference costs. Because activations are more difficult to measure, approaches that solely quantise weights in 4-bit models perform better than those that quantise both weights and activations. But BitNet outperforms both weight-only and weight-and-activation quantisation techniques; BitNet utilises 1.58-bit weights. I hope you are clear with Fine-tune LLMs to 1.58 bits.
Moreove, the outcomes for several metrics using Llama3 8B’s 10B fine-tuning procedure are shown in the table below. To give a thorough overview of performance, these results are compared to those from various model designs (all evaluations were carried out using Lighteval on the Nanotron format model)
The model shows outstanding performance after fine-tuning on only 10 billion tokens using ternary weights, particularly compared to other models that underwent more extensive training. For example, it performs better than the Bitnet 7B model, even though the latter was trained on a far larger dataset with 100 billion tokens. Besides, it outperforms the FBI LLM (Fully Binarized LLM) model, refined on an even larger scale of 1.26 trillion tokens. This demonstrates the model’s efficacy and efficiency in spite of the fine-tuning process’s comparatively tiny scale.
Are you looking for an online Generative AI course? If yes, explore this: GenAI Pinnacle Program.
Ans. Quantization reduces the precision of model parameters, like weights, from 16-bit or 32-bit floating points to lower-bit formats (8-bit, 4-bit, or even 1-bit), reducing memory usage and speeding up computation at the cost of some accuracy.
Ans. BitNet is a new 1.58-bit quantized LLM, where each model parameter is represented as {-1, 0, 1}. It achieves similar performance to full-precision models while significantly reducing memory, energy, and computational costs.
Ans. STE allows gradients to pass through non-differentiable functions (like rounding) in quantized neural networks, enabling effective training and weight updates even when using low-precision parameters.
Ans. Fine-tuning starts from pretrained Llama3 models, using techniques like dynamic lambda scheduling to gradually introduce quantization, which helps prevent loss of information and improves convergence.
Ans. BitNet offers similar perplexity and downstream performance while dramatically reducing energy consumption and computational costs, making it a more efficient alternative for large-scale LLMs.