Transformer:
High memory and compute requirements.
Lack of a single summary state entails slow inference.
RNN:
Summarize an arbitrarily long context in a single hidden state.
Costly to train since training cannot be parallelized across time steps.
Struggle with long distance relationships, which the hidden state captures to only a limited extent.
SSMs are more efficient to train than RNNs and are more capable at handling long distance relationships.
Still lag behind the performance of comparably sized Transformer LMs.
Jamba combines Transformer layers with Mamba layers as well as MoE.
Jamba has 12B active params and 52B total available params.
When scaling Transformers to long contexts, the KV cache becomes bottleneck.
Trading off attention layers for Mamba layers reduces the the KV cache.
Jamba provide an 8x smaller KV cache compared to a vanilla Transformer.
Maintaining a small KV cache even with 256K token contexts.
Attention hogs most of the compute with the long sequence.
Mamba layers are more compute-efficient.
Increasing the ratio of Mamba layers improves throughput.
Each Jamba block is a combination of Mamba or Attention layers.
Each layer contains either an attention or a Mamba module, followed by a MLP.
A Jamba block contains \(l\) layers, which are mixed at a ratio of \(a : m\).
Meaning \(a\) attention layers for every \(m\) Mamba layers.
The MoE module may be applied to MLPs every \(e\) layers.
Positional embeddings or mechanisms like RoPE are not necessary.
\(l = 8\): The number of layers.
\(a:m = 1:7\): ratio attention-to-Mamba layers.
\(e = 2\): how often to use MoE instead of a single MLP.
\(n = 16\): total number of experts.
\(K = 2\): number of top experts used at each token.
Varying batch size, a single 80 GB GPU, int8, 8K context, 512 output tokens.
Jamba allows processing of large batches, leading to a 3x increase in throughput over Mixtral despite having a similar number of active parameters.
Single batch, 4 A100 GPUs, FP16, varying context lengths, output 512 tokens.
Jamba with 128K tokens its throughput is 3x that of Mixtral.
The model was trained on NVIDIA H100 GPUs.
Used an in-house proprietary framework allowing efficient large-scale training including FSDP, tensor/sequence/expert parallelism.
In-house training dataset that contains text data from the Web, books, and code, with the last update in March 2024.
Data processing pipeline includes quality filters and deduplication.
Train on context lengths of up to 1M tokens and support up to 256K tokens.
Only 4 attention layers is enough.
Evaluate ability to handle long contexts using QA, consisting of long inputs.
1B models trained for 250B tokens.
The hybrid Jamba model outperforms the pure attention or Mamba models.
7B models trained for 50B tokens.
The pure Mamba layer lags slightly behind pure Attention.
The pure Mamba model often does not follow the correct format.
Labels: "Positive" or "Negative", Mamba outputs: "Funny", "Bad", ..., etc.
Conjecture that the lack of an attention mechanism in the pure Mamba model makes it difficult for it to learn in-context.
7B parameters trained on 50B tokens.
The MoE variant has \(n = 16\) total experts, \(K = 2\) experts used at each token, and MoE is applied every \(e = 2\) layers.
Encounter large loss spikes when scaling to a larger model.
1.3B parameter models, 250B tokens.
Explicit positional information may not be required for the hybrid architecture.
The Mamba layers provide implicit position information.