Mistral AI
8 Jan 2024
Mixtral 8x7B, a sparse mixture of experts model (SMoE).
The feedforward block picks from a set of 8 distinct groups of parameters.
At every layer, for every token, a router network chooses two of these groups to process the token and combine their output additively.
Mixtral is pretrained with multilingual data using a context size of 32k tokens.
The output of the expert layer is given by
\(\sum_{i=0}^{n-1} G(x)_i \cdot E_i(x)\)
\( G(x)_i \) denotes the \(n\)-dim output of the gating network for the \(i\)-th expert.
\( E_i(x) \) is the output of the \(i\)-th expert network.
Avoid computing the outputs of experts whose gates are zero.
\(G(x) := \text{Softmax}(\text{TopK}(x \cdot W_g))\)
The MoE layer is applied independently per token and replaces the feed-forward (FFN) sub-block of the transformer block.
Mixtral use the same SwiGLU as the expert function \(E_i(x)\) and set \(K=2\).
Each token is routed to two SwiGLU sub-blocks with different sets of weights.
The output \(y\) for an input token \( x \) is computed as:
\( y = \sum_{i=0}^{n-1} \text{Softmax}(\text{Top2}(x \cdot W_g))_i \cdot \text{SwiGLU}_i(x). \)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# e.g. batch_size=4, seq_len=8, hidden_dim=16
batch_size, seq_len, hidden_dim = hidden_states.shape
# 所有 Batch 的每個 Token 會使用到的 Expert 不一樣
# router_logits: (batch * sequence_length, n_experts) => (32, 4)
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# 使用 `torch.topk` 選出權重最高的 K 個 Expert
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# 根據選出的 Expert 計算各自的權重佔比
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# 預先配置輸出 Tensor
final_hidden_states = torch.zeros((batch_size * seq_len, hidden_dim))
# 使用 One-Hot Encoding 建立 Expert Mask
# expert_mask: (32, 2, 4) => (4, 2, 32), n_tokens=32, selected_experts=2, n_experts=4
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.n_experts)
expert_mask = expert_mask.permute(2, 1, 0)
# 拜訪每個 Expert 計算各自需要處理的 Token
for expert_idx in range(self.n_experts):
expert_layer = self.experts[expert_idx]
# top_x 為 Expert 要處理的 Token 的所在位置 Index
# idx 代表當前 Expert 是該 Token 第幾名的 Expert
idx, top_x = torch.where(expert_mask[expert_idx])
# 如果此 Expert 沒有要處理的 Token 則跳過
if top_x.shape[0] == 0:
continue
# 根據 top_x 將要處理的 Token 選出來
curr_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
# 加權計算 Token 輸出
curr_hidden_states = expert_layer(curr_state) * routing_weights[top_x, idx, None]
# 將 Hidden States 塞回去
final_hidden_states.index_add_(0, top_x, curr_hidden_states)
final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
return final_hidden_states, router_logits
If one increases \(n\) while keeping \(K\) fixed, one can increase the model's parameter count while keeping its computational cost effectively constant.
The model's total parameter count (sparse) grows with \(n\).
The number of parameters used for processing an token (active) grows with \(K\).
Can be distributed to multiple GPUs through Expert Parallelism (\(EP\)).
Route each tokens to the corresponding GPU for processing.
\(EP\) introduces challenges in load balancing to prevent overloading individual GPUs or hitting computational bottlenecks.
Mixtral 8x7B Instruct outperforms Claude-2.1, Gemini Pro, and GPT-3.5 Turbo.
Mixtral only uses 13B active parameters per token.
Outperforming the previous best model using 70B parameters per token.
Publicly available under the Apache 2.0 license.