Titan Transformer Explained: Learning to Memorize at Test Time

Aim of Titans Architecture

  1. Learn a limited size representation of history. This lets us avoid quadratic attention complexity.
  2. Continual learning of memory.
  3. Integrate such memory into Transformer.

Formulation

(1) is a classic problem in sequence learning literature - for example RNN compress input and history into . In a earlier LtLaTT paper, an idea to use intermediate neural network as KV cache was explored. This paper also bases it's memory module based on same idea.

Key insight for (2) is that we would want to remember surprising things. For this we need a way to gauge the surprise of the input . So, they let model learn a surprise scoring function .

Based on the amount of suprise we compute the change in the as:

Intuitively this makes some sense because, we would want the updated towards a state that is less surprised about the association .

Then they go on to modify this a little bit to include a stream of surprises - sort of like momentum we encounter in case of gradient descent. So, the new equations are -

  • : Forgetting gate
  • : Decay of past surprises.
  • : Scaling of momentary surprise.


One thing we can see is that due to continuous non linear update of memory module we can no longer do parallel training like transformer. In the paper they describe how to make above approach parallelizable. The parallelization technique used is also quite similar to one proposed in LtLaTT paper Instead of finding the we use for where is fixed chunk size. So, the cumulative sum of gradients required for update of can computed chunkwise.

Screenshot 2025-01-17 at 13.45.47.png


Incorporate Memory into Transformer

They also describe various ways of incorporating such memory module into transformer like architecture. For example they combine sliding window attention, memory layers, gating like LSTMs into the transformer architecture.


Advantages:

  1. Continual Memory Learning into fixed dimension.
  2. Titans can solve problems beyond the capabilities of Transformers and linear recurrent models.
  3. Titans serve as strong baseline alternative to transformer with linear compute time.

Points to Ponder

  1. Just like LtLaTT paper, the outer gradient descent completely ignores the inner gradient descent operation . It is quite hard to do total optimization of both inner and outer loop parameters because we would immediately run into hessian compute territory. But then the question of initialization of and its implications to performance remain unanswered.
  2. I could be wrong but the wiki-103 benchmark should include other efficient transformers which reportedly exhibit much lower perplexity - leaderboard, routing transformer, reformer, transformer-xl at similar parameter counts.

https://www.youtube.com/watch?v=x8jFFhCLDJY [Video tutorial of Titans]

https://www.youtube.com/watch?v=FsflifJAWdc [Learning to Learn at Test Time]

Comments

SecondThread 2025-01-17 08:47 UTC

Theorem 4.1 in the paper appears out of nowhere. Can someone provide justification for this?

Theorem 4.1. Contrary to Transformers, diagonal linear recurrent models, and DeltaNet, all of which are limited to TC0 (Merrill,Petty, and Sabharwal 2024), Titans are capable of solving problems beyond TC 0, meaning that Titans are theoretically more expressive than Transformers and most modern linear recurrent models in state tracking tasks.