Sparsely-gated Mixture Of Experts (MoE)



In transformer models, the attention block is typically followed by a feed forward layer (FF), which is a simple fully-connected NN with a hidden layer and nonlinearity. Here's the code for such a block that uses ReLU:

def feed_forward_relu(x, W1, W2):
    """Feed-forward layer with ReLU activation.

    Args:
        x: Input tensor (B, N, D).
        Wh: Weights for the hidden layer (D, DH).
        Wo: Weights for the output layer (DH, D).

    Returns:
        Output tensor (B, N, D).
    """
    x = x @ W1  # hidden layer (B, N, DH)
    x = np.maximum(0, x)  # ReLU activation (B, N, DH)
    x = x @ W2  # output layer (B, N, D)
    return x

This layer typically holds most of the weights in the transformer, because the hidden dimension (DH in this post, hidden_dim in some papers) is large - 4x the embedding depth D is common. Intuitively, this makes sense because this layer does the majority of the heavy lifting; while the attention block mixes the embeddings of tokens together to express their relationships to one another, the FF block does the actual "reasoning" on these tokens.

Transformer blocks are repeated dozens of times in a model, so the total size of these layers becomes problematic. One approach for improving efficiency that became very popular is called sparsely-gated mixture of experts (paper).

Mixture of Experts architecture (MoE)

The basic idea of MoE is this:

  • The large FF layer is split into a number (NEXP) of blocks called "experts". Each expert is still a FF layer. It takes a vector of size D and transforms it to another vector of size D.
  • There's an additional piece called "router" or "gate". This is just a fully-connected layer (D, NEXP) that takes a token and produces a score for each expert. The router is learned by the model, along with the experts themselves.
  • K experts with the highest scores are selected for each token, and the token is only fed through these experts.
  • The scores are also used to calculate a weighted average from the experts' outputs, eventually producing an answer of size D.

Here's a diagram for a single token, assuming NEXP=8 and TOPK=2 (the two highest scoring experts are selected for each token, out of a total of eight):

mixture of experts diagram

Notes:

  • Experts #1 and #5 are selected because the router produced the highest scores for them among all the experts. The input token is routed to these experts, but not to the others.
  • The output of each expert is element-wise multiplied by a corresponding weight, calculated from the scores of the selected experts using a softmax function (to ensure balanced weighting across multiple tokens and experts).
  • The weighted expert outputs are added up for the final output of the layer.

The key point to understand about this architecture is: the experts that were not among the top K for a token aren't used at all - the computation required to propagate the token through these experts is eschewed (both on the forward and backward passes).

This is the goal of the MoE architecture - we increase the overall model size, but keep the computational cost in check by only using a portion of the parameters for every single token. This is also reflected in the models' names; for example, the Mixtral model has size 8x7B; it has 8 experts, and it would be incorrect to just multiply the size of each expert by 8 because not all these parameters participate in the calculation of every token [1]. According to the Mixtral paper, the model only uses 13B active parameters for each token.

A summary of the idea behind MoE is:

MoE increases the model's capacity without proportionally increasing its computational cost.

Numpy implementation

Here's a well-commented implementation of the MoE layer using pure Numpy. First, some parameters:

# Parameters for a feed-forward layer with a fixed activation function.
@dataclass
class FFParams:
    Wh: np.ndarray
    Wo: np.ndarray


# Parameters for a Mixture of Experts (MoE) layer.
@dataclass
class MoEParams:
    # Embedding dimension of each token (a.k.a. model dimension, Dmodel)
    D: int

    # Hidden dimension in FF layers
    DH: int

    # Total number of experts
    NEXP: int

    # K in the top-k selection of top experts per token
    TOPK: int

    # List of experts: each expert is a forward layer with FFParams.
    ff_weights: List[FFParams]

    # Router weights: a linear layer (D, NEXP) that maps input to expert scores.
    router_weights: np.ndarray

And now the implementation. Note that it takes a general (B, N, D) input, assuming batch dimension D and sequence length N:

def moe(x: np.ndarray, params: MoEParams):
    """Mixture of Experts (MoE) layer.

    Args:
        x: Input tensor (B, N, D).
        params: MoEParams.

    Returns:
        Output tensor (B, N, D).
    """
    # Run input through router to get expert scores for each token.
    expert_scores = x @ params.router_weights  # (B, N, NEXP)

    # Select the top-k expert scores and their indices for each token.
    top_scores, top_experts = topk_lastdim(expert_scores, params.TOPK)  # (B, N, TOPK)

    # Apply softmax to the top scores to get weights that sum to 1.
    weights = softmax_lastdim(top_scores)  # (B, N, TOPK)

    out = np.zeros_like(x)
    for b in range(x.shape[0]):
        for n in range(x.shape[1]):
            # Unvectorized implementation: for each token in the batch and
            # sequence, select the top-k experts and apply them with the
            # calculated weights.
            for expert_idx, weight in zip(top_experts[b, n], weights[b, n]):
                expert = params.ff_weights[expert_idx]
                out[b, n] += weight * feed_forward_relu(x[b, n], expert.Wh, expert.Wo)

    return out

Calculating the experts themselves is not vectorized here - it is done token by token. MoE is inherently sparse: different tokens in the same sequence (and batch) may go through different sets of experts. Vectorizing this efficiently is tricky in general and depends on the HW we run the model on [2]. For a popular approach on GPUs, see the MegaBlocks paper from 2022. This remains an active area of research.

All that's left are some helper functions:

def topk_lastdim(x, k):
    """Get the top k elements and their indices.

    x is an arbitrary array with at least two dimensions. The returned
    array has the same shape as x, but its elements are the top k elements
    across the last dimension. The indices of the top k elements are also
    returned.
    """
    idx = np.argpartition(x, -k, axis=-1)[..., -k:]
    return np.take_along_axis(x, idx, axis=-1), idx

def softmax_lastdim(x):
    """Compute softmax across last dimension of x.

    x is an arbitrary array with at least two dimensions. The returned array has
    the same shape as x, but its elements sum up to 1 across the last dimension.
    """
    # Subtract the max for numerical stability
    ex = np.exp(x - np.max(x, axis=-1, keepdims=True))
    # Divide by sums across last dimension
    return ex / np.sum(ex, axis=-1, keepdims=True)

Additional considerations

A major area of focus with MoE architectures is load balancing among experts. Without special provisions, the model may learn to prefer certain experts over others and this leads to inefficient utilization of the model's weights. There are various approaches to tackle this, for example:

  • Adding noise to the top-k selection process to inject randomness
  • Defining a special loss function during training that encourages experts to receive a roughly equal number of training samples

Code

The full code for this post is available on GitHub.


[1]Another way to think about MoE is that each "expert" specializes in a certain area of the model's capability. For example, one expert would be good at math, another at prose, etc. This is a very rough approximation, though, because transformer models consist of dozens of repeating blocks, and all these different experts end up thoroughly intermixed as tokens flow through the entire model.
[2]

In the sparsely-gated mixture of experts paper, this is referred to as The Shrinking Batch Problem:

"In modern CPUs and GPUs, large batch sizes are necessary for computational efficiency, so as to amortize the overhead of parameter loads and updates. If the gating network chooses k out of n experts for each example, then for a batch of b examples, each expert receives a much smaller batch of approximately kb/n << b examples. This causes a naive MoE implementation to become very inefficient as the number of experts increases"


Recent posts

2025.04.12: Cross-entropy and KL divergence
2025.04.05: Reproducing word2vec with JAX
2025.03.31: Summary of reading: January - March 2025
2025.03.26: Notes on implementing Attention
2025.03.22: Understanding Numpy's einsum
2025.02.22: Making any integer with four 2s
2025.02.18: Benchmarking utility for Python
2025.02.03: Decorator JITs - Python as a DSL
2025.01.13: Reverse mode Automatic Differentiation
2024.12.31: Summary of reading: October - December 2024

See Archives for a full list.