Mixtral of Experts
Mistral AI
8 Jan 2024
Introduction
-
Mixtral 8x7B, a sparse mixture of experts model (SMoE).
-
The feedforward block picks from a set of 8 distinct groups of parameters.
-
At every layer, for every token, a router network chooses two of these groups to process the token and combine their output additively.
-
Mixtral is pretrained with multilingual data using a context size of 32k tokens.
Sparse Mixture of Experts
-
The output of the expert layer is given by
\(\sum_{i=0}^{n-1} G(x)_i \cdot E_i(x)\)
-
\( G(x)_i \) denotes the \(n\)-dim output of the gating network for the \(i\)-th expert.
-
\( E_i(x) \) is the output of the \(i\)-th expert network.
-
Avoid computing the outputs of experts whose gates are zero.
\(G(x) := \text{Softmax}(\text{TopK}(x \cdot W_g))\)
MoE Layers
-
The MoE layer is applied independently per token and replaces the feed-forward (FFN) sub-block of the transformer block.
-
Mixtral use the same SwiGLU as the expert function \(E_i(x)\) and set \(K=2\).
-
Each token is routed to two SwiGLU sub-blocks with different sets of weights.
-
The output \(y\) for an input token \( x \) is computed as:
\( y = \sum_{i=0}^{n-1} \text{Softmax}(\text{Top2}(x \cdot W_g))_i \cdot \text{SwiGLU}_i(x). \)
Mixture of Experts Layer
Mixture of Experts Architecture
MixtralSparseMoeBlock (HF)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# e.g. batch_size=4, seq_len=8, hidden_dim=16
batch_size, seq_len, hidden_dim = hidden_states.shape
# 所有 Batch 的每個 Token 會使用到的 Expert 不一樣
# router_logits: (batch * sequence_length, n_experts) => (32, 4)
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# 使用 `torch.topk` 選出權重最高的 K 個 Expert
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# 根據選出的 Expert 計算各自的權重佔比
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# 預先配置輸出 Tensor
final_hidden_states = torch.zeros((batch_size * seq_len, hidden_dim))
# 使用 One-Hot Encoding 建立 Expert Mask
# expert_mask: (32, 2, 4) => (4, 2, 32), n_tokens=32, selected_experts=2, n_experts=4
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.n_experts)
expert_mask = expert_mask.permute(2, 1, 0)
# 拜訪每個 Expert 計算各自需要處理的 Token
for expert_idx in range(self.n_experts):
expert_layer = self.experts[expert_idx]
# top_x 為 Expert 要處理的 Token 的所在位置 Index
# idx 代表當前 Expert 是該 Token 第幾名的 Expert
idx, top_x = torch.where(expert_mask[expert_idx])
# 如果此 Expert 沒有要處理的 Token 則跳過
if top_x.shape[0] == 0:
continue
# 根據 top_x 將要處理的 Token 選出來
curr_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
# 加權計算 Token 輸出
curr_hidden_states = expert_layer(curr_state) * routing_weights[top_x, idx, None]
# 將 Hidden States 塞回去
final_hidden_states.index_add_(0, top_x, curr_hidden_states)
final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
return final_hidden_states, router_logits
Sparse & Active Parameters
-
If one increases \(n\) while keeping \(K\) fixed, one can increase the model's parameter count while keeping its computational cost effectively constant.
-
The model's total parameter count (sparse) grows with \(n\).
-
The number of parameters used for processing an token (active) grows with \(K\).
MoE Parallelism
-
Can be distributed to multiple GPUs through Expert Parallelism (\(EP\)).
-
Route each tokens to the corresponding GPU for processing.
-
-
\(EP\) introduces challenges in load balancing to prevent overloading individual GPUs or hitting computational bottlenecks.
Parameters
Common Benchmarks
Comparison of Mixtral with Llama
Comparison of Mixtral with Llama 2 70B and GPT-3.5
Multilingual Benchmarks
Passkey Retrieval Task
Long Range Performance
Instruction Fine-Tuning
Routing Analysis
Percentage of Expert Assignment Repetitions
First Expert Choice
Conclusion
-
Mixtral 8x7B Instruct outperforms Claude-2.1, Gemini Pro, and GPT-3.5 Turbo.
-
Mixtral only uses 13B active parameters per token.
-
Outperforming the previous best model using 70B parameters per token.
-
Publicly available under the Apache 2.0 license.
Mixtral
By Penut Chen(陳威廷)
Mixtral
Mixtral of Experts
- 5