Flash Attention for llama.cpp on RDNA3: 47% less KV VRAM than Vulkan f16 K, KLD almost losselss on F16 K / q4_0 V. Part 1.
Posted by DrBearJ3w@reddit | LocalLLaMA | View on Reddit | 13 comments
The normal tradeoff in llama.cpp attention is: quantize your KV cache and lose quality, or keep fp16 and burn VRAM. On RDNA3 there's a third option(from now on)!Pack four 8-bit K values into a single 32-bit and feed them directly to the GPU's native `sudot4` dot-product instruction. No lossy quantization of K. No fp16 K buffer sitting in memory. The kernel gets exactly the data layout it needs, and VRAM drops because you're storing 8-bit K payloads plus fp16 scales instead of full fp16 K tensors.
But the real gap shows at 128k context with active MTP draft model running - now you're storing K and V for *two* full contexts (main + draft). Total VRAM measured via `rocm-smi`:
128k active MTP, q4_0 V both sides |
| Vulkan f16 K | 23.18 GiB | 22.50 GiB |
| ROCm packed16 K** | **21.76 GiB** |
That 1.42 GiB is the difference between fitting a 128k MTP session and not, depending on your other VRAM pressure. It's not a model weight saving those are identical — it's purely from slashing the K-cache memory footprint across both contexts.
Now the quality side. The packed16 K path still produces fp16-range K values after dequant — the 8-bit packing isn't a lossy quantization, it's a storage layout change. The only compression loss comes from the V side. Measured on WikiText-2 with the 27B model, ctx=512, chunks=4, comparing V=q4_0 and V=q8_0 against a V=fp16 baseline. K is packed16 I32 in all candidates:
| Metric | Value |
| Mean PPL ratio | 1.0020 ± 0.0042 |
| Mean KLD | **0.00455** ± 0.00034 |
| Median KLD | **0.00182** |
| 99th percentile KLD | 0.0500 |
| Same top token | **97.06%** |
| RMS Δp | 1.98% |
**q8_0 V vs fp16 V:**
| Metric | Value |
| Mean PPL ratio | 1.0010 ± 0.0034 |
| Mean KLD | **0.00283** ± 0.00033 |
| Median KLD | **0.00086** |
| 99th percentile KLD | 0.0313 |
| Same top token | **97.94%** |
| RMS Δp | 1.68% |
For context on what these KLD numbers mean: Kullback-Leibler divergence measures how different two probability distributions are. Under \~0.01 is generally considered near-indistinguishable in practice for token-level distributions. Both V formats are comfortably under that, with q8_0 roughly half the divergence of q4_0 (mean 0.0028 vs 0.0046, median 0.0009 vs 0.0018). If you're running q4_0 V to stay lean, you're paying \~0.0045 KLD for less KV VRAM than fp16 K+V. If you want tighter quality, q8_0 V gives you \~0.0028 KLD vs fp16 K+V (since the K saving is identical the V format doesn't change the packed16 K layout).
Why does packed16 K produce fp16-equivalent quality? Because the packing isn't quantization it's repacking. The K tensor is fp16 at rest. The kernel reads each row, computes per-block fp16 scales (absmax), quantizes to int8 on the fly, packs four int8 values into one I32, and writes that payload plus the scales to the cache. On the attention pass, the kernel loads the I32 payload, calls `sudot4` (which does four INT8 multiplies and an accumulate in one instruction), multiplies by the Q and K scales, and proceeds through online softmax. The dequant is mathematically exact for the packed int8 range!The only information loss is the int8 rounding of K values, and that's bounded by the fp16 scale per block. The WikiText numbers confirm this: PPL ratio of 1.002 is well within the ±0.004 noise band.
Compare this to what Vulkan does: on Vulkan, the KV cache path stores K as full fp16. That's lossless for K but costs memory. The packed16 approach gets you the same effective K precision (int8 rounding with fp16 scale is effectively fp16-range) while cutting the K memory footprint to roughly one third 8 bits per value plus scale overhead vs 16 bits. The V side is also halved. For effective 4_0 V you get 2.25 bit.
https://github.com/DrBearJew/llama.cpp/tree/tbq4-rdna3-experiment
Thrumpwart@reddit
This is very interesting. Will give it a go and report back.
FierceDeity_@reddit
Does less VRAM used in this case also mean improved performance? Due to less VRAM reads.
I use a RDNA3.5 Ryzen AI 395+ (Radeon 8060S) with 128GB RAM. Saving VRAM is not the biggest concern, but packing things for better reading could be, since the VRAM is LPDDR5-8000 and not the fastest by any means.
soyalemujica@reddit
This is an interesting one, compiling in CatchyOS to give it a try in my 7900XTX
jake_that_dude@reddit
the number i'd want before merging this is tok/s split by context length with MTP off/on. the VRAM win is real at 128k, but can still lose if the pack/dequant work shows up in decode. run 8k/32k/128k with the same prompt, then report prefill tok/s, decode tok/s, and peak. if decode stays flat, this is way more interesting than another KV quant.
DrBearJ3w@reddit (OP)
I will remind you when it's time to merge. There will be part 2.
BringMeTheBoreWorms@reddit
Very interesting! I’ll try it out when I get the chance. Will this also work on rdna4?
DrBearJ3w@reddit (OP)
Sadly not. Needs source-side gating/porting. But I don't have RDNA4 card to test it out.
BringMeTheBoreWorms@reddit
Ah.. I’ll try it on my 7900s. But would love to get a bit more space on the r9700 as well
DrBearJ3w@reddit (OP)
The MTP is not yet optimized. It works, but is marginally slower. The Non-MTP version has normal speeds,But has BIG advantage on longer context.
BringMeTheBoreWorms@reddit
Shame you can get a hold of an r9700. What do you need to test? Mine see a lot of use each day so hard to share it out
DrBearJ3w@reddit (OP)
I could write the kernel blindly,but I am certain there would be bugs involved. So yeah. I need RDNA4 card 😅
audioen@reddit
The sensible comparison point for a scheme that packs K cache to 8 bits with fp16 scale is K cache in q8_0, because it does something similar to your packed16 scheme. Internal contradictions in this text are just weird. One part of text claims seems to claim it is not lossy ("not quantization but repacking", paraphrasing) another part of the text admits that some rounding and loss of K precision is involved.
You also aren't going to cut K overhead to one third if you go to 16 bits to 8 bits per value, basic math says the ratio of those values is 2:1, unless something else is involved here.
DrBearJ3w@reddit (OP)
I should've compared against q8_0 K, not fp16 K.That is why its confusing. Same memory footprint, same ballpark. But the KLD isn't the same: packed16 keeps per-block fp16 scaling through the int8 payload in a way that preserves fp16-range fidelity. it's a half-bit heavier than plain q8_0 (8.5 vs 8) for effectively lossless quality. Half the memory for the same quality(2.2 percent off)