by ocean_moist on 5/13/25, 3:29 AM with 32 comments
by jbellis on 5/13/25, 9:45 AM
[3.3] For saving the KV cache, only the intermediate latent representations need to be stored: [latex] where r is much smaller than nh · dh [n-sub-h, d-sub-h]
[background] In traditional multi-head attention you must cache full key and value matrices of size T x (nh · dh) where T is the token length, nh is the number of attention heads, dh is the dimensionality of each individual head
sounds like a big win for memory constrained environments like local inference
by killerstorm on 5/13/25, 12:54 PM
by karmakaze on 5/13/25, 4:37 PM
by olq_plo on 5/13/25, 5:54 AM
by magicalhippo on 5/13/25, 11:35 AM
Perhaps a trivial insight but I feel a lot of progress often comes in the form of generalizations, where existing approaches can be seen as special cases. Here the authors show that Group Query Attention (GQA) and Multi-Query Attention (MQA) falls out as special cases of their new model.
edit:
Adding my own summary, as I understand it.
The key to what they're doing, no pun intended, is to rely on the fact that large, high-dimensional, matrices may contain a lot of redundant information. Thus one may be able to find an good approximation which has less redundant information, by going through an intermediary stage which has fewer dimensions.
A n-by-m matrix M takes n-dimensional vectors and transforms them to m-dimensional vectors. The trick here is to replace matrix A by two matrices, L and R, which are n-by-r and r-by-m respectively, where r is smaller than n and m. This is called a low-rank approximation.
In a sense you're "straining the matrix", by forcing the information to pass through an intermediary, low-dimensional vector.
The memory savings come from the fact that matrix A has n*m entries, while L and R have n*r and r*m entries respectively. Say n = m = 100 and r = 20, that means A has 100*100 = 10k entries, while L and R have just 100*20 + 20*100 = 4k entries in total.
The trick itself is not new, for example it is also used in LoRA where an additional low-rank approximation matrix is used to tweak the output of an existing model. The low rank means there's far fewer the matrix entries, aka parameters, to train than if one had used a regular fully dense matrix.
The extra expressiveness of MLA comes from the fact that in GQA, in order to save memory, some of the matrices are actually built by gluing copies of a narrower matrix together. This means the information in the glued-up matrices are very redundant and fixed in a certain way, and thus are restricted in how they can transform the inputs.
By using the low-rank approximation instead, the information in the full, reconstructed matrices are not fixed in the same way compared to the glued-up result. Thus the inputs can be transformed in a less restrictive way, leading to the increase in expressiveness.
The GQA method saves a bit more memory compared to MLA as the narrower matrices are even smaller than the low-rank matrices in MLA, but at the cost of expressiveness.
by wiz21c on 5/13/25, 8:56 AM
Answering my own question: https://www.reddit.com/r/MachineLearning/comments/1hpg91o/d_...
by kavalg on 5/13/25, 6:28 AM
by EGreg on 5/13/25, 9:17 AM