[CUDA 优化实战] 纯手搓 flash decoding sm120 : 超越 flashinfer.single_decode_with_kv_cache
本文适用于有一定 CUDA 编程基础,熟悉 GEMM/multi-head-attention 优化,对进阶嵌入 PTX 指令性能调优感兴趣的读者阅读
完整 kernel 和测试代码可以点击 flash_decode 查看
# 0. 序 - decode 和 prefill attention : 完全不同的优化哲学
承接上篇 fmha 文章。上篇主要讨论 prefill 场景下的 flash attention,这一篇换到 decode 场景,看看单 query、长 KV cache 时 kernel 该怎么写。
本文的对比 baseline 有两个:
torch.compile后的 PyTorch native 实现,作为通用算子基线- flashinfer 的
single_decode_with_kv_cache,作为现成 decode kernel 基线
说明一下:flashinfer baseline 是本文后补的。为了能在我这台 26 SM 的 5060 上正常跑通 benchmark,我对它做了一个最小修复,后面单独说明。
# pytorch native
@torch.compile
def torch_native_decode(q, k, v, scale=None):
# q: [head, dim] -> [32, 128]
# k: [seq, head, dim] -> [4096, 32, 128]
# v: [seq, head, dim] -> [4096, 32, 128]
if scale is None:
scale = 1.0 / math.sqrt(q.shape[-1])
# 调整维度以适应 Batched GEMV
q_b = q.unsqueeze(1) # [32, 1, 128]
k_b = k.permute(1, 2, 0) # [32, 128, 4096]
v_b = v.transpose(0, 1) # [32, 4096, 128]
# S = Q @ K^T
attn_scores = torch.matmul(q_b, k_b) * scale # [32, 1, 4096]
attn_probs = torch.softmax(attn_scores, dim=-1)
# O = P @ V
out = torch.matmul(attn_probs, v_b) # [32, 1, 128]
return out.squeeze(1) # [32, 128]不要因为它是 PyTorch 实现就先入为主地觉得它慢。decode attention 本质上已经很接近 batched GEMV,PyTorch 会走到相当成熟的库实现,再叠加 torch.compile 的图优化,完全够资格做 baseline。自己的 kernel 不认真写,还真不一定打得过它。
# flashinfer baseline 的一个兼容性修复

flashinfer 的SingleDecodeWithKVCacheDispatched代码,不知为何cudaOccupancyMaxActiveBlocksPerMultiprocessor返回了 0,导致一步 0,步步 0.
我这里做了一个简单的修复,强制设置 num_blocks_per_sm=1(其默认双 buffer Ks/Vs 实现,占用 64KB smem,也只能是 1),然后把 max_num_kv_chunks 设为了 1(按照原本代码逻辑 block_per_sm * num_sm / heads= 1*26/ 32 就是会等于 0,我只能将其改为接近原代码意图的正整数)。
备注:其实看到 flashinfer 双 Ks/Vs buffer smem 的配置时,我就知道 flashinfer 输定了,Occupancy 比我低一半,这损失的性能不是它的 Vs 双 buffer 能挽回的了(我的 Ks 也是双 buffer)。
# 本文的 kernel 大纲
- flash_decode_tma_128 (BN=64,TMA + float4 向量化读取smem + online softmax)
- flash_decode_tma_dbf_k (BN=64,TMA + float4 向量化读取smem + online softmax,double Ks buffers)
支持 head_dim 为 128 情况下 mha 的 decode attention。
# 1. flash decoding
flash decoding 的思想大家肯定都学习过了。在 llm decoding 的阶段,由于 batchsize/q_seq_len 太小(甚至直接等于 1),attention 中对 q 的序列并行完全没了,无法充分利用所有 SM。因此考虑对 kv 的 seq 维度进行 chunk 切分,然后分两步完成 attention
- 一个 block 负责一个 q 和 kv chunk 的 attention,计算完 chunk 内的 m_i/d_i/acc_o,写回 gmem
- 读取上一步的中间结果,merge m_i/d_i/acc_o,得到最终的 o 输出
这张官方博客的示意图,相信关注的人看过没有十遍也有八遍了。但我们就不重复说原理性的东西,直直白白地讲清楚如何使用 c++ 代码纯手搓出来一个 flash decoding kernel.

为了方便理解,这里只考虑 q shape 为 [head, dim]、kv shape 为 [seq, head, dim] 的一次 decode 计算,也就是和上面那段 PyTorch 代码一一对应的版本。
先说 data tiling 策略:
- 搬运 kv 数据 tile,这里沿用了 flash attention 实现中的 BN=64,即 64x128, 为什么?
- 我的卡一共 100KB smem,一个 block 最多 48KB,这里即使省掉了原来 q 中 BM 对应的 16KB,也不够 Ks 和 Vs 分的,暂时也不想搞个奇奇怪怪的 BN 大小
也就是说 kv chunk loop 中每次循环加载 64x128 的 Ks/Vs tile。
再说 thread block/grid 配置:
- block 直接定为 128。这个不是拍脑袋:32、64 太小,可切换 warp 数太少;256 又太大,这个 kernel 的计算密度没高到需要那么多线程一起上。实测下来 128 最合适。
- grid 上,显然 q 失去了 seq 维度,无法并行。head 还是放在 y 维度上,再考虑对 kv 的 seq 进行切分放到 x 维度。这里只有一个切块大小的问题:
- 我的 5060 只有 26 个 SM,为了充分利用 SM,我们保障 block 数量为 SM 数量的整数倍,不用太多,2~4 倍即可,我这里就用了 26x2,因此先确定预期的总 chunk 数为 52 个左右
- 然后运行时用 head*seq/52,且向上对 2 的幂取整得到 chunk_size,则 grid.x = (seq + chunk_size - 1) / chunk_size;
把这些约束合起来,kernel launch 代码就基本定下来了。
这里我没有继续走 Tensor Core 路线。原因很直接:mma m16n8k16 要求 m=16,而 decode 里的 q 本质上只有一行,硬凑出 16 行 padding 只会徒增浪费。再加上这个问题本身更偏向带宽瓶颈,与其执着于 mma,不如把重点放在更高效地搬运和消费 K/V 数据上。
既然不走 mma,那 ldmatrix 和专门为其服务的 swizzle 也都可以先放下。TMA 这里只需要把一整块 64x128 的 K/V tile 原样搬进 shared memory,后面再用向量化读法把它吃满即可。
inline int get_chunk_size(int q_head, int kv_len, int num_sms) {
int target_blocks = num_sms * 2;
// Total_Blocks = q_head * (kv_len / chunk_size)
// chunk_size = (q_head * kv_len) / target_blocks
int chunk = (q_head * kv_len) / target_blocks;
if (chunk <= 256)
return 256;
if (chunk <= 512)
return 512;
if (chunk <= 1024)
return 1024;
return 2048;
}
#define CHECK_T(x) TORCH_CHECK(x.is_cuda() && x.is_contiguous(), #x " must be contiguous CUDA tensor")
template <typename T>
inline CUtensorMap create_3d_tensor_map(T *global_address,
uint64_t dim_d,
uint64_t dim_h,
uint64_t dim_s,
uint64_t stride_h,
uint64_t stride_s,
uint32_t box_d,
uint32_t box_s) // Each kernel load takes a (box_s x box_d) block
{
CUtensorMap tmap;
CUtensorMapDataType type =
std::is_same_v<T, __half> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT16 : CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE;
// TMA dimensions: from fastest (0) to slowest (2)
uint64_t globalDim[3] = {dim_d, dim_h, dim_s};
// globalStrides are strides for dimensions 1, 2, must be in Bytes
uint64_t globalStrides[2] = {stride_h, stride_s};
uint32_t boxDim[3] = {box_d, 1, box_s};
uint32_t elementStrides[3] = {1, 1, 1};
CUresult res = cuTensorMapEncodeTiled(&tmap,
type,
3, // Rank = 3
global_address,
globalDim,
globalStrides,
boxDim,
elementStrides,
CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle,
CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS, "cuTensorMapEncodeTiled failed for 3D Tensor!");
return tmap;
}
#define DISPATCH_TMA_KERNEL(NAME, HEAD_DIM, CHUNK_SIZE) \
NAME##_kernel<BN, CHUNK_SIZE, HEAD_DIM, 128, __nv_bfloat16> \
<<<blocks_per_grid, 128, smem_bytes, stream>>>(reinterpret_cast<__nv_bfloat16 *>(q.data_ptr()), \
tma_k, \
tma_v, \
reinterpret_cast<float *>(ws_o.data_ptr()), \
reinterpret_cast<float *>(ws_lse.data_ptr()), \
kv_len, \
q_head, \
kv_head, \
scale);
#define binding_tiled_tma_func_gen(name, HEAD_DIM) \
void name##_##HEAD_DIM(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor o, float scale) { \
\
CHECK_T(q); \
CHECK_T(k); \
CHECK_T(v); \
CHECK_T(o); \
\
/* Extract dimension info dynamically from Tensor */ \
const int q_head = q.size(0); \
const int head_dim = q.size(1); \
const int kv_len = k.size(0); \
const int kv_head = k.size(1); \
\
/* Only validate that head_dim matches the compile-time constant */ \
TORCH_CHECK(head_dim == HEAD_DIM, "Head dim mismatch: expected ", HEAD_DIM); \
\
int elem_bytes = k.element_size(); \
uint64_t k_stride_h = k.stride(1) * elem_bytes; \
uint64_t k_stride_s = k.stride(0) * elem_bytes; \
uint64_t v_stride_h = v.stride(1) * elem_bytes; \
uint64_t v_stride_s = v.stride(0) * elem_bytes; \
\
const int BN = 64; \
const int num_sms = 26; \
const size_t smem_bytes = BN * head_dim * sizeof(__nv_bfloat16) * 2 + sizeof(mbarrier_t) * 2; \
const int chunk_size = get_chunk_size(q_head, kv_len, num_sms); \
CUtensorMap tma_k = create_3d_tensor_map<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 *>(k.data_ptr()), \
head_dim, \
kv_head, \
kv_len, \
k_stride_h, \
k_stride_s, \
head_dim, \
BN); \
\
CUtensorMap tma_v = create_3d_tensor_map<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 *>(v.data_ptr()), \
head_dim, \
kv_head, \
kv_len, \
v_stride_h, \
v_stride_s, \
head_dim, \
BN); \
\
TORCH_CHECK(q_head % kv_head == 0, "q_head must be divisible by kv_head"); \
const dim3 blocks_per_grid((kv_len + chunk_size - 1) / chunk_size, q_head); \
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(q.device()); \
auto ws_lse = torch::empty({q_head, blocks_per_grid.x}, options); \
auto ws_o = torch::empty({q_head, blocks_per_grid.x, head_dim}, options); \
/* launch kernel */ \
switch (chunk_size) { \
case 256: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 256); break; \
case 512: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 512); break; \
case 1024: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 1024); break; \
case 2048: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 2048); break; \
default: TORCH_CHECK(false, "Unsupported chunk size: ", chunk_size); \
} \
flash_decode_reduce_kernel<HEAD_DIM, 128, __nv_bfloat16> \
<<<q_head, 128, 0, stream>>>(reinterpret_cast<float *>(ws_o.data_ptr()), \
reinterpret_cast<float *>(ws_lse.data_ptr()), \
reinterpret_cast<__nv_bfloat16 *>(o.data_ptr()), \
blocks_per_grid.x); \
}
binding_tiled_tma_func_gen(flash_decode_tma, 128);
#define torch_pybinding_func(f) m.def(#f, &f, #f)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// flash_decode_tma_128
torch_pybinding_func(flash_decode_tma_128);
}使用一个 help func 拿到 chunk_size,根据 chunk size 用一个 dispatch 宏分发到不同的 kernel launch 上(因为我们需要编译期确定 chunk_size,以帮助 kernel 内部循环展开)
# 2. kernel 实现细节
从逻辑上来说,一次 tiling 下的数据(chunk 大小是 tiling 大小的整数倍),我们需要加载一行 Qs[1,128],然后和 Ks.T[128,64] 做向量矩阵乘法得到 s[1,64], 然后 softmax 完得到 Ps[1,64], 再乘以 Vs[64,128]。
现在的情况是,我们一个 block 有 4 个 warp,128 线程,如何来瓜分这些计算。
首先排除掉一个线程负责一整行/一整列点积的思路,这不是 cuda 并行编程的第一性原理,不能用串行思维去写代码,这样直接后果就是 bank conflict 爆炸,而且算完一行的点积结果后,还是要经过 block 线程间进行同步广播,否则无法参数接下来 Vs 的计算。
我们应该先从一个 warp 去考虑,比如一个 warp 负责一行计算,那么 128 的向量共 32 线程,每个线程就只需要负责 4 个元素。4 个 warp,每个 warp 负责 16 行,这样 16 行 16 次循环内只需要 warp 间同步,等 16 行计算完四组再进行 block 间同步。
更进一步,考虑到我们使用的是半精度,也就是说一个元素才 2 字节,我们为了最大化压榨带宽,肯定是用 float4 向量化指令,因此我们让一个线程负责 8 个元素,一个 warp 一次就可以算两行,只需要循环 8 次。
初步整理一下算法流程如下:
- kernel pass 1:
- 初始化 Ks、Vs 2 块 smem 和 2 个 tma mbarrier
- 初始化 acc_o[8], 每个 group(16 线程)私有化初始化历史状态 acc_o[8], m_i, d_i
- 加载 Qs[8](使用 float4 向量化加载到寄存器)
- 一个 block 负责一个 chunk,kv chunk 内 loop:
- tma 发起加载 Ks、Vs,并等待 Ks 加载完成
- [计算 S] 循环 8 次(Group 内每线程负责 8 行):
- float4 向量化读取 Ks 一行内的 8 个元素
- Qs[8] 和 reg_k[8] 进行点积计算
- Warp Reduce 求和得到单行的 Attention Score,并统计当前小块的 m_part
- wait Vs 加载完毕
- [计算 O] 循环 8 次:
- 根据 m_part 计算当前行的 Softmax 权重标量 p
- p 乘以 reg_v[8] 向量,累加得到当前小块的 part_o[8],并统计当前小块的 d_part
- [group 内部状态更新] 使用安全的 online softmax 逻辑计算 alpha,将当前的 part_o、d_part、m_part 融合进 group 维护的历史 acc_o、d_i、m_i 中
- block 内各 group 进行 block_reduce,合并各自的 acc_o、d_i、m_i
- 将最终合并的寄存器结果转化为 ws_o 和 d_i/m_i,写回 gmem
- kernel pass 2
- 读取上面 gmem 的 ws_o,和 m_i/d_i,online softmax 继续规约得到最终输出 o
- 结束
上面是最原汁原味的分块 online + “offline” softmax attention. 但参考一些教程,可以加入一个 trick 优化,将 m_i 和 d_i 合并为 lse(logsumexp),就是对 e 的指数和再取对数。当然由于我们还加入了 log2_scale 的技巧,因此 lse 变成了一个有点丑陋的东西:
lse(o) = m_i * ln(2) + ln(d_i)然后 pass 2 内通过 lse 和 ws_o 合并方式伪代码为:
float max_lse = max(lse_i);
float global_lse = max_lse + ln(sum(exp(lse_i - max_lse)));
o = sum(ws_o_i * exp(lse_i - global_lse));上 kernel 代码:
// flash decoding softmax(q @ k.T*scale) @ v
template <const int BN = 64,
const int CHUNK_SIZE = 256,
const int HEAD_DIM = 128,
const int THREADS_PER_BLOCK = 128,
typename T>
__global__ void flash_decode_tma_kernel(T *q,
const __grid_constant__ CUtensorMap tma_k,
const __grid_constant__ CUtensorMap tma_v,
float *ws_o, // [q_head, num_chunks, HEAD_DIM]
float *ws_lse, // [q_head, num_chunks]
int kv_len,
int q_head,
int kv_head,
float scale) {
static_assert(THREADS_PER_BLOCK == 128);
static_assert(BN == 64);
// 1. shared memory: K tile, V tile, mbarriers
extern __shared__ __align__(128) uint8_t smem_buf[];
T(*Ks)[HEAD_DIM] = reinterpret_cast<T(*)[HEAD_DIM]>(smem_buf);
T(*Vs)[HEAD_DIM] = reinterpret_cast<T(*)[HEAD_DIM]>(smem_buf + BN * HEAD_DIM * sizeof(T));
mbarrier_t *mbar_k = reinterpret_cast<mbarrier_t *>(smem_buf + BN * HEAD_DIM * sizeof(T) * 2);
mbarrier_t *mbar_v = mbar_k + 1;
// 2. coordinates
const int tid = threadIdx.x;
const int chunk_id = blockIdx.x;
const int q_head_id = blockIdx.y;
const int kv_group_size = q_head / kv_head;
const int kv_head_id = q_head_id / kv_group_size;
constexpr int THREADS_PER_ROW = 16;
constexpr int NUM_GROUPS = THREADS_PER_BLOCK / THREADS_PER_ROW;
constexpr int ROWS_PER_GROUP = BN / NUM_GROUPS;
const int group_id = tid / THREADS_PER_ROW;
const int lane_id = tid % THREADS_PER_ROW;
if (tid == 0) {
mbarrier_init(mbar_k, 1);
mbarrier_init(mbar_v, 1);
}
__syncthreads();
// 3. load q fragment
pack128 qs{FLOAT4(q[q_head_id * HEAD_DIM + lane_id * 8])};
// 4. init subgroup-local online softmax state
__align__(16) float acc_o[8] = {0.0f};
float m_i = -FLT_MAX;
float d_i = 0.0f;
int phase_k = 0;
int phase_v = 0;
const float scale_log2 = scale * 1.44269504f; // scale*log2(e)
const int num_chunks = gridDim.x;
const int chunk_start = chunk_id * CHUNK_SIZE;
const int chunk_end = min(chunk_start + CHUNK_SIZE, kv_len);
// 5. loop over KV tiles inside this chunk
for (int n = chunk_start; n < chunk_end; n += BN) {
int current_bn = min(BN, chunk_end - n);
// 5.1 TMA async load K/V
if (tid == 0) {
mbarrier_expect_tx(mbar_k, BN * HEAD_DIM * sizeof(T));
mbarrier_expect_tx(mbar_v, BN * HEAD_DIM * sizeof(T));
cp_async_bulk_tensor_3d(mbar_k, &tma_k, Ks, 0, kv_head_id, n);
cp_async_bulk_tensor_3d(mbar_v, &tma_v, Vs, 0, kv_head_id, n);
}
__syncthreads();
mbarrier_wait(mbar_k, phase_k);
phase_k ^= 1; // flip phase
// 5.2 compute S = Q * K^T, keep rows per subgroup in registers
const int row_begin = group_id * ROWS_PER_GROUP;
float acc_s[ROWS_PER_GROUP];
float m_part = -FLT_MAX;
#pragma unroll
for (int i = 0; i < ROWS_PER_GROUP; ++i) {
acc_s[i] = -FLT_MAX;
}
#pragma unroll
for (int i = 0; i < ROWS_PER_GROUP; ++i) {
const int row = row_begin + i;
float sum = 0.0f;
if (row < current_bn) {
pack128 ks{FLOAT4(Ks[row][lane_id * 8])};
#pragma unroll
for (int j = 0; j < 8; ++j) {
sum += static_cast<float>(qs.bf[j]) * static_cast<float>(ks.bf[j]);
}
}
sum = warp_reduce_sum<THREADS_PER_ROW>(sum);
if (row < current_bn) {
acc_s[i] = sum * scale_log2;
m_part = fmaxf(m_part, acc_s[i]);
}
}
// 5.3 accumulate subgroup-local O = P * V
mbarrier_wait(mbar_v, phase_v);
phase_v ^= 1;
float part_d = 0.0f;
float part_o[8] = {0.0f};
#pragma unroll
for (int i = 0; i < ROWS_PER_GROUP; ++i) {
const int row = row_begin + i;
if (row < current_bn) {
float p = exp2f(acc_s[i] - m_part);
part_d += p;
pack128 vs{FLOAT4(Vs[row][lane_id * 8])};
#pragma unroll
for (int j = 0; j < 8; ++j) {
part_o[j] += p * static_cast<float>(vs.bf[j]);
}
}
}
if (m_part != -FLT_MAX) {
const float m_new = fmaxf(m_i, m_part);
const float alpha_old = exp2f(m_i - m_new);
const float alpha_new = exp2f(m_part - m_new);
#pragma unroll
for (int i = 0; i < 8; ++i) {
acc_o[i] = acc_o[i] * alpha_old + part_o[i] * alpha_new;
}
d_i = d_i * alpha_old + part_d * alpha_new;
m_i = m_new;
}
}
// 6. merge subgroup states once per chunk, then write split results
const float m_chunk = block_reduce_max<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? m_i : -FLT_MAX);
const float alpha = d_i > 0.0f ? exp2f(m_i - m_chunk) : 0.0f;
const float d_chunk = block_reduce_sum<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? d_i * alpha : 0.0f);
#pragma unroll
for (int i = 0; i < 8; ++i) {
acc_o[i] = block_reduce_sum_by_lane<NUM_GROUPS, THREADS_PER_ROW>(acc_o[i] * alpha);
}
if (group_id == 0) {
int out_base_idx = (q_head_id * num_chunks + chunk_id) * HEAD_DIM + lane_id * 8;
float inv_d = __frcp_rn(d_chunk);
#pragma unroll
for (int i = 0; i < 8; i++) {
acc_o[i] *= inv_d;
}
pack128 out_pack0, out_pack1;
#pragma unroll
for (int i = 0; i < 4; ++i) {
out_pack0.f[i] = acc_o[i];
out_pack1.f[i] = acc_o[i + 4];
}
FLOAT4(ws_o[out_base_idx + 0]) = out_pack0.f4;
FLOAT4(ws_o[out_base_idx + 4]) = out_pack1.f4;
if (lane_id == 0) {
int scalar_idx = q_head_id * num_chunks + chunk_id;
ws_lse[scalar_idx] = m_chunk * 0.6931471805599453f + logf(d_chunk);
}
}
}
template <const int HEAD_DIM = 128, const int THREADS_PER_BLOCK = 128, typename T>
__global__ void flash_decode_reduce_kernel(float *ws_o, float *ws_lse, T *o, int num_chunks) {
const int q_head_id = blockIdx.x;
const int tid = threadIdx.x;
constexpr int NUM_WARPS = THREADS_PER_BLOCK / WARP_SIZE;
__shared__ float s_lse;
float lse_max = -FLT_MAX;
for (int chunk = tid; chunk < num_chunks; chunk += THREADS_PER_BLOCK) {
lse_max = fmaxf(lse_max, ws_lse[q_head_id * num_chunks + chunk]);
}
lse_max = block_reduce_max<NUM_WARPS, WARP_SIZE>(lse_max);
float lse_sum = 0.0f;
for (int chunk = tid; chunk < num_chunks; chunk += THREADS_PER_BLOCK) {
lse_sum += expf(ws_lse[q_head_id * num_chunks + chunk] - lse_max);
}
lse_sum = block_reduce_sum<NUM_WARPS, WARP_SIZE>(lse_sum);
if (tid == 0) {
s_lse = logf(lse_sum) + lse_max;
}
__syncthreads();
const int col = tid * 8;
if (col >= HEAD_DIM) {
return;
}
float out[8] = {0.0f};
for (int chunk = 0; chunk < num_chunks; ++chunk) {
const int scalar_idx = q_head_id * num_chunks + chunk;
const float weight = expf(ws_lse[scalar_idx] - s_lse);
const int base_idx = scalar_idx * HEAD_DIM + col;
pack128 partial0{FLOAT4(ws_o[base_idx + 0])};
pack128 partial1{FLOAT4(ws_o[base_idx + 4])};
#pragma unroll
for (int i = 0; i < 4; ++i) {
out[i] += partial0.f[i] * weight;
out[i + 4] += partial1.f[i] * weight;
}
}
pack128 out_pack;
#pragma unroll
for (int i = 0; i < 8; ++i) {
out_pack.bf[i] = __float2bfloat16_rn(out[i]);
}
FLOAT4(o[q_head_id * HEAD_DIM + col]) = out_pack.f4;
}其中 warp reduce、block reduce、tma copy 等都抽成函数出去了。
细心的朋友可能注意到了
- flash decoding 里没有 lazy rescale,因为没必要,原来 prefill 里 是用 16 次 Ps scale 去替换 64 次 acc_o 的 scale 是值得的,这里 acc_o 只有 8,和 p 乘以 inv_scale 的次数相同。
- 同样的,也没有 need_casual_mask 校验,因为对于当前的 q,所以 kv 都是可见的
唯一值得说明一下的是
# grouped warp/block reduce
因为我们一个 warp 负责两行,相当于把一个 warp 劈开成两半,分别去做 reduce 了,所以 warp reduce 的时候要指定宽度 16,用了个 group_size 模板参数
其实本版代码的实现还是很粗糙的,比如对于 ws_o/ws_lse 的写回还没有做优化,后续再看吧~(stay tuned,我也可能去学习一下 flashinfer 里如何实现等等)
# 3. flash_decode_tma_dbf_k
ok, 完成上述 kernel 后,我们总结一下。目前使用了 smem 32KB + 几个 barrier + block 同步占用的中间变量 num_group*group_size。这里在保住 Occupancy(经验表明,在 1 个 block 或 2 个 block 的选择下,保两个 block,让硬件调度总是会更优)的前提下,我们唯一的办法就是再增加 buffer 进行流水线操作,进一步用计算隐藏时延。 因此我做了如下两点优化:
# double Ks buffer
Ks 使用双 buffer,Vs 依然是单 buffer。为什么可以这么做呢,因为 attention 里 Ks 和 Vs 本来就是异步的,Vs 要等 Ks 的计算完才会用到,所以 Vs 本来就是被隐藏的,只要我们再增加一重 buffer 把 Ks 也隐藏掉理论上就很好。
具体流水线操作也很简单:
- 初始化 2 份 K tile
- prologue:在 kv chunk loop 之前先发起加载
Ks[0],并初始化read/write_idx - kv chunk loop:
- 如果还有 next Ks tile,就发起 TMA Ks[write_idx] 请求
- 发起 Vs TMA 请求
- wait Ks[read_idx] 加载完毕
- online softmax(Qs * Ks)
- wait Vs
- 计算输出
- 同步,并反转 read/write_idx
- epilogue
# epilogue 优化
原来的 epilogue 有 8 次循环的 block reduce,有点重。现在改成复用 Ks/Vs 的 smem buffer 进行中转,然后再用一个单独的 group(16 线程)读取统计 ws_o,最后写回。 这个优化其实影响也不大,改动前后级别没什么提升。不过写都写了,就留着吧
上代码:
// flash decoding softmax(q @k.T *scale) @v
template <const int BN = 64,
const int CHUNK_SIZE = 256,
const int HEAD_DIM = 128,
const int THREADS_PER_BLOCK = 128,
typename T>
__global__ void flash_decode_tma_dbf_k_kernel(T *q,
const __grid_constant__ CUtensorMap tma_k,
const __grid_constant__ CUtensorMap tma_v,
float *ws_o, // [q_head, num_chunks, HEAD_DIM]
float *ws_lse, // [q_head, num_chunks]
int kv_len,
int q_head,
int kv_head,
float scale) {
static_assert(THREADS_PER_BLOCK == 128);
static_assert(BN == 64);
// 1. shared memory: K tile, V tile, mbarriers
extern __shared__ __align__(128) uint8_t smem_buf[];
T(*Ks)[BN][HEAD_DIM] = reinterpret_cast<T(*)[BN][HEAD_DIM]>(smem_buf);
T(*Vs)[HEAD_DIM] = reinterpret_cast<T(*)[HEAD_DIM]>(smem_buf + BN * HEAD_DIM * sizeof(T) * 2);
mbarrier_t *mbar_k = reinterpret_cast<mbarrier_t *>(smem_buf + BN * HEAD_DIM * sizeof(T) * 3);
mbarrier_t *mbar_v = mbar_k + 2;
// 2. coordinates
const int tid = threadIdx.x;
const int chunk_id = blockIdx.x;
const int q_head_id = blockIdx.y;
const int kv_group_size = q_head / kv_head;
const int kv_head_id = q_head_id / kv_group_size;
constexpr int THREADS_PER_ROW = 16;
constexpr int NUM_GROUPS = THREADS_PER_BLOCK / THREADS_PER_ROW;
constexpr int ROWS_PER_GROUP = BN / NUM_GROUPS;
const int group_id = tid / THREADS_PER_ROW;
const int lane_id = tid % THREADS_PER_ROW;
// 3. load q fragment
pack128 qs{FLOAT4(q[q_head_id * HEAD_DIM + lane_id * 8])};
// 4. init subgroup-local online softmax state
__align__(16) float acc_o[8] = {0.0f};
float m_i = -FLT_MAX;
float d_i = 0.0f;
int phase_k[2] = {0};
int phase_v = 0;
const float scale_log2 = scale * 1.44269504f; // scale*log2(e)
const int num_chunks = gridDim.x;
const int chunk_start = chunk_id * CHUNK_SIZE;
const int chunk_end = min(chunk_start + CHUNK_SIZE, kv_len);
// preload Ks
if (tid == 0) {
mbarrier_init(mbar_k, 1);
mbarrier_init(mbar_k + 1, 1);
mbarrier_init(mbar_v, 1);
mbarrier_expect_tx(mbar_k, BN * HEAD_DIM * sizeof(T));
cp_async_bulk_tensor_3d(mbar_k, &tma_k, Ks[0], 0, kv_head_id, chunk_start);
}
__syncthreads();
int read_idx = 0, write_idx = 1;
// 5. loop over KV tiles inside this chunk
for (int n = chunk_start; n < chunk_end; n += BN) {
int current_bn = min(BN, chunk_end - n);
// 5.1 TMA async load K/V
if (tid == 0) {
if (n + BN < chunk_end) {
mbarrier_expect_tx(mbar_k + write_idx, BN * HEAD_DIM * sizeof(T));
cp_async_bulk_tensor_3d(mbar_k + write_idx, &tma_k, Ks[write_idx], 0, kv_head_id, n + BN);
}
mbarrier_expect_tx(mbar_v, BN * HEAD_DIM * sizeof(T));
cp_async_bulk_tensor_3d(mbar_v, &tma_v, Vs, 0, kv_head_id, n);
}
mbarrier_wait(mbar_k + read_idx, phase_k[read_idx]);
phase_k[read_idx] ^= 1; // flip phase
// 5.2 compute S = Q * K^T, keep rows per subgroup in registers
const int row_begin = group_id * ROWS_PER_GROUP;
float acc_s[ROWS_PER_GROUP];
float m_part = -FLT_MAX;
#pragma unroll
for (int i = 0; i < ROWS_PER_GROUP; ++i) {
acc_s[i] = -FLT_MAX;
}
#pragma unroll
for (int i = 0; i < ROWS_PER_GROUP; ++i) {
const int row = row_begin + i;
float sum = 0.0f;
if (row < current_bn) {
pack128 ks{FLOAT4(Ks[read_idx][row][lane_id * 8])};
#pragma unroll
for (int j = 0; j < 8; ++j) {
sum += static_cast<float>(qs.bf[j]) * static_cast<float>(ks.bf[j]);
}
}
sum = warp_reduce_sum<THREADS_PER_ROW>(sum);
if (row < current_bn) {
acc_s[i] = sum * scale_log2;
m_part = fmaxf(m_part, acc_s[i]);
}
}
// 5.3 accumulate subgroup-local O = P * V
mbarrier_wait(mbar_v, phase_v);
phase_v ^= 1;
float part_d = 0.0f;
float part_o[8] = {0.0f};
#pragma unroll
for (int i = 0; i < ROWS_PER_GROUP; ++i) {
const int row = row_begin + i;
if (row < current_bn) {
float p = exp2f(acc_s[i] - m_part);
part_d += p;
pack128 vs{FLOAT4(Vs[row][lane_id * 8])};
#pragma unroll
for (int j = 0; j < 8; ++j) {
part_o[j] += p * static_cast<float>(vs.bf[j]);
}
}
}
if (m_part != -FLT_MAX) {
const float m_new = fmaxf(m_i, m_part);
const float alpha_old = exp2f(m_i - m_new);
const float alpha_new = exp2f(m_part - m_new);
#pragma unroll
for (int i = 0; i < 8; ++i) {
acc_o[i] = acc_o[i] * alpha_old + part_o[i] * alpha_new;
}
d_i = d_i * alpha_old + part_d * alpha_new;
m_i = m_new;
}
// next round
__syncthreads();
read_idx ^= 1;
write_idx ^= 1;
}
// 6. epilogue: merge subgroup states once per chunk, then write split results
const float m_chunk = block_reduce_max<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? m_i : -FLT_MAX);
const float alpha = d_i > 0.0f ? exp2f(m_i - m_chunk) : 0.0f;
const float d_chunk = block_reduce_sum<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? d_i * alpha : 0.0f);
// reuse buffer
constexpr int O_PER_GROUP = 8 * THREADS_PER_ROW;
constexpr int O_GROUP_STRIDE = O_PER_GROUP + THREADS_PER_ROW;
float *sdata_o = reinterpret_cast<float *>(smem_buf);
#pragma unroll
for (int i = 0; i < 8; ++i) {
sdata_o[group_id * O_GROUP_STRIDE + i * THREADS_PER_ROW + lane_id] = acc_o[i] * alpha;
}
__syncthreads();
if (group_id == 0) {
#pragma unroll
for (int i = 0; i < 8; ++i) {
float val = 0.0f;
#pragma unroll
for (int group = 0; group < NUM_GROUPS; ++group) {
val += sdata_o[group * O_GROUP_STRIDE + i * THREADS_PER_ROW + lane_id];
}
acc_o[i] = val;
}
}
if (group_id == 0) {
int out_base_idx = (q_head_id * num_chunks + chunk_id) * HEAD_DIM + lane_id * 8;
float inv_d = __frcp_rn(d_chunk);
#pragma unroll
for (int i = 0; i < 8; i++) {
acc_o[i] *= inv_d;
}
pack128 out_pack0, out_pack1;
#pragma unroll
for (int i = 0; i < 4; ++i) {
out_pack0.f[i] = acc_o[i];
out_pack1.f[i] = acc_o[i + 4];
}
FLOAT4(ws_o[out_base_idx + 0]) = out_pack0.f4;
FLOAT4(ws_o[out_base_idx + 4]) = out_pack1.f4;
if (lane_id == 0) {
int scalar_idx = q_head_id * num_chunks + chunk_id;
ws_lse[scalar_idx] = m_chunk * 0.6931471805599453f + logf(d_chunk);
}
}
}# 4. benchmark
不多说,直接上 benchmark 结果:
####################################################################################################
decode, kv seq: 8192, head: 32, dim: 128
torch.compile mean time: 0.454655 ms, 295.24 GB/s
flash-infer mean time: 0.403291 ms, speedup: 1.13, GB/s: 332.85
flash_decode_tma_128 mean time: 0.408378 ms, speedup: 1.11, GB/s: 328.70
flash_decode_tma_dbf_k_128 mean time: 0.366698 ms, speedup: 1.24, GB/s: 366.06
####################################################################################################
decode, kv seq: 16384, head: 32, dim: 128
torch.compile mean time: 0.872882 ms, 307.55 GB/s
flash-infer mean time: 0.784423 ms, speedup: 1.11, GB/s: 342.23
flash_decode_tma_128 mean time: 0.735274 ms, speedup: 1.19, GB/s: 365.10
flash_decode_tma_dbf_k_128 mean time: 0.733273 ms, speedup: 1.19, GB/s: 366.10
####################################################################################################
decode, kv seq: 32768, head: 32, dim: 128
torch.compile mean time: 1.507921 ms, 356.04 GB/s
flash-infer mean time: 1.499479 ms, speedup: 1.01, GB/s: 358.05
flash_decode_tma_128 mean time: 1.495797 ms, speedup: 1.01, GB/s: 358.93
flash_decode_tma_dbf_k_128 mean time: 1.455790 ms, speedup: 1.04, GB/s: 368.79
####################################################################################################
decode, kv seq: 65536, head: 32, dim: 128
torch.compile mean time: 2.980080 ms, 360.31 GB/s
flash-infer mean time: 2.897006 ms, speedup: 1.03, GB/s: 370.64
flash_decode_tma_128 mean time: 2.856871 ms, speedup: 1.04, GB/s: 375.85
flash_decode_tma_dbf_k_128 mean time: 2.849400 ms, speedup: 1.05, GB/s: 376.84
####################################################################################################
decode, kv seq: 131072, head: 32, dim: 128
torch.compile mean time: 6.044398 ms, 355.29 GB/s
flash-infer mean time: 5.751600 ms, speedup: 1.05, GB/s: 373.37
flash_decode_tma_128 mean time: 5.663495 ms, speedup: 1.07, GB/s: 379.18
flash_decode_tma_dbf_k_128 mean time: 5.736955 ms, speedup: 1.05, GB/s: 374.33
####################################################################################################
decode, kv seq: 131073, head: 32, dim: 128
torch.compile mean time: 6.466227 ms, 332.11 GB/s
flash-infer mean time: 6.117131 ms, speedup: 1.06, GB/s: 351.07
flash_decode_tma_128 mean time: 5.701174 ms, speedup: 1.13, GB/s: 376.68
flash_decode_tma_dbf_k_128 mean time: 5.695415 ms, speedup: 1.14, GB/s: 377.06- 从结果看,两个自己实现的 kernel 都能稳定超过
torch.compile的 native baseline,而flash_decode_tma_dbf_k_128整体表现最好。 - double K buffer 的收益主要体现在较短序列上:这时流水线更容易影响实际带宽利用率,所以提升更明显。
- 序列继续变长后,几个实现都逐渐逼近带宽上限,彼此差距自然开始收敛,但我们的 kernel 仍然保持领先。
- 以逻辑带宽估算,最高达到
377.06 / 384 = 98.2%,已经很接近这张卡的理论峰值。 - flashinfer 默认实现是单 block(Occupancy 很低)+ double Ks/Vs buffer,实际表现要弱于我们 2block + 2Ks + 1Vs 的配置。再一次证明 Occupancy 的重要性(Occupancy 极低的情况)。
ncu report:

还能看到一些 uncoalesced global accesses,主要来自 ws_o 和 ws_lse 的读写。这部分还没专门优化,不过它们已经不在热点循环里,DRAM 带宽的硬件统计也已经来到 90%+,所以对总耗时影响不大。
# 5. 结束
以上就是我目前对 flash decoding 的所有理解啦,有一些瑕疵就留着吧,准备去写点别的~
如有错误,欢迎指正。如有建议,也欢迎讨论
完整 kernel 和测试代码可以点击 github vitamin-cuda 项目 flash_decode 查看
以上