Mikhail Breslav

Brief Notes on Attention Efficiency

As part of my ongoing review of LLMs, I revisited the core computation performed during self attention. Like in my previous reviews, I focused on the idea of there being three important learnable projections that map our token embeddings to queries, keys, and values which are then used to re-represent (add context to) the token embeddings. One aspect of attention that I glossed over in the past is the efficiency of this computation.

How Efficient is Self Attention?

Looking at the Multi-Query Attention paper we can see that they report multi-headed attention as having a time complexity of \(\Theta(bnd^{2})\) and a memory complexity of \(O(bnd + bhn^{2} + d^{2})\). Our goal in this section is to see if this makes sense at a high level.

Let’s start by reviewing the time complexity component and I’ll make a few simplifying assumptions:

Therefore:

At this point we see a discrepency between my analysis and what the paper reports. The paper only contains the first term and ignores the second. Why might this be? My guess is that the paper assumes \(d >> n\) in which case the first term dominates and the second term can be ignored. While at the time of the paper that may have been a good assumption it’s not obvious that it still holds as context lengths have grown larger and larger.

Another confusion I’ve had is seeing the complexity of attention commonly being reported as quadratic in \(n\), while my computation above shows that we also have the first term which is quadratic in \(d\). This confusion was resolved when I realized that the commonly reported complexity is only considering the computation of the attention weights and applying them to the values. Instead my starting point was based on the above paper which also includes the projection steps in its computation. So while the attention computation is quadratic in \(n\) we must also consider the pre step of calculating the Queries, Keys, and Values which is qudratic in \(d\).

As for memory complexity, it’s easier to see where it comes from by looking at the paper directly and noting the shapes of all the tensors that need to be stored during the computation. Like time complexity we see that memory is also quadratic in \(n\) and \(d\).

Conclusion & Lingering Questions

Given that we want LLMs to be able to handle long input sequences it is reasonable to be concerned by the overall quadratic dependence on sequence length. This leads to a few of my lingering questions…

As usual if you’ve made it this far thanks for the read. If you’re looking for some calming music check out the Peaceful Meditation playlist on Spotify.

References

These are the references I found to be helpful: