Mixtral of Experts
Mistral AI
8 Jan 2024
Introduction
-
Mixtral 8x7B, a sparse mixture of experts model (SMoE).
-
The feedforward block picks from a set of 8 distinct groups of parameters.
-
At every layer, for every token, a router network chooses two of these groups to process the token and combine their output additively.
-
Mixtral is pretrained with multilingual data using a context size of 32k tokens.
Sparse Mixture of Experts
-
The output of the expert layer is given by
∑i=0n−1G(x)i⋅Ei(x)
-
G(x)i denotes the n-dim output of the gating network for the i-th expert.
-
Ei(x) is the output of the i-th expert network.
-
Avoid computing the outputs of experts whose gates are zero.
G(x):=Softmax(TopK(x⋅Wg))
MoE Layers
-
The MoE layer is applied independently per token and replaces the feed-forward (FFN) sub-block of the transformer block.
-
Mixtral use the same SwiGLU as the expert function Ei(x) and set K=2.
-
Each token is routed to two SwiGLU sub-blocks with different sets of weights.
-
The output y for an input token x is computed as:
y=∑i=0n−1Softmax(Top2(x⋅Wg))i⋅SwiGLUi(x).
Mixture of Experts Layer

Mixture of Experts Architecture

MixtralSparseMoeBlock (HF)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# e.g. batch_size=4, seq_len=8, hidden_dim=16
batch_size, seq_len, hidden_dim = hidden_states.shape
# 所有 Batch 的每個 Token 會使用到的 Expert 不一樣
# router_logits: (batch * sequence_length, n_experts) => (32, 4)
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# 使用 `torch.topk` 選出權重最高的 K 個 Expert
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# 根據選出的 Expert 計算各自的權重佔比
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# 預先配置輸出 Tensor
final_hidden_states = torch.zeros((batch_size * seq_len, hidden_dim))
# 使用 One-Hot Encoding 建立 Expert Mask
# expert_mask: (32, 2, 4) => (4, 2, 32), n_tokens=32, selected_experts=2, n_experts=4
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.n_experts)
expert_mask = expert_mask.permute(2, 1, 0)
# 拜訪每個 Expert 計算各自需要處理的 Token
for expert_idx in range(self.n_experts):
expert_layer = self.experts[expert_idx]
# top_x 為 Expert 要處理的 Token 的所在位置 Index
# idx 代表當前 Expert 是該 Token 第幾名的 Expert
idx, top_x = torch.where(expert_mask[expert_idx])
# 如果此 Expert 沒有要處理的 Token 則跳過
if top_x.shape[0] == 0:
continue
# 根據 top_x 將要處理的 Token 選出來
curr_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
# 加權計算 Token 輸出
curr_hidden_states = expert_layer(curr_state) * routing_weights[top_x, idx, None]
# 將 Hidden States 塞回去
final_hidden_states.index_add_(0, top_x, curr_hidden_states)
final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
return final_hidden_states, router_logits
Sparse & Active Parameters
-
If one increases n while keeping K fixed, one can increase the model's parameter count while keeping its computational cost effectively constant.
-
The model's total parameter count (sparse) grows with n.
-
The number of parameters used for processing an token (active) grows with K.
MoE Parallelism
-
Can be distributed to multiple GPUs through Expert Parallelism (EP).
-
Route each tokens to the corresponding GPU for processing.
-
-
EP introduces challenges in load balancing to prevent overloading individual GPUs or hitting computational bottlenecks.
Parameters




Common Benchmarks

Comparison of Mixtral with Llama

Comparison of Mixtral with Llama 2 70B and GPT-3.5

Multilingual Benchmarks

Passkey Retrieval Task

Long Range Performance

Instruction Fine-Tuning

Routing Analysis

Percentage of Expert Assignment Repetitions

First Expert Choice

Conclusion
-
Mixtral 8x7B Instruct outperforms Claude-2.1, Gemini Pro, and GPT-3.5 Turbo.
-
Mixtral only uses 13B active parameters per token.
-
Outperforming the previous best model using 70B parameters per token.
-
Publicly available under the Apache 2.0 license.
Mixtral
By Penut Chen(陳威廷)
Mixtral
Mixtral of Experts
- 10