首页
avatar

WingEdge777

[CUDA 优化实战] 纯手搓 flash decoding sm120 : 超越 flashinfer.single_decode_with_kv_cache

本文适用于有一定 CUDA 编程基础,熟悉 GEMM/multi-head-attention 优化,对进阶嵌入 PTX 指令性能调优感兴趣的读者阅读

完整 kernel 和测试代码可以点击 flash_decode 查看

0. 序 - decode 和 prefill attention : 完全不同的优化哲学

承接上篇 fmha 文章。上篇主要讨论 prefill 场景下的 flash attention,这一篇换到 decode 场景,看看单 query、长 KV cache 时 kernel 该怎么写。

本文的对比 baseline 有两个:

  • torch.compile 后的 PyTorch native 实现,作为通用算子基线
  • flashinfer 的 single_decode_with_kv_cache,作为现成 decode kernel 基线

说明一下:flashinfer baseline 是本文后补的。为了能在我这台 26 SM 的 5060 上正常跑通 benchmark,我对它做了一个最小修复,后面单独说明。

pytorch native

@torch.compile
def torch_native_decode(q, k, v, scale=None):
    # q: [head, dim] -> [32, 128]
    # k: [seq, head, dim] -> [4096, 32, 128]
    # v: [seq, head, dim] -> [4096, 32, 128]
    if scale is None:
        scale = 1.0 / math.sqrt(q.shape[-1])

    # 调整维度以适应 Batched GEMV
    q_b = q.unsqueeze(1)  # [32, 1, 128]
    k_b = k.permute(1, 2, 0)  # [32, 128, 4096]
    v_b = v.transpose(0, 1)  # [32, 4096, 128]

    # S = Q @ K^T
    attn_scores = torch.matmul(q_b, k_b) * scale  # [32, 1, 4096]
    attn_probs = torch.softmax(attn_scores, dim=-1)

    # O = P @ V
    out = torch.matmul(attn_probs, v_b)  # [32, 1, 128]

    return out.squeeze(1)  # [32, 128]

不要因为它是 PyTorch 实现就先入为主地觉得它慢。decode attention 本质上已经很接近 batched GEMV,PyTorch 会走到相当成熟的库实现,再叠加 torch.compile 的图优化,完全够资格做 baseline。自己的 kernel 不认真写,还真不一定打得过它。

flashinfer baseline 的一个兼容性修复

flashinfer 的SingleDecodeWithKVCacheDispatched代码,不知为何cudaOccupancyMaxActiveBlocksPerMultiprocessor返回了 0,导致一步 0,步步 0.

我这里做了一个简单的修复,强制设置 num_blocks_per_sm=1(其默认双 buffer Ks/Vs 实现,占用 64KB smem,也只能是 1),然后把 max_num_kv_chunks 设为了 1(按照原本代码逻辑 block_per_sm * num_sm / heads= 1*26/ 32 就是会等于 0,我只能将其改为接近原代码意图的正整数)。

备注:其实看到 flashinfer 双 Ks/Vs buffer smem 的配置时,我就知道 flashinfer 输定了,Occupancy 比我低一半,这损失的性能不是它的 Vs 双 buffer 能挽回的了(我的 Ks 也是双 buffer)。

本文的 kernel 大纲

  • flash_decode_tma_128 (BN=64,TMA + float4 向量化读取smem + online softmax)
  • flash_decode_tma_dbf_k (BN=64,TMA + float4 向量化读取smem + online softmax,double Ks buffers)

支持 head_dim 为 128 情况下 mha 的 decode attention。

1. flash decoding

flash decoding 的思想大家肯定都学习过了。在 llm decoding 的阶段,由于 batchsize/q_seq_len 太小(甚至直接等于 1),attention 中对 q 的序列并行完全没了,无法充分利用所有 SM。因此考虑对 kv 的 seq 维度进行 chunk 切分,然后分两步完成 attention

  • 一个 block 负责一个 q 和 kv chunk 的 attention,计算完 chunk 内的 m_i/d_i/acc_o,写回 gmem
  • 读取上一步的中间结果,merge m_i/d_i/acc_o,得到最终的 o 输出

这张官方博客的示意图,相信关注的人看过没有十遍也有八遍了。但我们就不重复说原理性的东西,直直白白地讲清楚如何使用 c++ 代码纯手搓出来一个 flash decoding kernel.

为了方便理解,这里只考虑 q shape 为 [head, dim]kv shape 为 [seq, head, dim] 的一次 decode 计算,也就是和上面那段 PyTorch 代码一一对应的版本。

先说 data tiling 策略:

  • 搬运 kv 数据 tile,这里沿用了 flash attention 实现中的 BN=64,即 64x128, 为什么?
  • 我的卡一共 100KB smem,一个 block 最多 48KB,这里即使省掉了原来 q 中 BM 对应的 16KB,也不够 Ks 和 Vs 分的,暂时也不想搞个奇奇怪怪的 BN 大小

也就是说 kv chunk loop 中每次循环加载 64x128 的 Ks/Vs tile。

再说 thread block/grid 配置:

  • block 直接定为 128。这个不是拍脑袋:32、64 太小,可切换 warp 数太少;256 又太大,这个 kernel 的计算密度没高到需要那么多线程一起上。实测下来 128 最合适。
  • grid 上,显然 q 失去了 seq 维度,无法并行。head 还是放在 y 维度上,再考虑对 kv 的 seq 进行切分放到 x 维度。这里只有一个切块大小的问题:
    • 我的 5060 只有 26 个 SM,为了充分利用 SM,我们保障 block 数量为 SM 数量的整数倍,不用太多,2~4 倍即可,我这里就用了 26x2,因此先确定预期的总 chunk 数为 52 个左右
    • 然后运行时用 head*seq/52,且向上对 2 的幂取整得到 chunk_size,则 grid.x = (seq + chunk_size - 1) / chunk_size;

把这些约束合起来,kernel launch 代码就基本定下来了。

这里我没有继续走 Tensor Core 路线。原因很直接:mma m16n8k16 要求 m=16,而 decode 里的 q 本质上只有一行,硬凑出 16 行 padding 只会徒增浪费。再加上这个问题本身更偏向带宽瓶颈,与其执着于 mma,不如把重点放在更高效地搬运和消费 K/V 数据上。

既然不走 mma,那 ldmatrix 和专门为其服务的 swizzle 也都可以先放下。TMA 这里只需要把一整块 64x128 的 K/V tile 原样搬进 shared memory,后面再用向量化读法把它吃满即可。

inline int get_chunk_size(int q_head, int kv_len, int num_sms) {
    int target_blocks = num_sms * 2;

    // Total_Blocks = q_head * (kv_len / chunk_size)
    // chunk_size = (q_head * kv_len) / target_blocks
    int chunk = (q_head * kv_len) / target_blocks;

    if (chunk <= 256)
        return 256;
    if (chunk <= 512)
        return 512;
    if (chunk <= 1024)
        return 1024;
    return 2048;
}

#define CHECK_T(x) TORCH_CHECK(x.is_cuda() && x.is_contiguous(), #x " must be contiguous CUDA tensor")

template <typename T>
inline CUtensorMap create_3d_tensor_map(T *global_address,
                                        uint64_t dim_d,
                                        uint64_t dim_h,
                                        uint64_t dim_s,
                                        uint64_t stride_h,
                                        uint64_t stride_s,
                                        uint32_t box_d,
                                        uint32_t box_s) // Each kernel load takes a (box_s x box_d) block
{
    CUtensorMap tmap;
    CUtensorMapDataType type =
        std::is_same_v<T, __half> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT16 : CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
    CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE;

    // TMA dimensions: from fastest (0) to slowest (2)
    uint64_t globalDim[3] = {dim_d, dim_h, dim_s};

    // globalStrides are strides for dimensions 1, 2, must be in Bytes
    uint64_t globalStrides[2] = {stride_h, stride_s};

    uint32_t boxDim[3] = {box_d, 1, box_s};
    uint32_t elementStrides[3] = {1, 1, 1};

    CUresult res = cuTensorMapEncodeTiled(&tmap,
                                          type,
                                          3, // Rank = 3
                                          global_address,
                                          globalDim,
                                          globalStrides,
                                          boxDim,
                                          elementStrides,
                                          CU_TENSOR_MAP_INTERLEAVE_NONE,
                                          swizzle,
                                          CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
                                          CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);

    TORCH_CHECK(res == CUDA_SUCCESS, "cuTensorMapEncodeTiled failed for 3D Tensor!");
    return tmap;
}

#define DISPATCH_TMA_KERNEL(NAME, HEAD_DIM, CHUNK_SIZE)                                                                \
    NAME##_kernel<BN, CHUNK_SIZE, HEAD_DIM, 128, __nv_bfloat16>                                                        \
        <<<blocks_per_grid, 128, smem_bytes, stream>>>(reinterpret_cast<__nv_bfloat16 *>(q.data_ptr()),                \
                                                       tma_k,                                                          \
                                                       tma_v,                                                          \
                                                       reinterpret_cast<float *>(ws_o.data_ptr()),                     \
                                                       reinterpret_cast<float *>(ws_lse.data_ptr()),                   \
                                                       kv_len,                                                         \
                                                       q_head,                                                         \
                                                       kv_head,                                                        \
                                                       scale);

#define binding_tiled_tma_func_gen(name, HEAD_DIM)                                                                     \
    void name##_##HEAD_DIM(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor o, float scale) {          \
                                                                                                                       \
        CHECK_T(q);                                                                                                    \
        CHECK_T(k);                                                                                                    \
        CHECK_T(v);                                                                                                    \
        CHECK_T(o);                                                                                                    \
                                                                                                                       \
        /* Extract dimension info dynamically from Tensor */                                                           \
        const int q_head = q.size(0);                                                                                  \
        const int head_dim = q.size(1);                                                                                \
        const int kv_len = k.size(0);                                                                                  \
        const int kv_head = k.size(1);                                                                                 \
                                                                                                                       \
        /* Only validate that head_dim matches the compile-time constant */                                            \
        TORCH_CHECK(head_dim == HEAD_DIM, "Head dim mismatch: expected ", HEAD_DIM);                                   \
                                                                                                                       \
        int elem_bytes = k.element_size();                                                                             \
        uint64_t k_stride_h = k.stride(1) * elem_bytes;                                                                \
        uint64_t k_stride_s = k.stride(0) * elem_bytes;                                                                \
        uint64_t v_stride_h = v.stride(1) * elem_bytes;                                                                \
        uint64_t v_stride_s = v.stride(0) * elem_bytes;                                                                \
                                                                                                                       \
        const int BN = 64;                                                                                             \
        const int num_sms = 26;                                                                                        \
        const size_t smem_bytes = BN * head_dim * sizeof(__nv_bfloat16) * 2 + sizeof(mbarrier_t) * 2;                  \
        const int chunk_size = get_chunk_size(q_head, kv_len, num_sms);                                                \
        CUtensorMap tma_k = create_3d_tensor_map<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 *>(k.data_ptr()),       \
                                                                head_dim,                                              \
                                                                kv_head,                                               \
                                                                kv_len,                                                \
                                                                k_stride_h,                                            \
                                                                k_stride_s,                                            \
                                                                head_dim,                                              \
                                                                BN);                                                   \
                                                                                                                       \
        CUtensorMap tma_v = create_3d_tensor_map<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 *>(v.data_ptr()),       \
                                                                head_dim,                                              \
                                                                kv_head,                                               \
                                                                kv_len,                                                \
                                                                v_stride_h,                                            \
                                                                v_stride_s,                                            \
                                                                head_dim,                                              \
                                                                BN);                                                   \
                                                                                                                       \
        TORCH_CHECK(q_head % kv_head == 0, "q_head must be divisible by kv_head");                                     \
        const dim3 blocks_per_grid((kv_len + chunk_size - 1) / chunk_size, q_head);                                    \
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();                                                        \
        auto options = torch::TensorOptions().dtype(torch::kFloat32).device(q.device());                               \
        auto ws_lse = torch::empty({q_head, blocks_per_grid.x}, options);                                              \
        auto ws_o = torch::empty({q_head, blocks_per_grid.x, head_dim}, options);                                      \
        /* launch kernel */                                                                                            \
        switch (chunk_size) {                                                                                          \
            case 256: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 256); break;                                                 \
            case 512: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 512); break;                                                 \
            case 1024: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 1024); break;                                               \
            case 2048: DISPATCH_TMA_KERNEL(name, HEAD_DIM, 2048); break;                                               \
            default: TORCH_CHECK(false, "Unsupported chunk size: ", chunk_size);                                       \
        }                                                                                                              \
        flash_decode_reduce_kernel<HEAD_DIM, 128, __nv_bfloat16>                                                       \
            <<<q_head, 128, 0, stream>>>(reinterpret_cast<float *>(ws_o.data_ptr()),                                   \
                                         reinterpret_cast<float *>(ws_lse.data_ptr()),                                 \
                                         reinterpret_cast<__nv_bfloat16 *>(o.data_ptr()),                              \
                                         blocks_per_grid.x);                                                           \
    }

binding_tiled_tma_func_gen(flash_decode_tma, 128);

#define torch_pybinding_func(f) m.def(#f, &f, #f)

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    // flash_decode_tma_128
    torch_pybinding_func(flash_decode_tma_128);
}

使用一个 help func 拿到 chunk_size,根据 chunk size 用一个 dispatch 宏分发到不同的 kernel launch 上(因为我们需要编译期确定 chunk_size,以帮助 kernel 内部循环展开)

2. kernel 实现细节

从逻辑上来说,一次 tiling 下的数据(chunk 大小是 tiling 大小的整数倍),我们需要加载一行 Qs[1,128],然后和 Ks.T[128,64] 做向量矩阵乘法得到 s[1,64], 然后 softmax 完得到 Ps[1,64], 再乘以 Vs[64,128]。

现在的情况是,我们一个 block 有 4 个 warp,128 线程,如何来瓜分这些计算。

首先排除掉一个线程负责一整行/一整列点积的思路,这不是 cuda 并行编程的第一性原理,不能用串行思维去写代码,这样直接后果就是 bank conflict 爆炸,而且算完一行的点积结果后,还是要经过 block 线程间进行同步广播,否则无法参数接下来 Vs 的计算。

我们应该先从一个 warp 去考虑,比如一个 warp 负责一行计算,那么 128 的向量共 32 线程,每个线程就只需要负责 4 个元素。4 个 warp,每个 warp 负责 16 行,这样 16 行 16 次循环内只需要 warp 间同步,等 16 行计算完四组再进行 block 间同步。

更进一步,考虑到我们使用的是半精度,也就是说一个元素才 2 字节,我们为了最大化压榨带宽,肯定是用 float4 向量化指令,因此我们让一个线程负责 8 个元素,一个 warp 一次就可以算两行,只需要循环 8 次。

初步整理一下算法流程如下:

  • kernel pass 1:
    • 初始化 Ks、Vs 2 块 smem 和 2 个 tma mbarrier
    • 初始化 acc_o[8], 每个 group(16 线程)私有化初始化历史状态 acc_o[8], m_i, d_i
    • 加载 Qs[8](使用 float4 向量化加载到寄存器)
    • 一个 block 负责一个 chunk,kv chunk 内 loop:
      • tma 发起加载 Ks、Vs,并等待 Ks 加载完成
      • [计算 S] 循环 8 次(Group 内每线程负责 8 行):
        • float4 向量化读取 Ks 一行内的 8 个元素
        • Qs[8] 和 reg_k[8] 进行点积计算
        • Warp Reduce 求和得到单行的 Attention Score,并统计当前小块的 m_part
      • wait Vs 加载完毕
      • [计算 O] 循环 8 次:
        • 根据 m_part 计算当前行的 Softmax 权重标量 p
        • p 乘以 reg_v[8] 向量,累加得到当前小块的 part_o[8],并统计当前小块的 d_part
      • [group 内部状态更新] 使用安全的 online softmax 逻辑计算 alpha,将当前的 part_o、d_part、m_part 融合进 group 维护的历史 acc_o、d_i、m_i 中
    • block 内各 group 进行 block_reduce,合并各自的 acc_o、d_i、m_i
    • 将最终合并的寄存器结果转化为 ws_o 和 d_i/m_i,写回 gmem
  • kernel pass 2
    • 读取上面 gmem 的 ws_o,和 m_i/d_i,online softmax 继续规约得到最终输出 o
  • 结束

上面是最原汁原味的分块 online + “offline” softmax attention. 但参考一些教程,可以加入一个 trick 优化,将 m_i 和 d_i 合并为 lse(logsumexp),就是对 e 的指数和再取对数。当然由于我们还加入了 log2_scale 的技巧,因此 lse 变成了一个有点丑陋的东西:

lse(o) = m_i * ln(2) + ln(d_i)

然后 pass 2 内通过 lse 和 ws_o 合并方式伪代码为:

float max_lse = max(lse_i);
float global_lse = max_lse + ln(sum(exp(lse_i - max_lse)));
o = sum(ws_o_i * exp(lse_i - global_lse));

上 kernel 代码:

// flash decoding softmax(q @ k.T*scale) @ v
template <const int BN = 64,
          const int CHUNK_SIZE = 256,
          const int HEAD_DIM = 128,
          const int THREADS_PER_BLOCK = 128,
          typename T>
__global__ void flash_decode_tma_kernel(T *q,
                                        const __grid_constant__ CUtensorMap tma_k,
                                        const __grid_constant__ CUtensorMap tma_v,
                                        float *ws_o,   // [q_head, num_chunks, HEAD_DIM]
                                        float *ws_lse, // [q_head, num_chunks]
                                        int kv_len,
                                        int q_head,
                                        int kv_head,
                                        float scale) {
    static_assert(THREADS_PER_BLOCK == 128);
    static_assert(BN == 64);

    // 1. shared memory: K tile, V tile, mbarriers
    extern __shared__ __align__(128) uint8_t smem_buf[];
    T(*Ks)[HEAD_DIM] = reinterpret_cast<T(*)[HEAD_DIM]>(smem_buf);
    T(*Vs)[HEAD_DIM] = reinterpret_cast<T(*)[HEAD_DIM]>(smem_buf + BN * HEAD_DIM * sizeof(T));
    mbarrier_t *mbar_k = reinterpret_cast<mbarrier_t *>(smem_buf + BN * HEAD_DIM * sizeof(T) * 2);
    mbarrier_t *mbar_v = mbar_k + 1;

    // 2. coordinates
    const int tid = threadIdx.x;
    const int chunk_id = blockIdx.x;
    const int q_head_id = blockIdx.y;
    const int kv_group_size = q_head / kv_head;
    const int kv_head_id = q_head_id / kv_group_size;

    constexpr int THREADS_PER_ROW = 16;
    constexpr int NUM_GROUPS = THREADS_PER_BLOCK / THREADS_PER_ROW;
    constexpr int ROWS_PER_GROUP = BN / NUM_GROUPS;
    const int group_id = tid / THREADS_PER_ROW;
    const int lane_id = tid % THREADS_PER_ROW;

    if (tid == 0) {
        mbarrier_init(mbar_k, 1);
        mbarrier_init(mbar_v, 1);
    }
    __syncthreads();

    // 3. load q fragment
    pack128 qs{FLOAT4(q[q_head_id * HEAD_DIM + lane_id * 8])};

    // 4. init subgroup-local online softmax state
    __align__(16) float acc_o[8] = {0.0f};
    float m_i = -FLT_MAX;
    float d_i = 0.0f;

    int phase_k = 0;
    int phase_v = 0;
    const float scale_log2 = scale * 1.44269504f; // scale*log2(e)
    const int num_chunks = gridDim.x;
    const int chunk_start = chunk_id * CHUNK_SIZE;
    const int chunk_end = min(chunk_start + CHUNK_SIZE, kv_len);

    // 5. loop over KV tiles inside this chunk
    for (int n = chunk_start; n < chunk_end; n += BN) {
        int current_bn = min(BN, chunk_end - n);

        // 5.1 TMA async load K/V
        if (tid == 0) {
            mbarrier_expect_tx(mbar_k, BN * HEAD_DIM * sizeof(T));
            mbarrier_expect_tx(mbar_v, BN * HEAD_DIM * sizeof(T));
            cp_async_bulk_tensor_3d(mbar_k, &tma_k, Ks, 0, kv_head_id, n);
            cp_async_bulk_tensor_3d(mbar_v, &tma_v, Vs, 0, kv_head_id, n);
        }
        __syncthreads();
        mbarrier_wait(mbar_k, phase_k);
        phase_k ^= 1; // flip phase

        // 5.2 compute S = Q * K^T, keep rows per subgroup in registers
        const int row_begin = group_id * ROWS_PER_GROUP;
        float acc_s[ROWS_PER_GROUP];
        float m_part = -FLT_MAX;
#pragma unroll
        for (int i = 0; i < ROWS_PER_GROUP; ++i) {
            acc_s[i] = -FLT_MAX;
        }
#pragma unroll
        for (int i = 0; i < ROWS_PER_GROUP; ++i) {
            const int row = row_begin + i;
            float sum = 0.0f;
            if (row < current_bn) {
                pack128 ks{FLOAT4(Ks[row][lane_id * 8])};
#pragma unroll
                for (int j = 0; j < 8; ++j) {
                    sum += static_cast<float>(qs.bf[j]) * static_cast<float>(ks.bf[j]);
                }
            }
            sum = warp_reduce_sum<THREADS_PER_ROW>(sum);
            if (row < current_bn) {
                acc_s[i] = sum * scale_log2;
                m_part = fmaxf(m_part, acc_s[i]);
            }
        }

        // 5.3 accumulate subgroup-local O = P * V
        mbarrier_wait(mbar_v, phase_v);
        phase_v ^= 1;
        float part_d = 0.0f;
        float part_o[8] = {0.0f};
#pragma unroll
        for (int i = 0; i < ROWS_PER_GROUP; ++i) {
            const int row = row_begin + i;
            if (row < current_bn) {
                float p = exp2f(acc_s[i] - m_part);
                part_d += p;

                pack128 vs{FLOAT4(Vs[row][lane_id * 8])};
#pragma unroll
                for (int j = 0; j < 8; ++j) {
                    part_o[j] += p * static_cast<float>(vs.bf[j]);
                }
            }
        }
        if (m_part != -FLT_MAX) {
            const float m_new = fmaxf(m_i, m_part);
            const float alpha_old = exp2f(m_i - m_new);
            const float alpha_new = exp2f(m_part - m_new);
#pragma unroll
            for (int i = 0; i < 8; ++i) {
                acc_o[i] = acc_o[i] * alpha_old + part_o[i] * alpha_new;
            }
            d_i = d_i * alpha_old + part_d * alpha_new;
            m_i = m_new;
        }
    }

    // 6. merge subgroup states once per chunk, then write split results
    const float m_chunk = block_reduce_max<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? m_i : -FLT_MAX);
    const float alpha = d_i > 0.0f ? exp2f(m_i - m_chunk) : 0.0f;
    const float d_chunk = block_reduce_sum<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? d_i * alpha : 0.0f);
#pragma unroll
    for (int i = 0; i < 8; ++i) {
        acc_o[i] = block_reduce_sum_by_lane<NUM_GROUPS, THREADS_PER_ROW>(acc_o[i] * alpha);
    }

    if (group_id == 0) {
        int out_base_idx = (q_head_id * num_chunks + chunk_id) * HEAD_DIM + lane_id * 8;
        float inv_d = __frcp_rn(d_chunk);
#pragma unroll
        for (int i = 0; i < 8; i++) {
            acc_o[i] *= inv_d;
        }
        pack128 out_pack0, out_pack1;
#pragma unroll
        for (int i = 0; i < 4; ++i) {
            out_pack0.f[i] = acc_o[i];
            out_pack1.f[i] = acc_o[i + 4];
        }
        FLOAT4(ws_o[out_base_idx + 0]) = out_pack0.f4;
        FLOAT4(ws_o[out_base_idx + 4]) = out_pack1.f4;

        if (lane_id == 0) {
            int scalar_idx = q_head_id * num_chunks + chunk_id;
            ws_lse[scalar_idx] = m_chunk * 0.6931471805599453f + logf(d_chunk);
        }
    }
}

template <const int HEAD_DIM = 128, const int THREADS_PER_BLOCK = 128, typename T>
__global__ void flash_decode_reduce_kernel(float *ws_o, float *ws_lse, T *o, int num_chunks) {
    const int q_head_id = blockIdx.x;
    const int tid = threadIdx.x;
    constexpr int NUM_WARPS = THREADS_PER_BLOCK / WARP_SIZE;

    __shared__ float s_lse;

    float lse_max = -FLT_MAX;
    for (int chunk = tid; chunk < num_chunks; chunk += THREADS_PER_BLOCK) {
        lse_max = fmaxf(lse_max, ws_lse[q_head_id * num_chunks + chunk]);
    }
    lse_max = block_reduce_max<NUM_WARPS, WARP_SIZE>(lse_max);

    float lse_sum = 0.0f;
    for (int chunk = tid; chunk < num_chunks; chunk += THREADS_PER_BLOCK) {
        lse_sum += expf(ws_lse[q_head_id * num_chunks + chunk] - lse_max);
    }
    lse_sum = block_reduce_sum<NUM_WARPS, WARP_SIZE>(lse_sum);
    if (tid == 0) {
        s_lse = logf(lse_sum) + lse_max;
    }
    __syncthreads();

    const int col = tid * 8;
    if (col >= HEAD_DIM) {
        return;
    }

    float out[8] = {0.0f};
    for (int chunk = 0; chunk < num_chunks; ++chunk) {
        const int scalar_idx = q_head_id * num_chunks + chunk;
        const float weight = expf(ws_lse[scalar_idx] - s_lse);
        const int base_idx = scalar_idx * HEAD_DIM + col;
        pack128 partial0{FLOAT4(ws_o[base_idx + 0])};
        pack128 partial1{FLOAT4(ws_o[base_idx + 4])};
#pragma unroll
        for (int i = 0; i < 4; ++i) {
            out[i] += partial0.f[i] * weight;
            out[i + 4] += partial1.f[i] * weight;
        }
    }

    pack128 out_pack;
#pragma unroll
    for (int i = 0; i < 8; ++i) {
        out_pack.bf[i] = __float2bfloat16_rn(out[i]);
    }
    FLOAT4(o[q_head_id * HEAD_DIM + col]) = out_pack.f4;
}

其中 warp reduce、block reduce、tma copy 等都抽成函数出去了。

细心的朋友可能注意到了

  • flash decoding 里没有 lazy rescale,因为没必要,原来 prefill 里 是用 16 次 Ps scale 去替换 64 次 acc_o 的 scale 是值得的,这里 acc_o 只有 8,和 p 乘以 inv_scale 的次数相同。
  • 同样的,也没有 need_casual_mask 校验,因为对于当前的 q,所以 kv 都是可见的

唯一值得说明一下的是

grouped warp/block reduce

因为我们一个 warp 负责两行,相当于把一个 warp 劈开成两半,分别去做 reduce 了,所以 warp reduce 的时候要指定宽度 16,用了个 group_size 模板参数

其实本版代码的实现还是很粗糙的,比如对于 ws_o/ws_lse 的写回还没有做优化,后续再看吧~(stay tuned,我也可能去学习一下 flashinfer 里如何实现等等)

3. flash_decode_tma_dbf_k

ok, 完成上述 kernel 后,我们总结一下。目前使用了 smem 32KB + 几个 barrier + block 同步占用的中间变量 num_group*group_size。这里在保住 Occupancy(经验表明,在 1 个 block 或 2 个 block 的选择下,保两个 block,让硬件调度总是会更优)的前提下,我们唯一的办法就是再增加 buffer 进行流水线操作,进一步用计算隐藏时延。 因此我做了如下两点优化:

double Ks buffer

Ks 使用双 buffer,Vs 依然是单 buffer。为什么可以这么做呢,因为 attention 里 Ks 和 Vs 本来就是异步的,Vs 要等 Ks 的计算完才会用到,所以 Vs 本来就是被隐藏的,只要我们再增加一重 buffer 把 Ks 也隐藏掉理论上就很好。

具体流水线操作也很简单:

  • 初始化 2 份 K tile
  • prologue:在 kv chunk loop 之前先发起加载 Ks[0],并初始化 read/write_idx
  • kv chunk loop:
    • 如果还有 next Ks tile,就发起 TMA Ks[write_idx] 请求
    • 发起 Vs TMA 请求
    • wait Ks[read_idx] 加载完毕
    • online softmax(Qs * Ks)
    • wait Vs
    • 计算输出
    • 同步,并反转 read/write_idx
  • epilogue

epilogue 优化

原来的 epilogue 有 8 次循环的 block reduce,有点重。现在改成复用 Ks/Vs 的 smem buffer 进行中转,然后再用一个单独的 group(16 线程)读取统计 ws_o,最后写回。 这个优化其实影响也不大,改动前后级别没什么提升。不过写都写了,就留着吧

上代码:

// flash decoding softmax(q @k.T *scale) @v
template <const int BN = 64,
          const int CHUNK_SIZE = 256,
          const int HEAD_DIM = 128,
          const int THREADS_PER_BLOCK = 128,
          typename T>
__global__ void flash_decode_tma_dbf_k_kernel(T *q,
                                              const __grid_constant__ CUtensorMap tma_k,
                                              const __grid_constant__ CUtensorMap tma_v,
                                              float *ws_o,   // [q_head, num_chunks, HEAD_DIM]
                                              float *ws_lse, // [q_head, num_chunks]
                                              int kv_len,
                                              int q_head,
                                              int kv_head,
                                              float scale) {
    static_assert(THREADS_PER_BLOCK == 128);
    static_assert(BN == 64);

    // 1. shared memory: K tile, V tile, mbarriers
    extern __shared__ __align__(128) uint8_t smem_buf[];
    T(*Ks)[BN][HEAD_DIM] = reinterpret_cast<T(*)[BN][HEAD_DIM]>(smem_buf);
    T(*Vs)[HEAD_DIM] = reinterpret_cast<T(*)[HEAD_DIM]>(smem_buf + BN * HEAD_DIM * sizeof(T) * 2);
    mbarrier_t *mbar_k = reinterpret_cast<mbarrier_t *>(smem_buf + BN * HEAD_DIM * sizeof(T) * 3);
    mbarrier_t *mbar_v = mbar_k + 2;

    // 2. coordinates
    const int tid = threadIdx.x;
    const int chunk_id = blockIdx.x;
    const int q_head_id = blockIdx.y;
    const int kv_group_size = q_head / kv_head;
    const int kv_head_id = q_head_id / kv_group_size;

    constexpr int THREADS_PER_ROW = 16;
    constexpr int NUM_GROUPS = THREADS_PER_BLOCK / THREADS_PER_ROW;
    constexpr int ROWS_PER_GROUP = BN / NUM_GROUPS;
    const int group_id = tid / THREADS_PER_ROW;
    const int lane_id = tid % THREADS_PER_ROW;

    // 3. load q fragment
    pack128 qs{FLOAT4(q[q_head_id * HEAD_DIM + lane_id * 8])};

    // 4. init subgroup-local online softmax state
    __align__(16) float acc_o[8] = {0.0f};
    float m_i = -FLT_MAX;
    float d_i = 0.0f;

    int phase_k[2] = {0};
    int phase_v = 0;

    const float scale_log2 = scale * 1.44269504f; // scale*log2(e)
    const int num_chunks = gridDim.x;
    const int chunk_start = chunk_id * CHUNK_SIZE;
    const int chunk_end = min(chunk_start + CHUNK_SIZE, kv_len);
    // preload Ks
    if (tid == 0) {
        mbarrier_init(mbar_k, 1);
        mbarrier_init(mbar_k + 1, 1);
        mbarrier_init(mbar_v, 1);

        mbarrier_expect_tx(mbar_k, BN * HEAD_DIM * sizeof(T));
        cp_async_bulk_tensor_3d(mbar_k, &tma_k, Ks[0], 0, kv_head_id, chunk_start);
    }
    __syncthreads();
    int read_idx = 0, write_idx = 1;

    // 5. loop over KV tiles inside this chunk
    for (int n = chunk_start; n < chunk_end; n += BN) {
        int current_bn = min(BN, chunk_end - n);

        // 5.1 TMA async load K/V
        if (tid == 0) {
            if (n + BN < chunk_end) {
                mbarrier_expect_tx(mbar_k + write_idx, BN * HEAD_DIM * sizeof(T));
                cp_async_bulk_tensor_3d(mbar_k + write_idx, &tma_k, Ks[write_idx], 0, kv_head_id, n + BN);
            }
            mbarrier_expect_tx(mbar_v, BN * HEAD_DIM * sizeof(T));
            cp_async_bulk_tensor_3d(mbar_v, &tma_v, Vs, 0, kv_head_id, n);
        }
        mbarrier_wait(mbar_k + read_idx, phase_k[read_idx]);
        phase_k[read_idx] ^= 1; // flip phase

        // 5.2 compute S = Q * K^T, keep rows per subgroup in registers
        const int row_begin = group_id * ROWS_PER_GROUP;
        float acc_s[ROWS_PER_GROUP];
        float m_part = -FLT_MAX;
#pragma unroll
        for (int i = 0; i < ROWS_PER_GROUP; ++i) {
            acc_s[i] = -FLT_MAX;
        }
#pragma unroll
        for (int i = 0; i < ROWS_PER_GROUP; ++i) {
            const int row = row_begin + i;
            float sum = 0.0f;
            if (row < current_bn) {
                pack128 ks{FLOAT4(Ks[read_idx][row][lane_id * 8])};
#pragma unroll
                for (int j = 0; j < 8; ++j) {
                    sum += static_cast<float>(qs.bf[j]) * static_cast<float>(ks.bf[j]);
                }
            }
            sum = warp_reduce_sum<THREADS_PER_ROW>(sum);
            if (row < current_bn) {
                acc_s[i] = sum * scale_log2;
                m_part = fmaxf(m_part, acc_s[i]);
            }
        }

        // 5.3 accumulate subgroup-local O = P * V
        mbarrier_wait(mbar_v, phase_v);
        phase_v ^= 1;
        float part_d = 0.0f;
        float part_o[8] = {0.0f};
#pragma unroll
        for (int i = 0; i < ROWS_PER_GROUP; ++i) {
            const int row = row_begin + i;
            if (row < current_bn) {
                float p = exp2f(acc_s[i] - m_part);
                part_d += p;

                pack128 vs{FLOAT4(Vs[row][lane_id * 8])};
#pragma unroll
                for (int j = 0; j < 8; ++j) {
                    part_o[j] += p * static_cast<float>(vs.bf[j]);
                }
            }
        }
        if (m_part != -FLT_MAX) {
            const float m_new = fmaxf(m_i, m_part);
            const float alpha_old = exp2f(m_i - m_new);
            const float alpha_new = exp2f(m_part - m_new);
#pragma unroll
            for (int i = 0; i < 8; ++i) {
                acc_o[i] = acc_o[i] * alpha_old + part_o[i] * alpha_new;
            }
            d_i = d_i * alpha_old + part_d * alpha_new;
            m_i = m_new;
        }

        // next round
        __syncthreads();
        read_idx ^= 1;
        write_idx ^= 1;
    }

    // 6. epilogue: merge subgroup states once per chunk, then write split results
    const float m_chunk = block_reduce_max<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? m_i : -FLT_MAX);
    const float alpha = d_i > 0.0f ? exp2f(m_i - m_chunk) : 0.0f;
    const float d_chunk = block_reduce_sum<NUM_GROUPS, THREADS_PER_ROW>(lane_id == 0 ? d_i * alpha : 0.0f);
    // reuse buffer
    constexpr int O_PER_GROUP = 8 * THREADS_PER_ROW;
    constexpr int O_GROUP_STRIDE = O_PER_GROUP + THREADS_PER_ROW;
    float *sdata_o = reinterpret_cast<float *>(smem_buf);
#pragma unroll
    for (int i = 0; i < 8; ++i) {
        sdata_o[group_id * O_GROUP_STRIDE + i * THREADS_PER_ROW + lane_id] = acc_o[i] * alpha;
    }
    __syncthreads();

    if (group_id == 0) {
#pragma unroll
        for (int i = 0; i < 8; ++i) {
            float val = 0.0f;
#pragma unroll
            for (int group = 0; group < NUM_GROUPS; ++group) {
                val += sdata_o[group * O_GROUP_STRIDE + i * THREADS_PER_ROW + lane_id];
            }
            acc_o[i] = val;
        }
    }

    if (group_id == 0) {
        int out_base_idx = (q_head_id * num_chunks + chunk_id) * HEAD_DIM + lane_id * 8;
        float inv_d = __frcp_rn(d_chunk);
#pragma unroll
        for (int i = 0; i < 8; i++) {
            acc_o[i] *= inv_d;
        }
        pack128 out_pack0, out_pack1;
#pragma unroll
        for (int i = 0; i < 4; ++i) {
            out_pack0.f[i] = acc_o[i];
            out_pack1.f[i] = acc_o[i + 4];
        }
        FLOAT4(ws_o[out_base_idx + 0]) = out_pack0.f4;
        FLOAT4(ws_o[out_base_idx + 4]) = out_pack1.f4;

        if (lane_id == 0) {
            int scalar_idx = q_head_id * num_chunks + chunk_id;
            ws_lse[scalar_idx] = m_chunk * 0.6931471805599453f + logf(d_chunk);
        }
    }
}

4. benchmark

不多说,直接上 benchmark 结果:

####################################################################################################
decode, kv seq: 8192, head: 32, dim: 128
torch.compile                            mean time: 0.454655 ms, 295.24 GB/s
flash-infer                              mean time: 0.403291 ms, speedup: 1.13, GB/s: 332.85
flash_decode_tma_128                     mean time: 0.408378 ms, speedup: 1.11, GB/s: 328.70
flash_decode_tma_dbf_k_128               mean time: 0.366698 ms, speedup: 1.24, GB/s: 366.06
####################################################################################################
decode, kv seq: 16384, head: 32, dim: 128
torch.compile                            mean time: 0.872882 ms, 307.55 GB/s
flash-infer                              mean time: 0.784423 ms, speedup: 1.11, GB/s: 342.23
flash_decode_tma_128                     mean time: 0.735274 ms, speedup: 1.19, GB/s: 365.10
flash_decode_tma_dbf_k_128               mean time: 0.733273 ms, speedup: 1.19, GB/s: 366.10
####################################################################################################
decode, kv seq: 32768, head: 32, dim: 128
torch.compile                            mean time: 1.507921 ms, 356.04 GB/s
flash-infer                              mean time: 1.499479 ms, speedup: 1.01, GB/s: 358.05
flash_decode_tma_128                     mean time: 1.495797 ms, speedup: 1.01, GB/s: 358.93
flash_decode_tma_dbf_k_128               mean time: 1.455790 ms, speedup: 1.04, GB/s: 368.79
####################################################################################################
decode, kv seq: 65536, head: 32, dim: 128
torch.compile                            mean time: 2.980080 ms, 360.31 GB/s
flash-infer                              mean time: 2.897006 ms, speedup: 1.03, GB/s: 370.64
flash_decode_tma_128                     mean time: 2.856871 ms, speedup: 1.04, GB/s: 375.85
flash_decode_tma_dbf_k_128               mean time: 2.849400 ms, speedup: 1.05, GB/s: 376.84
####################################################################################################
decode, kv seq: 131072, head: 32, dim: 128
torch.compile                            mean time: 6.044398 ms, 355.29 GB/s
flash-infer                              mean time: 5.751600 ms, speedup: 1.05, GB/s: 373.37
flash_decode_tma_128                     mean time: 5.663495 ms, speedup: 1.07, GB/s: 379.18
flash_decode_tma_dbf_k_128               mean time: 5.736955 ms, speedup: 1.05, GB/s: 374.33
####################################################################################################
decode, kv seq: 131073, head: 32, dim: 128
torch.compile                            mean time: 6.466227 ms, 332.11 GB/s
flash-infer                              mean time: 6.117131 ms, speedup: 1.06, GB/s: 351.07
flash_decode_tma_128                     mean time: 5.701174 ms, speedup: 1.13, GB/s: 376.68
flash_decode_tma_dbf_k_128               mean time: 5.695415 ms, speedup: 1.14, GB/s: 377.06
  • 从结果看,两个自己实现的 kernel 都能稳定超过 torch.compile 的 native baseline,而 flash_decode_tma_dbf_k_128 整体表现最好。
  • double K buffer 的收益主要体现在较短序列上:这时流水线更容易影响实际带宽利用率,所以提升更明显。
  • 序列继续变长后,几个实现都逐渐逼近带宽上限,彼此差距自然开始收敛,但我们的 kernel 仍然保持领先。
  • 以逻辑带宽估算,最高达到 377.06 / 384 = 98.2%,已经很接近这张卡的理论峰值。
  • flashinfer 默认实现是单 block(Occupancy 很低)+ double Ks/Vs buffer,实际表现要弱于我们 2block + 2Ks + 1Vs 的配置。再一次证明 Occupancy 的重要性(Occupancy 极低的情况)。

ncu report:

还能看到一些 uncoalesced global accesses,主要来自 ws_ows_lse 的读写。这部分还没专门优化,不过它们已经不在热点循环里,DRAM 带宽的硬件统计也已经来到 90%+,所以对总耗时影响不大。

5. 结束

以上就是我目前对 flash decoding 的所有理解啦,有一些瑕疵就留着吧,准备去写点别的~

如有错误,欢迎指正。如有建议,也欢迎讨论

完整 kernel 和测试代码可以点击 github vitamin-cuda 项目 flash_decode 查看

以上

vitamin-cuda cuda c++ GPU GEMM flash attention flash decoding