Optimize Resource Usage with the Mixture of Experts and Grok-1

Santhosh Reddy Dandavolu Last Updated : 28 Mar, 2024
6 min read

Introduction

Large Language models (LLMs) can generate coherent and contextually relevant text since they are trained on extensive datasets and leveraging billions of parameters. This immense scale endows LLMs with emergent properties, such as nuanced understanding and generation capabilities across domains surpassing simpler models. However, these advantages come at the cost of high computational requirements during frequent use. To mitigate these challenges, let us look at an important technique called the mixture of experts, which optimizes resource usage without compromising model performance. We will also explore the Grok-1 architecture to understand how this technique is used.

Learning Objectives

  1. Understand how Mixture of Experts optimizes computational resources by selectively activating subsets of model parameters.
  2. Explore router mechanisms in MoE, facilitating efficient resource allocation based on input characteristics.
  3. Compare MoE implementation in LLMs, highlighting differences in attention mechanisms and dense block structures.
  4. Learn how to execute a MoE layer in Grok-1 for efficient model inference.

Mixture of Experts

If you remember the ensemble techniques in Machine Learning, we can take the weighted average of predictions of multiple models to get the final prediction.

The mixture of Experts works similarly. Instead of passing the input through all the model parameters, we can pass it through only a subset of the parameters based on the input token. That subset of the parameters can be considered ‘experts’ for that input. 

This selective engagement of model parameters allows for more efficient computation and scalability without reducing the model performance. Since we select only a few experts, this is also called the sparse MOE technique.

Router

How does the model know which experts to select? In MOE, a component known as the router is trained to choose which experts to use for a given input token. We initialize the router’s weight matrix with constant values (e.g., zeros). As the model is trained with more data, the feed-forward router network adjusts these weights based on each expert’s performance, effectively learning which experts excel in handling specific types of inputs. 

"

We keep the weights of top-K experts while making other weights—infinity. Then, we apply softmax to these weights, which outputs the weightage of top-K experts to process the input. We can denote tok-k and softmax operations with this simple equation.

P = Softmax(Top-K(W))

Which components of the LLM can be chosen as experts? To find out, let’s examine the typical LLM architecture.

LLM Architecture

Let us briefly look at the calculations done in a typical LLM.

  1. Input is tokenized, and positional embeddings are added.
  2. Input is multiplied with Q, K, and V weights to get each head’s Q, K, and V matrices.
  3. Attention is calculated as Attention(Q, K, V ) = softmax( QK
  4. Then, it is multiplied by O (output) weights. The results are concatenated from all heads to form the multi-head attention output.
  5. MHA output is upscaled (usually by a factor of 4) and downscaled using fully connected MLP layers, usually incorporating a nonlinear activation function like ReLU.
  6. Points 2 to 5 are repeated for each decoder layer.
  7. The final output is processed to an MLP to produce probabilities of the vocabulary for the next token.

Given the hidden_size dimension of h for a token, the parameters can be shown as follows a single decoder layer.

"

As we can see, There are more parameters in the fully connected layer than in the MHA layers.

So, we can increase the number of MLP layers and then choose only the top K using the routing mechanism for optimal performance and efficiency.

Grok-1 is the largest open-source LLM based on a mixture of experts. Let’s see how this is implemented in Grok-1.

Grok-1 Architecture

Here are the specifications of Grok-1:

Specifications

  1. Parameters: 314B
  2. Architecture: Mixture of 8 Experts (MoE)
  3. Experts Utilization: 2 experts are used per token
  4. Layers: 64
  5. Attention Heads: 48 for queries, 8 for keys/values
  6. Embedding Size: 6,144
  7. Tokenization: SentencePiece tokenizer with 131,072 tokens
  8. Additional Features
    • Rotary embeddings (RoPE)
    • Supports activation sharding and 8-bit quantization
  9. Maximum Sequence Length (context): 8,192 tokens

Compared to the typical LLM described above, there are a few differences grok-1.

Attention Block

The attention heads are 48 for queries but 8 for keys or values. This is called Grouped Query Attention.

"

As we can see from the above picture, In Multi-Head Attention, the number of unique Key and Value vectors equals the number of query attention heads; in Multi-Query Attention, the number of unique Key and Value vectors equals 1.

While Multi-Query Attention reduces model parameters, it also reduces performance. Grouped-query attention balances these two. Here, the number of unique Key and Value vectors equals a certain fraction of query vectors. In Grok, for 48 query vectors, there are 8 key or value vectors.

Dense Block

After the attention block, the weights are concatenated and upscaled by a widening factor.

Let’s look at the grok code to find out the widening factor.

Grok-1 Github

def ffn_size(emb_size, widening_factor):

  _ffn_size = int(widening_factor * emb_size) * 2 // 3

  _ffn_size = _ffn_size + (8 - _ffn_size) % 8  # ensure it's a multiple of 8

  logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")

  return _ffn_size

The widening factor is more like 8/3. So, the embedding size of 6144 is upscaled to 16384.

Here is the code for the dense block

	h_v = Linear(

        	ffn_size(

            	model_size,self.widening_factor),

        	with_bias=False, mesh=self.mesh,

        	sharding=P("data", "model"),

        	name="linear_v",

    	)(inputs)

    	h_w1 = jax.nn.gelu(

        	Linear(

            	ffn_size(

            	model_size,self.widening_factor),

            	with_bias=False, mesh=self.mesh,

            	sharding=P("data", "model"),

        	)(inputs)

    	)

    	h_dense = Linear(

        	model_size,

        	with_bias=False,

        	sharding=P("model", "data"),

        	mesh=self.mesh,

        	shard_axis=1,

    	)(h_w1 * h_v)

The input matrix with hidden_size 6144 is upscaled twice parallelly to 16384 as denoted by h_v and h_w1. The GELU activation function is applied only to the second matrix. Then, element-wise multiplication is performed on both of them, and the result is downscaled to the model size 6144.

MOE Layer

The MoE layer in Grok-1 orchestrates a flexible and efficient way to leverage multiple expert networks, each specializing in different aspects of the input data. The _inference_call method executes several key steps to achieve this:

Grok-1 uses Jax and Haiku libraries to build the model.

def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):
    	routing_probs, _, _ = self.router.compute_routing_prob(
        	inputs, padding_mask, self.num_experts
    	)
    	expert_gate, expert_index = jax.lax.top_k(routing_probs, k=self.router.num_selected_experts)
    	tmp = jnp.reshape(inputs, (inputs.shape[0] * inputs.shape[1], inputs.shape[2]))
    	broad_inputs = jnp.tile(tmp[:, jnp.newaxis, :], (1, self.router.num_selected_experts, 1))
    	broad_inputs = jnp.reshape(
        	broad_inputs, (broad_inputs.shape[0] * broad_inputs.shape[1], broad_inputs.shape[2])
    	)
  1. It starts by calculating routing probabilities for each piece of input data, determining how inputs are distributed across the available experts. This is achieved through the router.compute_routing_prob method, which takes inputs and, optionally, a padding_mask. Routing probs are calculated as, routing_probs = jax.nn.softmax(router_weights(inputs, num_experts) where num_experts are 8.
  2. Based on the routing probabilities, the top k experts (2 for Grok-1) are selected for each input using jax.lax.top_k. This ensures that each input is processed by the experts most likely to handle it effectively.
  3. The rest of the code prepares the input data to be processed by the haiku library using various transformations.
  4. Then, as we have seen with the dense block configuration, inputs are passed through two parallel upscaling MLPs. The GELU activation function is applied to the second one; both are multiplied, and the result is downscaled to the original dimension 6144.

Conclusion

In conclusion, Mixture of Experts (MoE) offers a promising avenue for enhancing the efficiency of Large Language Models (LLMs) by selectively engaging subsets of model parameters based on input characteristics. MoE conserves computational resources and maintains high model performance through router mechanisms and optimized architecture. As exemplified by the Grok-1 architecture, MoE demonstrates its potential to revolutionize LLM inference, paving the way for more scalable and effective natural language processing solutions in the future.

Key Takeaways

  1. A mixture of Experts (MoE) optimizes large language models (LLMs) by selectively activating subsets of model parameters, enhancing efficiency without compromising performance.
  2. The router mechanism in MoE dynamically selects experts based on input characteristics, allowing for adaptive and resource-efficient computation.
  3. Grok-1 architecture showcases MoE’s potential in LLMs, offering scalable and effective solutions for natural language processing tasks.
  4. Embracing MoE can lead to breakthroughs in LLM inference, enabling advancements in diverse domains requiring sophisticated language understanding and generation capabilities.

Frequently Asked Questions

Q1. What is the advantage of using a Mixture of Experts (MoE) in large language models (LLMs)?

Ans. MoE optimizes computational resources by selectively activating subsets of model parameters based on input characteristics, enhancing efficiency without compromising performance.

Q2. How does the router mechanism work in MoE?

Ans. The router dynamically selects experts for each input based on routing probabilities learned during training. This ensures that inputs are processed by the most suitable experts, contributing to adaptive and resource-efficient computation.

Q3. What distinguishes Grok-1 architecture from traditional LLMs in utilizing MoE?

Ans. Grok-1 uses two parallel upscaling networks and calculates element-wise before downscaling the result. Its innovative approach leverages multiple experts to handle different aspects of input data, leading to breakthroughs in language understanding and generation capabilities.

I am working as an Associate Data Scientist at Analytics Vidhya, a platform dedicated to building the Data Science ecosystem. My interests lie in the fields of Natural Language Processing (NLP), Deep Learning, and AI Agents.

Responses From Readers

Clear

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details