Dense Models Scale Badly
A 70B dense model requires loading and computing through all 70B parameters for every token. Scaling to 400B dense parameters would require proportionally more compute and memory per token. Mixture of Experts (MoE) breaks this coupling: you can have 400B total parameters but activate only 50B per token, achieving quality proportional to model size while paying compute cost proportional to activated parameters.
The MoE Architecture
Each Transformer FFN layer is replaced by a set of N expert networks (separate FFN networks with their own parameters) and a gating network that selects which k experts process each token:
- The gating network computes a score for each expert given the current token
- The top-k experts by score are selected (typically k=1 or k=2)
- Only those experts compute their outputs
- Outputs are weighted by softmax-normalized gate scores and summed
The gating network is a simple linear layer: scores = softmax(W_g * x). The sparsity comes from using top-k selection rather than routing to all N experts.
Mixtral 8x7B: The Accessible MoE
Mixtral 8x7B (arXiv:2401.04088) uses 8 experts per FFN layer with top-2 routing. Total parameters: 46.7B. Active parameters per token: approximately 12.9B (2/8 of experts × parameter count per expert). In practice, Mixtral 8x7B matches or exceeds Llama 2 70B on benchmarks while requiring the compute of a 13B dense model.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, d_model, d_ff, n_experts=8, top_k=2):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
self.experts = nn.ModuleList([
nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model))
for _ in range(n_experts)
])
self.gate = nn.Linear(d_model, n_experts, bias=False)
def forward(self, x):
# x: (batch * seq_len, d_model)
gate_logits = self.gate(x)
weights, indices = torch.topk(gate_logits, self.top_k, dim=-1)
weights = F.softmax(weights, dim=-1)
output = torch.zeros_like(x)
for expert_idx in range(self.n_experts):
mask = (indices == expert_idx).any(dim=-1)
if mask.any():
expert_out = self.experts[expert_idx](x[mask])
expert_weight = weights[mask][indices[mask] == expert_idx]
output[mask] += expert_weight.unsqueeze(-1) * expert_out
return output
Load Balancing Loss
Without regularization, the gating network collapses: a few popular experts receive almost all tokens while others are rarely used (expert collapse). The load balancing loss penalizes imbalanced routing by encouraging uniform expert utilization:
L_balance = alpha * sum_i(f_i * P_i)
where f_i is the fraction of tokens routed to expert i and P_i is the average gate probability for expert i.
DeepSeek MoE Fine-Grained Routing
DeepSeek-MoE (arXiv:2401.06066) uses more experts with a smaller size each (64 experts, top-6 selection vs. 8 experts, top-2) and designates some experts as shared (always activated) versus routing (conditionally activated). This fine-grained routing allows better specialization and achieves higher quality at the same activated-parameter count.
Efficiency in Practice
For inference on Mixtral 8x7B, you need 46.7B parameters in memory (still 93GB in fp16), but each forward pass only computes 12.9B parameters worth of operations. This makes it faster per token than Llama 70B for small batch sizes while consuming similar memory.