Jamba: A Hybrid Transformer-Mamba Language Model

 

AI21 Labs

 

28 Mar 2024

 

Paper | HF Hub

Transformer vs RNN

  • Transformer:

    • High memory and compute requirements.

    • Lack of a single summary state entails slow inference.

  • RNN:

    • Summarize an arbitrarily long context in a single hidden state.

    • Costly to train since training cannot be parallelized across time steps.

    • Struggle with long distance relationships, which the hidden state captures to only a limited extent. 

State Space Model

  • SSMs are more efficient to train than RNNs and are more capable at handling long distance relationships.

  • Still lag behind the performance of comparably sized Transformer LMs.

Jamba

  • Jamba combines Transformer layers with Mamba layers as well as MoE.

  • Jamba has 12B active params and 52B total available params.

Memory

  • When scaling Transformers to long contexts, the KV cache becomes bottleneck.

  • Trading off attention layers for Mamba layers reduces the the KV cache.

  • Jamba provide an 8x smaller KV cache compared to a vanilla Transformer.

  • Maintaining a small KV cache even with 256K token contexts.

Throughput

  • Attention hogs most of the compute with the long sequence.

  • Mamba layers are more compute-efficient.

  • Increasing the ratio of Mamba layers improves throughput.

Jamba Blocks

  • Each Jamba block is a combination of Mamba or Attention layers.

  • Each layer contains either an attention or a Mamba module, followed by a MLP.

  • A Jamba block contains \(l\) layers, which are mixed at a ratio of \(a : m\).

  • Meaning \(a\) attention layers for every \(m\) Mamba layers.

  • The MoE module may be applied to MLPs every \(e\) layers.

  • Positional embeddings or mechanisms like RoPE are not necessary.

Jamba Configuration

  • \(l = 8\): The number of layers.

  • \(a:m = 1:7\): ratio attention-to-Mamba layers.

  • \(e = 2\): how often to use MoE instead of a single MLP.

  • \(n = 16\): total number of experts.

  • \(K = 2\): number of top experts used at each token.

Throughput Analysis

  • Varying batch size, a single 80 GB GPU, int8, 8K context, 512 output tokens.

  • Jamba allows processing of large batches, leading to a 3x increase in throughput over Mixtral despite having a similar number of active parameters.

Throughput Analysis

  • Single batch, 4 A100 GPUs, FP16, varying context lengths, output 512 tokens.

  • Jamba with 128K tokens its throughput is 3x that of Mixtral.

Training Infrastructure and Dataset

  • The model was trained on NVIDIA H100 GPUs.

  • Used an in-house proprietary framework allowing efficient large-scale training including FSDP, tensor/sequence/expert parallelism.

  • In-house training dataset that contains text data from the Web, books, and code, with the last update in March 2024.

  • Data processing pipeline includes quality filters and deduplication. 

  • Train on context lengths of up to 1M tokens and support up to 256K tokens.

Academic Benchmarks

Needle-In-A-Haystack

  • Only 4 attention layers is enough.

Naturalistic Long-Context Evaluation

  • Evaluate ability to handle long contexts using QA, consisting of long inputs.

Abalation - Attention & Mamba (1B)

  • 1B models trained for 250B tokens.

  • The hybrid Jamba model outperforms the pure attention or Mamba models.

Training Loss (1B)

Abalation - Attention & Mamba (7B)

  • 7B models trained for 50B tokens.

  • The pure Mamba layer lags slightly behind pure Attention.

Training Loss (7B)

Why does the Combination Work?

  • The pure Mamba model often does not follow the correct format.

    • Labels: "Positive" or "Negative", Mamba outputs: "Funny", "Bad", ..., etc.

  • Conjecture that the lack of an attention mechanism in the pure Mamba model makes it difficult for it to learn in-context.

The Effect of Mixture-of-Experts

  • 7B parameters trained on 50B tokens.

  • The MoE variant has \(n = 16\) total experts, \(K = 2\) experts used at each token, and MoE is applied every \(e = 2\) layers.

Stabilizing Mamba at Large Scale

  • Encounter large loss spikes when scaling to a larger model.

Jamba Does Not Require Explicit Positional Information

  • 1.3B parameter models, 250B tokens.

  • Explicit positional information may not be required for the hybrid architecture.

  • The Mamba layers provide implicit position information.

Made with Slides.com