Arxiv Dives

ArXiv Dives:šŸ’ƒĀ Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling

Mathias
Jun 26, 2024

Modeling sequences with infinite context length is one of the dreams of Large Language models. Some LLMs such as Transformers suffer from quadratic computational complexity, making the infinite context length problem computationally infeasible. While other models such as linear recurrent neural networks are great for speed, but lack the ability to remember details far back in itā€™s context.

Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling
Efficiently modeling sequences with infinite context length has long been a challenging problem. Previous approaches have either suffered from quadratic computational complexity or limited extrapolation ability in length generalization. In this work, we present Samba, a simple hybrid architecture that layer-wise combines Mamba, a selective State Space Model (SSM), with Sliding Window Attention (SWA). Samba selectively compresses a given sequence into recurrent hidden states while still maintaining the ability to precisely recall recent memories with the attention mechanism. We scale Samba up to 3.8B parameters with 3.2T training tokens and demonstrate that it significantly outperforms state-of-the-art models across a variety of benchmarks. Pretrained on sequences of 4K length, Samba shows improved perplexity in context lengths of up to 1M in zero-shot. When finetuned on 4K-length sequences, Samba efficiently extrapolates to a 256K context length with perfect memory recall on the Passkey Retrieval task, and exhibits superior retrieval extrapolation on the challenging Phonebook task compared to full-attention models. As a linear-time sequence model, Samba achieves a 3.73x higher throughput compared to Transformers with grouped-query attention for user prompts of 128K length, and a 3.64x speedup when generating 64K tokens with unlimited streaming. Our code for training on open source data is publicly available at https://github.com/microsoft/Samba.

Samba fixes these issues by combining Mamba layers with Sliding Window Attention (SWA) and Multilayer Perceptions (MLP) to create a model with hybrid of techniques.

Each one of these layers are like legos or building blocks. If you understand the motivation behind each layer, you can start stacking them and combining them in interesting ways to gain performance in terms of both accuracy and throughput. They scale SAMBA up to 3.8B parameters trained on 3.2Trillion Tokens, showing that SAMBA outperforms state of the art models based on pure attention or state spaces. SAMBA has 3.73x higher throughput than transformers and can recall up to 256K tokens with perfect memory recall, while maintaining perplexity up to 1M content length.

Breaking Down the Building Blocks

The hybrid model consists of three main layers:

  • Mamba Style State Spaces - great for speed, and modeling recurrent sequences
  • Sliding Window Attention - precise retrieval of memory
  • Multi-Layer Perceptions - factual knowledge

They compare 4 different configurations:

Mamba Layers

We will not dive too deep into the math here since we did in a previous dive.

https://www.oxen.ai/blog/mamba-linear-time-sequence-modeling-with-selective-state-spaces-arxiv-dives

You can think of it as an efficient recurrent network that scans one token at a time, selectively updating itā€™s hidden state space.

Mamba layers are heavily inspired by GRU (Gated Recurrent Units) with some optimizations to make them very fast. They are from a family of architectures called State Space Models (SSMs).

The problem with pure mamba layers is you have to cram all that information into the hidden state and selectively decide what to keep or discard.

Sliding Window Attention

They add SWA to help process longer range dependencies that get lost in the hidden state.

If you remember from our Mistral Dive SWA looks like this:

https://www.oxen.ai/blog/arxiv-dive-how-to-mistral-7b-works

For example in the diagram above, if the window size was 3 and the sentence was "The cat sat on the mat and watched the dog."

  • first layer sees ā€œthe cat satā€
  • second layer sees ā€œon the matā€
  • third layer sees ā€œand watched theā€
  • ā€¦. etc

Tokens outside the sliding window can still influence the next word prediction. In each layer the information can move forward by W tokens, so if you have k layers, information can move forward by k*W.

W=3

K=12

Each token could attend in theory to 36 previous tokens. This technique helps alleviate some of the quadratic nature of vanilla attention.

In SAMBA they use a window size of 2048 + RoPE for relative position encodings within the sliding window.

Mamba captures low-rank information about the sequences through recurrent compression. This leaves the attention layers in SAMBA to theoretically only need to focus on information retrieval.

Multi-Layer Perceptions

They use MLPs with SwiGLU activations for all the models in this paper. They argue that these are architectureā€™s primary mechanism for non-linear transformation and recall of factual knowledge.

Experiments

They train 4 models with different parameter sizes.

421M

1.3B

1.7B

3.8B

They also train a baseline Mamba, Llama-3 and Mistral to compare against.

Letā€™s take a second to look at the results šŸ‘€

SAMBA 3.8B out performs Llama-3 8B on many tasks.

Training dataset of the same dataset used by Phi3 with 3.2T tokens.

They also did a smaller training run on 230B tokens of the Phi2 Dataset for more apples to apples comparisons šŸŽ. The chart above they did not normalize for the number of training tokens.

Exploring Attention vs Linear Recurrence

There are other forms of linear recurrence that they benchmark against.

  • Llama-2-SWA
  • Sliding RetNet (Multi-Scale Retention Layers)
  • Sliding GLA (Gated Linear Attention)

You can see that SAMBA is competitive in training speed, while winning on perplexity at context lengths of 4096, 8192, 16384.

SAMBA has stable throughput compared to models with attention. Youā€™ll notice Mistral has higher throughput, but thatā€™s because it just uses SWA.

That being said, SAMBA outperforms Mistral on Passkey Retrieval up to 256k context length.

Passkey retrieval isĀ a task that requires a language model to retrieve a five-digit random number (passkey) in a long meaningless text sequence.

Conclusion

The combination of building blocks within SAMBA give it a nice balance of throughput vs context length vs language modeling. It outperforms both pure attention-based and SSM-based models on a wide variety of benchmarks. Itā€™s super promising for long context retrieval tasks, and Iā€™m excited to see where the researchers take it next!