Transformer vs RNN
-
Transformer:
-
High memory and compute requirements.
-
Lack of a single summary state entails slow inference.
-
-
RNN:
-
Summarize an arbitrarily long context in a single hidden state.
-
Costly to train since training cannot be parallelized across time steps.
-
Struggle with long distance relationships, which the hidden state captures to only a limited extent.
-
State Space Model
-
SSMs are more efficient to train than RNNs and are more capable at handling long distance relationships.
-
Still lag behind the performance of comparably sized Transformer LMs.
Jamba
-
Jamba combines Transformer layers with Mamba layers as well as MoE.
-
Jamba has 12B active params and 52B total available params.
Memory
-
When scaling Transformers to long contexts, the KV cache becomes bottleneck.
-
Trading off attention layers for Mamba layers reduces the the KV cache.
-
Jamba provide an 8x smaller KV cache compared to a vanilla Transformer.
-
Maintaining a small KV cache even with 256K token contexts.

Throughput
-
Attention hogs most of the compute with the long sequence.
-
Mamba layers are more compute-efficient.
-
Increasing the ratio of Mamba layers improves throughput.
Jamba Blocks
-
Each Jamba block is a combination of Mamba or Attention layers.
-
Each layer contains either an attention or a Mamba module, followed by a MLP.
-
A Jamba block contains l layers, which are mixed at a ratio of a:m.
-
Meaning a attention layers for every m Mamba layers.
-
The MoE module may be applied to MLPs every e layers.
-
Positional embeddings or mechanisms like RoPE are not necessary.

Jamba Configuration
-
l=8: The number of layers.
-
a:m=1:7: ratio attention-to-Mamba layers.
-
e=2: how often to use MoE instead of a single MLP.
-
n=16: total number of experts.
-
K=2: number of top experts used at each token.


Throughput Analysis
-
Varying batch size, a single 80 GB GPU, int8, 8K context, 512 output tokens.
-
Jamba allows processing of large batches, leading to a 3x increase in throughput over Mixtral despite having a similar number of active parameters.

Throughput Analysis
-
Single batch, 4 A100 GPUs, FP16, varying context lengths, output 512 tokens.
-
Jamba with 128K tokens its throughput is 3x that of Mixtral.

Training Infrastructure and Dataset
-
The model was trained on NVIDIA H100 GPUs.
-
Used an in-house proprietary framework allowing efficient large-scale training including FSDP, tensor/sequence/expert parallelism.
-
In-house training dataset that contains text data from the Web, books, and code, with the last update in March 2024.
-
Data processing pipeline includes quality filters and deduplication.
-
Train on context lengths of up to 1M tokens and support up to 256K tokens.
Academic Benchmarks


Needle-In-A-Haystack
-
Only 4 attention layers is enough.

Naturalistic Long-Context Evaluation
-
Evaluate ability to handle long contexts using QA, consisting of long inputs.

Abalation - Attention & Mamba (1B)
-
1B models trained for 250B tokens.
-
The hybrid Jamba model outperforms the pure attention or Mamba models.

Training Loss (1B)

Abalation - Attention & Mamba (7B)
-
7B models trained for 50B tokens.
-
The pure Mamba layer lags slightly behind pure Attention.

Training Loss (7B)

Why does the Combination Work?
-
The pure Mamba model often does not follow the correct format.
-
Labels: "Positive" or "Negative", Mamba outputs: "Funny", "Bad", ..., etc.
-
-
Conjecture that the lack of an attention mechanism in the pure Mamba model makes it difficult for it to learn in-context.

The Effect of Mixture-of-Experts
-
7B parameters trained on 50B tokens.
-
The MoE variant has n=16 total experts, K=2 experts used at each token, and MoE is applied every e=2 layers.

Stabilizing Mamba at Large Scale
-
Encounter large loss spikes when scaling to a larger model.

Jamba Does Not Require Explicit Positional Information
-
1.3B parameter models, 250B tokens.
-
Explicit positional information may not be required for the hybrid architecture.
-
The Mamba layers provide implicit position information.

Jamba
By Penut Chen(陳威廷)
Jamba
Jamba: A Hybrid Transformer-Mamba Language Model
- 5