Mixtral of Experts

 

Mistral AI

 

8 Jan 2024

Introduction

  • Mixtral 8x7B, a sparse mixture of experts model (SMoE).

  • The feedforward block picks from a set of 8 distinct groups of parameters.

  • At every layer, for every token, a router network chooses two of these groups to process the token and combine their output additively.

  • Mixtral is pretrained with multilingual data using a context size of 32k tokens.

Sparse Mixture of Experts

  • The output of the expert layer is given by

\(\sum_{i=0}^{n-1} G(x)_i \cdot E_i(x)\)

  • \( G(x)_i \) denotes the \(n\)-dim output of the gating network for the \(i\)-th expert.

  • \( E_i(x) \) is the output of the \(i\)-th expert network.

  • Avoid computing the outputs of experts whose gates are zero.

\(G(x) := \text{Softmax}(\text{TopK}(x \cdot W_g))\)

MoE Layers

  • The MoE layer is applied independently per token and replaces the feed-forward (FFN) sub-block of the transformer block.

  • Mixtral use the same SwiGLU as the expert function \(E_i(x)\) and set \(K=2\).

  • Each token is routed to two SwiGLU sub-blocks with different sets of weights.

  • The output \(y\) for an input token \( x \) is computed as:

\( y = \sum_{i=0}^{n-1} \text{Softmax}(\text{Top2}(x \cdot W_g))_i \cdot \text{SwiGLU}_i(x). \)

Mixture of Experts Layer

Mixture of Experts Architecture

MixtralSparseMoeBlock (HF)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    # e.g. batch_size=4, seq_len=8, hidden_dim=16
    batch_size, seq_len, hidden_dim = hidden_states.shape

    # 所有 Batch 的每個 Token 會使用到的 Expert 不一樣
    # router_logits: (batch * sequence_length, n_experts) => (32, 4)
    hidden_states = hidden_states.view(-1, hidden_dim)
    router_logits = self.gate(hidden_states)
    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

    # 使用 `torch.topk` 選出權重最高的 K 個 Expert
    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

    # 根據選出的 Expert 計算各自的權重佔比
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

    # 預先配置輸出 Tensor
    final_hidden_states = torch.zeros((batch_size * seq_len, hidden_dim))

    # 使用 One-Hot Encoding 建立 Expert Mask
    # expert_mask: (32, 2, 4) => (4, 2, 32), n_tokens=32, selected_experts=2, n_experts=4
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.n_experts)
    expert_mask = expert_mask.permute(2, 1, 0)

    # 拜訪每個 Expert 計算各自需要處理的 Token
    for expert_idx in range(self.n_experts):
        expert_layer = self.experts[expert_idx]

        # top_x 為 Expert 要處理的 Token 的所在位置 Index
        # idx 代表當前 Expert 是該 Token 第幾名的 Expert
        idx, top_x = torch.where(expert_mask[expert_idx])

        # 如果此 Expert 沒有要處理的 Token 則跳過
        if top_x.shape[0] == 0:
            continue

        # 根據 top_x 將要處理的 Token 選出來
        curr_state = hidden_states[None, top_x].reshape(-1, hidden_dim)

        # 加權計算 Token 輸出
        curr_hidden_states = expert_layer(curr_state) * routing_weights[top_x, idx, None]

        # 將 Hidden States 塞回去
        final_hidden_states.index_add_(0, top_x, curr_hidden_states)

    final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

    return final_hidden_states, router_logits

Sparse & Active Parameters

  • If one increases \(n\) while keeping \(K\) fixed, one can increase the model's parameter count while keeping its computational cost effectively constant.

  • The model's total parameter count (sparse) grows with \(n\).

  • The number of parameters used for processing an token (active) grows with \(K\).

MoE Parallelism

  • Can be distributed to multiple GPUs through Expert Parallelism (\(EP\)).

    • Route each tokens to the corresponding GPU for processing.

  • \(EP\) introduces challenges in load balancing to prevent overloading individual GPUs or hitting computational bottlenecks.

Parameters

Common Benchmarks

Comparison of Mixtral with Llama

Comparison of Mixtral with Llama 2 70B and GPT-3.5

Multilingual Benchmarks

Passkey Retrieval Task

Long Range Performance

Instruction Fine-Tuning

Routing Analysis

Percentage of Expert Assignment Repetitions

First Expert Choice

Conclusion

  • Mixtral 8x7B Instruct outperforms Claude-2.1, Gemini Pro, and GPT-3.5 Turbo.

  • Mixtral only uses 13B active parameters per token.

  • Outperforming the previous best model using 70B parameters per token.

  • Publicly available under the Apache 2.0 license.

  • Demo Link

Made with Slides.com