This means that the total computing power required for attention grows quadratically with the total number of tokens. Suppose a 10-token prompt requires 414,720 attention operations. Then:
- Processing a 100-token prompt will require 45.6 million attention operations.
- Processing a 1,000-token prompt will require 4.6 billion attention operations.
- Processing a 10,000-token prompt will require 460 billion attention operations.
This is probably why Google charges twice as much, per token, for Gemini 1.5 Pro once the context gets longer than 128,000 tokens. Generating token number 128,001 requires comparisons with all 128,000 previous tokens, making it significantly more expensive than producing the first or 10th or 100th token.
A lot of effort has been put into optimizing attention. One line of research has tried to squeeze maximum efficiency out of individual GPUs.
As we saw earlier, a modern GPU contains thousands of execution units. Before a GPU can start doing math, it must move data from slow shared memory (called high-bandwidth memory) to much faster memory inside a particular execution unit (called SRAM). Sometimes GPUs spend more time moving data around than performing calculations.
In a series of papers, Princeton computer scientist Tri Dao and several collaborators have developed FlashAttention, which calculates attention in a way that minimizes the number of these slow memory operations. Work like Dao’s has dramatically improved the performance of transformers on modern GPUs.
Another line of research has focused on efficiently scaling attention across multiple GPUs. One widely cited paper describes ring attention, which divides input tokens into blocks and assigns each block to a different GPU. It’s called ring attention because GPUs are organized into a conceptual ring, with each GPU passing data to its neighbor.
I once attended a ballroom dancing class where couples stood in a ring around the edge of the room. After each dance, women would stay where they were while men would rotate to the next woman. Over time, every man got a chance to dance with every woman. Ring attention works on the same principle. The “women” are query vectors (describing what each token is “looking for”) and the “men” are key vectors (describing the characteristics each token has). As the key vectors rotate through a sequence of GPUs, they get multiplied by every query vector in turn.