Large Language models (LLMs) can generate coherent and contextually relevant text since they are trained on extensive datasets and leveraging billions of parameters. This immense scale endows LLMs with emergent properties, such as nuanced understanding and generation capabilities across domains surpassing simpler models. However, these advantages come at the cost of high computational requirements during frequent use. To mitigate these challenges, let us look at an important technique called the mixture of experts, which optimizes resource usage without compromising model performance. We will also explore the Grok-1 architecture to understand how this technique is used.
If you remember the ensemble techniques in Machine Learning, we can take the weighted average of predictions of multiple models to get the final prediction.
The mixture of Experts works similarly. Instead of passing the input through all the model parameters, we can pass it through only a subset of the parameters based on the input token. That subset of the parameters can be considered ‘experts’ for that input.
This selective engagement of model parameters allows for more efficient computation and scalability without reducing the model performance. Since we select only a few experts, this is also called the sparse MOE technique.
How does the model know which experts to select? In MOE, a component known as the router is trained to choose which experts to use for a given input token. We initialize the router’s weight matrix with constant values (e.g., zeros). As the model is trained with more data, the feed-forward router network adjusts these weights based on each expert’s performance, effectively learning which experts excel in handling specific types of inputs.
We keep the weights of top-K experts while making other weights—infinity. Then, we apply softmax to these weights, which outputs the weightage of top-K experts to process the input. We can denote tok-k and softmax operations with this simple equation.
P = Softmax(Top-K(W))
Which components of the LLM can be chosen as experts? To find out, let’s examine the typical LLM architecture.
Let us briefly look at the calculations done in a typical LLM.
Given the hidden_size dimension of h for a token, the parameters can be shown as follows a single decoder layer.
As we can see, There are more parameters in the fully connected layer than in the MHA layers.
So, we can increase the number of MLP layers and then choose only the top K using the routing mechanism for optimal performance and efficiency.
Grok-1 is the largest open-source LLM based on a mixture of experts. Let’s see how this is implemented in Grok-1.
Here are the specifications of Grok-1:
Compared to the typical LLM described above, there are a few differences grok-1.
The attention heads are 48 for queries but 8 for keys or values. This is called Grouped Query Attention.
As we can see from the above picture, In Multi-Head Attention, the number of unique Key and Value vectors equals the number of query attention heads; in Multi-Query Attention, the number of unique Key and Value vectors equals 1.
While Multi-Query Attention reduces model parameters, it also reduces performance. Grouped-query attention balances these two. Here, the number of unique Key and Value vectors equals a certain fraction of query vectors. In Grok, for 48 query vectors, there are 8 key or value vectors.
After the attention block, the weights are concatenated and upscaled by a widening factor.
Let’s look at the grok code to find out the widening factor.
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
return _ffn_size
The widening factor is more like 8/3. So, the embedding size of 6144 is upscaled to 16384.
Here is the code for the dense block
h_v = Linear(
ffn_size(
model_size,self.widening_factor),
with_bias=False, mesh=self.mesh,
sharding=P("data", "model"),
name="linear_v",
)(inputs)
h_w1 = jax.nn.gelu(
Linear(
ffn_size(
model_size,self.widening_factor),
with_bias=False, mesh=self.mesh,
sharding=P("data", "model"),
)(inputs)
)
h_dense = Linear(
model_size,
with_bias=False,
sharding=P("model", "data"),
mesh=self.mesh,
shard_axis=1,
)(h_w1 * h_v)
The input matrix with hidden_size 6144 is upscaled twice parallelly to 16384 as denoted by h_v and h_w1. The GELU activation function is applied only to the second matrix. Then, element-wise multiplication is performed on both of them, and the result is downscaled to the model size 6144.
The MoE layer in Grok-1 orchestrates a flexible and efficient way to leverage multiple expert networks, each specializing in different aspects of the input data. The _inference_call method executes several key steps to achieve this:
Grok-1 uses Jax and Haiku libraries to build the model.
def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):
routing_probs, _, _ = self.router.compute_routing_prob(
inputs, padding_mask, self.num_experts
)
expert_gate, expert_index = jax.lax.top_k(routing_probs, k=self.router.num_selected_experts)
tmp = jnp.reshape(inputs, (inputs.shape[0] * inputs.shape[1], inputs.shape[2]))
broad_inputs = jnp.tile(tmp[:, jnp.newaxis, :], (1, self.router.num_selected_experts, 1))
broad_inputs = jnp.reshape(
broad_inputs, (broad_inputs.shape[0] * broad_inputs.shape[1], broad_inputs.shape[2])
)
router.compute_routing_prob
method, which takes inputs and, optionally, a padding_mask
. Routing probs are calculated as, routing_probs = jax.nn.softmax(router_weights(inputs, num_experts)
where num_experts are 8.k
experts (2 for Grok-1) are selected for each input using jax.lax.top_k
. This ensures that each input is processed by the experts most likely to handle it effectively.In conclusion, Mixture of Experts (MoE) offers a promising avenue for enhancing the efficiency of Large Language Models (LLMs) by selectively engaging subsets of model parameters based on input characteristics. MoE conserves computational resources and maintains high model performance through router mechanisms and optimized architecture. As exemplified by the Grok-1 architecture, MoE demonstrates its potential to revolutionize LLM inference, paving the way for more scalable and effective natural language processing solutions in the future.
Ans. MoE optimizes computational resources by selectively activating subsets of model parameters based on input characteristics, enhancing efficiency without compromising performance.
Ans. The router dynamically selects experts for each input based on routing probabilities learned during training. This ensures that inputs are processed by the most suitable experts, contributing to adaptive and resource-efficient computation.
Ans. Grok-1 uses two parallel upscaling networks and calculates element-wise before downscaling the result. Its innovative approach leverages multiple experts to handle different aspects of input data, leading to breakthroughs in language understanding and generation capabilities.