[CUDA in Practice] Hand-Rolled Flash Decoding on SM120: Beating flashinfer.single_decode_with_kv_cache
::: This article is intended for readers with a solid CUDA foundation, familiar with GEMM/multi-head-attention optimization, and interested in advanced inline PTX tuning.
Full kernel and test code: github vitamin-cuda flash_attn :::
# 0. Preface — Decode vs. Prefill Attention: Two Different Optimization Philosophies
Following the previous FMHA post: that one focused on prefill FlashAttention; this one switches to decode attention (single query + long KV cache) and discusses how to write a practical kernel for that regime.
Baselines used in this article:
- PyTorch native implementation after
torch.compile(general-purpose baseline) - flashinfer’s
single_decode_with_kv_cache(existing specialized baseline)
A note on the flashinfer baseline: I added a minimal compatibility fix so it can run normally on my RTX 5060 Laptop (26 SMs). Details are below.
# PyTorch Native Baseline
@torch.compile
def torch_native_decode(q, k, v, scale=None):
# q: [head, dim] -> [32, 128]
# k: [seq, head, dim] -> [4096, 32, 128]
# v: [seq, head, dim] -> [4096, 32, 128]
if scale is None:
scale = 1.0 / math.sqrt(q.shape[-1])
# reshape for batched GEMV-style compute
q_b = q.unsqueeze(1) # [32, 1, 128]
k_b = k.permute(1, 2, 0) # [32, 128, 4096]
v_b = v.transpose(0, 1) # [32, 4096, 128]
attn_scores = torch.matmul(q_b, k_b) * scale
attn_probs = torch.softmax(attn_scores, dim=-1)
out = torch.matmul(attn_probs, v_b)
return out.squeeze(1)Don’t underestimate this baseline just because it’s “PyTorch code.”
Decode attention is fundamentally close to batched GEMV, where mature library paths plus torch.compile graph optimization can be very strong.
# A Minimal Compatibility Fix for flashinfer
I found cudaOccupancyMaxActiveBlocksPerMultiprocessor unexpectedly returning 0 in flashinfer’s SingleDecodeWithKVCacheDispatched path on my setup.
My temporary fix:
- force
num_blocks_per_sm = 1(its default double K/V buffering already uses about 64KB SMEM anyway) - set
max_num_kv_chunks = 1to avoid downstream zero-chunk behavior
Also, once I saw flashinfer’s double K/V-buffer SMEM footprint, I already expected occupancy loss to hurt it on this card.
# Kernel Outline in This Post
flash_decode_tma_128BN=64, TMA +float4vectorized SMEM loads + online softmax
flash_decode_tma_dbf_k- same as above, plus double K buffers
Both target MHA decode attention with head_dim = 128.
# 1. Flash Decoding Recap
At decoding stage, batch and q_seq_len are tiny (often effectively 1), so sequence parallelism on Q side vanishes and SM utilization drops.
Classic flash-decoding idea:
- split KV sequence into chunks
- compute chunk-local online stats (
m_i,d_i, partialo) - merge chunk results in a second pass
For clarity, this post discusses one decode step with:
qshape[head, dim]kvshape[seq, head, dim]
# Tiling Strategy
- Keep
BN=64tile size from prior FlashAttention experiments - On my card (100KB total SMEM, 48KB per block), this is practical and stable
So each KV loop iteration loads 64 x 128 K/V tiles.
# Block/Grid Strategy
- block size = 128
- 32/64 underutilize warp-level flexibility
- 256 is unnecessary for this kernel’s compute density
- 128 tested best
- grid:
y: headx: chunked KV sequence
Chunk size heuristic:
- target about
2 * num_smstotal blocks for this card (26 * 2 = 52) - derive runtime
chunk_sizefromhead * seq / 52 - round up to power-of-two-style buckets (
256/512/1024/2048)
# Why No Tensor Core Path Here?
Decode’s q is effectively one row.
mma m16n8k16 requires m=16; padding fake rows wastes work.
This problem is bandwidth-dominated, so efficient K/V movement and consumption matter more than forcing MMA.
So I skip mma/ldmatrix path and instead:
- use TMA to move full
64 x 128K/V tiles to SMEM - use vectorized reads (
float4) to maximize bandwidth usage
# 2. Kernel Details
Within one tile/chunk:
- load
Q[1,128] - compute
S = Q * K^T(1 x 64) - apply softmax to get
P[1,64] - compute
O = P * V
How to split this over 128 threads (4 warps)?
Bad idea first (rejected): one thread computes an entire row/column dot product. That violates CUDA parallel-first design and explodes both sync and bank-conflict pressure.
Chosen strategy:
- one warp handles row groups
- with half precision and
float4vectorization, each thread handles 8 elements - one warp can process two rows per step; loop depth is reduced
# Pass Structure
Pass 1 (chunk kernel):
- init K/V SMEM buffers + TMA mbarriers
- load
Qfragment into registers - maintain subgroup-local online softmax state (
acc_o,m_i,d_i) - loop through tiles inside chunk:
- TMA load K/V
- compute local attention scores
- compute weighted partial outputs
- online merge into subgroup state
- block-level merge subgroup states
- write split results:
ws_o(partial output)ws_lse(logsumexpform)
Pass 2 (reduce kernel):
- merge all chunk outputs using
ws_lseweights - produce final output
o
# Using lse Instead of Storing m_i + d_i Separately
I use:
lse = m_i * ln(2) + ln(d_i)Then in pass 2:
float max_lse = max(lse_i);
float global_lse = max_lse + ln(sum(exp(lse_i - max_lse)));
o = sum(ws_o_i * exp(lse_i - global_lse));This makes cross-chunk merge cleaner.
# Notes
- No
lazy_rescalehere:- in prefill it replaced expensive repeated scaling on larger accumulators
- here
acc_ois small and trade-off is not favorable
- No
need_causal_mask:- in decode, all current KV positions are visible to current query
# Grouped Warp/Block Reduce
Each warp effectively behaves like two 16-thread groups for reduction, so reduction helpers are parameterized with group width 16.
This implementation is still rough around the edges (e.g., ws_o/ws_lse write/read patterns are not deeply optimized yet), but the main path is working and fast.
# 3. flash_decode_tma_dbf_k: Double-K Buffering
After finishing baseline flash_decode_tma_128, I optimized further.
Current resource picture:
- ~32KB SMEM + barriers + temporary reduction storage
- occupancy can still keep 2 resident blocks
Given that, the most practical next step is additional pipelining to hide latency.
# 3.1 Double K Buffer
Use double buffering for K only, keep V single-buffered.
Why this works:
- K and V are already consumed at different moments in the pipeline
- V latency is partially hidden by
QK/softmax work - adding another K buffer improves overlap on the K side too
Pipeline outline:
- initialize 2 K tiles
- prologue: preload
Ks[0] - loop:
- prefetch next K tile if available
- issue V transfer
- wait current K tile, run online softmax path
- wait V, compute output update
- sync and flip read/write indices
# 3.2 Epilogue Tweak
Original epilogue had repeated block-reduce loops. I switched to reusing K/V SMEM buffers as staging storage, then a dedicated 16-thread group performs final accumulation/writeback.
Gain is modest, but change is kept for cleaner staging behavior.
# 4. Benchmark
The flash-infer numbers below come from the locally patched version described above, not an untouched upstream build.
####################################################################################################
decode, kv seq: 8192, head: 32, dim: 128
torch.compile mean time: 0.454655 ms, 295.24 GB/s
flash-infer mean time: 0.403291 ms, speedup: 1.13, GB/s: 332.85
flash_decode_tma_128 mean time: 0.408378 ms, speedup: 1.11, GB/s: 328.70
flash_decode_tma_dbf_k_128 mean time: 0.366698 ms, speedup: 1.24, GB/s: 366.06
####################################################################################################
decode, kv seq: 16384, head: 32, dim: 128
torch.compile mean time: 0.872882 ms, 307.55 GB/s
flash-infer mean time: 0.784423 ms, speedup: 1.11, GB/s: 342.23
flash_decode_tma_128 mean time: 0.735274 ms, speedup: 1.19, GB/s: 365.10
flash_decode_tma_dbf_k_128 mean time: 0.733273 ms, speedup: 1.19, GB/s: 366.10
####################################################################################################
decode, kv seq: 32768, head: 32, dim: 128
torch.compile mean time: 1.507921 ms, 356.04 GB/s
flash-infer mean time: 1.499479 ms, speedup: 1.01, GB/s: 358.05
flash_decode_tma_128 mean time: 1.495797 ms, speedup: 1.01, GB/s: 358.93
flash_decode_tma_dbf_k_128 mean time: 1.455790 ms, speedup: 1.04, GB/s: 368.79
####################################################################################################
decode, kv seq: 65536, head: 32, dim: 128
torch.compile mean time: 2.980080 ms, 360.31 GB/s
flash-infer mean time: 2.897006 ms, speedup: 1.03, GB/s: 370.64
flash_decode_tma_128 mean time: 2.856871 ms, speedup: 1.04, GB/s: 375.85
flash_decode_tma_dbf_k_128 mean time: 2.849400 ms, speedup: 1.05, GB/s: 376.84
####################################################################################################
decode, kv seq: 131072, head: 32, dim: 128
torch.compile mean time: 6.044398 ms, 355.29 GB/s
flash-infer mean time: 5.751600 ms, speedup: 1.05, GB/s: 373.37
flash_decode_tma_128 mean time: 5.663495 ms, speedup: 1.07, GB/s: 379.18
flash_decode_tma_dbf_k_128 mean time: 5.736955 ms, speedup: 1.05, GB/s: 374.33
####################################################################################################
decode, kv seq: 131073, head: 32, dim: 128
torch.compile mean time: 6.466227 ms, 332.11 GB/s
flash-infer mean time: 6.117131 ms, speedup: 1.06, GB/s: 351.07
flash_decode_tma_128 mean time: 5.701174 ms, speedup: 1.13, GB/s: 376.68
flash_decode_tma_dbf_k_128 mean time: 5.695415 ms, speedup: 1.14, GB/s: 377.06Takeaways:
- both custom kernels consistently beat
torch.compilebaseline flash_decode_tma_dbf_k_128is best overall- double-K buffering gives bigger gains on shorter sequences
- for longer sequences, all kernels approach bandwidth ceiling and gaps naturally shrink
- peak logical bandwidth reaches
377.06 / 384 = 98.2%, very close to this GPU’s theoretical peak
In this patched setup on my GPU, flashinfer effectively runs at single-block occupancy with double K/V buffers, which is weaker than my 2 blocks + 2K + 1V configuration here. This again shows how critical occupancy is when it gets too low.
NCU screenshots:
https://cdn.jsdelivr.net/gh/WingEdge777/CDN@main/images/vitamin_cuda/flash_decoding_summary.pnghttps://cdn.jsdelivr.net/gh/WingEdge777/CDN@main/images/vitamin_cuda/flash_decoding_detail.png
There are still some uncoalesced global accesses, mainly from ws_o/ws_lse I/O. They are outside the hot loop, and DRAM bandwidth is already 90%+, so total runtime impact is limited.
# 5. End
That’s my current understanding and implementation of flash decoding.
There are still rough edges, but I’ll leave them for now and move on to other experiments.
If you spot mistakes, feel free to correct me. Suggestions are also welcome.
Full kernel and test code: github vitamin-cuda flash_attn
That’s all.