[CUDA 优化实战] hgemm sm120 - 100KB SMEM 中的“微雕”战争:Tensor-core、TMA、ldmatrix、mma
文章描述
对不起朋友们,本来我说 gemm 系列不会有后续,但我食言了,今天依然是 hgemm,不过我们要拥抱 RTX 5060 laptop 上的一切,TMA + ldmatrix + mma,挑战极限
本文适用于有一定 CUDA 编程基础,熟悉 GEMM 优化,对进阶 tensor core / 嵌入 PTX 指令 性能调优感兴趣的读者阅读
完整 kernel 和测试代码可以点击 hgemm_sm120 查看
# 0. 序 - sm120 - 被阉割的 blackwell
众所周知,我们 geforce 50 系列消费级显卡的 sm120 架构虽然也叫 blackwell,但和 sm100 的 B 系列 完全不是一个品种,一刀又一刀,阉割得啥都没了。tcgen5 指令没有,wgmma 没有。那有啥,有 TMA(Tensor Memory accelerator),另外 NV 还贴心的拓展了 mma 指令支持 fp8/6/4 等精度。我们今天的重点是使用 TMA,TMA 是从 hopper 架构引进专门用于加速张量数据 copy 的专用硬件,只需要一条指令即可异步的搬运一小块指定的矩阵,总之就是速度快,节省指令,异步,还自带 swizzle,配合 wgmma 简直绝了。
可惜,我们没有 wgmma。
据我所知,cutlass 都没有实现 tma + mma 这鬼畜搭配的 hgemm,仅有量化版低精度的 gemm,感兴趣参考 cutlass example 79/87,所以我的这个 tma 移花接木 ldmatrix + mma.sync 实现的 hgemm,不敢说全网独一份,但肯定算稀有动物,请看官们一赏~
好,章接上文,本篇将以 M=N=K=4096(MxKxN, cuBLAS 最擅长的中等规模)的 GEMM fp16/bf16 为例,在 RTX 5060 laptop 上,使用 TMA(自带 swizzle) 移花接木 ldmatrix(手动 swizzle)+mma
实现 hgemm。并且用 cublas 和上篇文章中的 hgemm_bcf_dbf_rw_kernel 作为 baseline,进行对比。
这里先给结论,在我反复测试下,用上 tma 三级流水线 + 双缓冲寄存器 ldmatrix + mma 的 终版 kernel 和 我手写的纯 sm80 架构 kernel,在性能上,只能说可能具备极小的微弱的优势。在锁定最高显存和显卡频率的前提下,跑 10 次 benchmarks, 和 hgemm_bcf_dbf_rw_kernel 的速度对比互有胜负,大概 6/4 开
是不是有点沮丧。(其实我是有一点的,毕竟我花了很多的时间 debug 和测试,才跑通 tma 的代码)
不过换个角度想,我也了却上一篇文章的遗憾,基本实现了一个我觉得完美的 kernel,hgemm_bcf_dbf_rw_kernel 中出现的 Uncoalesced Shared Accesses, 这里不见了,哈哈~
只有一个 math-pipeline wait stall 提示(tensor-core 的计算速度也跟不上了) 和一个 warp Occupancy 不足提示(没有 smem 了)。其他都是完美。
本文将会给出 4 个 kernel 实现,
kernel 大纲如下(第一个是 cuBLAS kernel)
- hgemm_cublas bf16/fp16 版
- hgemm_bcf_dbf_rw bf16/fp16 版 (ldmatrix + mma, As/Bs swizzle bcf, double buffer, coalesced r/w gmem, 重构版,抽象出 copy,ldmatrix,mma compute 等函数)
- hgemm_k_stages bf16/fp16 版 (基于 hgemm_bcf_dbf_rw 改造的 kernel,可支持 3 级流水线,smem 上限了)
- hgemm_tma_r_k_stages bf16/fp16 版 (基于 hgemm_k_stages 改造的 kernel,将 cp.async 读取 gmem 替换为 TMA copy)
- hgemm_tma_rw_k_stages (这个是 todo,用 TMA copy 回 gmem,但我其实没心气做了,因为预期没有多少收益)
# 1. hgemm_bcf_dbf_rw
这个 kernel 在上一篇文章已经详细介绍过了,如何从基础版演化为终版 kernel,感兴趣的朋友请移步 .
因此,这里只是做了一些重构工作,目的是为了让 kernel 结构更简短清晰一些,方便改造为多级流水线。 所以直接上代码:
// a block calculate c[128][128]
template <const int BM = 128, const int BN = 128, const int BK = 32, typename T>
__global__ void hgemm_bcf_dbf_rw_kernel(T *a, T *b, T *c, int m, int n, int k) {
// grid swizzling
int linear_id = blockIdx.y * gridDim.x + blockIdx.x;
const int SWIZZLE_W = 8; // 将执行块设置为 8 的宽度
int bx = (linear_id % SWIZZLE_W) + (linear_id / (SWIZZLE_W * gridDim.y)) * SWIZZLE_W;
int by = (linear_id / SWIZZLE_W) % gridDim.y;
int tid = threadIdx.x; // 0~255
int warp_id = tid / WARP_SIZE;
int lane_id = tid % WARP_SIZE;
// 搬运映射
int load_a_row = tid / 4; // 0~63
int load_a_col = (tid % 4) * 8; // 0,8,16,24
int load_b_row = tid / 16; // 0~15 (K 维度)
int load_b_col = (tid % 16) * 8; // 0,8,16 ... 120 (N 维度)
// A/B 都行优先,用 union 复用同一块内存,写法优雅
__shared__ __align__(128) union {
// 前半段计算用的 A 和 B
struct {
T As[2][BM][BK];
T Bs[2][BK][BN];
};
// 后半段写回用的 C
T Cs[BM][BN];
} smem;
// warp tiling
// 每个 warp 负责 64 x 32 的 C 矩阵块
int warp_id_m = warp_id / 4; // 0, 1
int warp_id_n = warp_id % 4; // 0, 1, 2, 3
// 寄存器总量:M 维 4 块 * N 维 4 块 * 每块 4 个寄存器 = 64
float sum[4][4][4] = {0.f};
T *global_a_ptr = &a[(by * BM + load_a_row) * k + load_a_col];
T *global_b_ptr = &b[load_b_row * n + bx * BN + load_b_col];
// ----------------------------- Prologue 先加载一次 As/Bs
// 内部已包含跨行加载逻辑,确保覆盖全部 128x32/32x128 元素
cp_async_load_A<BK>(smem.As[0], load_a_row, load_a_col, global_a_ptr, k);
cp_async_load_B<BK, BN>(smem.Bs[0], load_b_row, load_b_col, global_b_ptr, n);
CP_ASYNC_COMMIT_GROUP();
cp_async_wait_group<0>();
__syncthreads();
int read_idx = 0;
int write_idx = 1;
// 主循环
for (int bk = 32; bk < k; bk += BK) {
// 推进指针
global_a_ptr += BK;
global_b_ptr += BK * n;
// 1. cp.async load A/B
cp_async_load_A<BK>(smem.As[write_idx], load_a_row, load_a_col, global_a_ptr, k);
cp_async_load_B<BK, BN>(smem.Bs[write_idx], load_b_row, load_b_col, global_b_ptr, n);
CP_ASYNC_COMMIT_GROUP();
// 2. Tensor Core 计算阶段 (k 维度分两次,一次 16 个 k)
#pragma unroll
for (int k_step = 0; k_step < 2; ++k_step) {
int k_offset = k_step * 16;
uint32_t reg_a[4][4];
uint32_t reg_b[4][2];
// 4 次 ldmatrix A (4 * 16 = 64 行)
ldmatrix_A<BK>(reg_a, smem.As[read_idx], warp_id_m, lane_id, k_offset);
// 4 次 ldmatrix B (4 * 8 = 32 列)
ldmatrix_B<BN, BK>(reg_b, smem.Bs[read_idx], warp_id_n, lane_id, k_offset);
// MMA 核心运算:4x4 次 m16n8k16
mma_compute<T>(sum, reg_a, reg_b);
}
read_idx ^= 1;
write_idx ^= 1;
cp_async_wait_group<0>();
__syncthreads();
}
// ------------------- Epilogue 最后计算一次再写回
#pragma unroll
for (int k_step = 0; k_step < 2; ++k_step) {
int k_offset = k_step * 16;
uint32_t reg_a[4][4];
uint32_t reg_b[4][2];
// 4 次 ldmatrix A (4 * 16 = 64 行)
ldmatrix_A<BK>(reg_a, smem.As[read_idx], warp_id_m, lane_id, k_offset);
// 4 次 ldmatrix B (4 * 8 = 32 列)
ldmatrix_B<BN, BK>(reg_b, smem.Bs[read_idx], warp_id_n, lane_id, k_offset);
// MMA 核心运算:4x4 次 m16n8k16
mma_compute<T>(sum, reg_a, reg_b);
}
write_c_via_smem<BM, BN>(c, by, bx, n, sum, warp_id_m, warp_id_n, lane_id, tid, smem.Cs);
}其中 cp_async_load_A,cp_async_load_B,ldmatrix_A,ldmatrix_B,mma_compute,write_c_via_smem 都改成了函数,具体不贴了,免得文章冗长,详细代码还请移步 github 查看。
# 2. hgemm_k_stages
我们稍微说明一下上一个 kernel 的 smem 的部分细节,BMxBNxBK 是 128x128x32,双 buffer 流水线,所以一共 是 128x32x2x2x2 = 32KB,如果增加一级流水线,那么就需要 128322*2 = 16KB,加起来 48KB,正好是我一个 block 的所能使用的 smem 上限(总上限 100KB,单个 block 上限 48KB)。
因此我想先实现一个三级流水线 kernel,看看是否有收益。 三级流水线流程
- prologue:cp.async 预先发起加载两个 stage 的 buffer 的指令,即 commit 两个 group,wait 直到第一个 group 完成加载
- 主循环 main loop
- 发起加载最后一个 stage 的 buffer,commit group
- 开始 ldmatrix + mma 计算
- 全局指针步进,流水线往下推进一级,同样等待最早的那个 group 加载完毕
- epilogue:
- 等待两个 groupcopy 完成
- 执行计算
- 利用 smem bufer 中转写回 gmem
代码:
// 这里 As/Bs 的 Tiling 策略维持在 128x128x32,为了压进单个 Block 48KB 的 SMEM 限制,我们只能做到 3 Stage,这是 SM120 上的物理极限。
template <const int BM = 128, const int BN = 128, const int BK = 32, const int STAGES = 3, typename T>
__global__ void hgemm_k_stages_kernel(T *a, T *b, T *c, int m, int n, int k) {
// grid swizzling
int linear_id = blockIdx.y * gridDim.x + blockIdx.x;
const int SWIZZLE_W = 8; // 将执行块设置为 8 的宽度
int bx = (linear_id % SWIZZLE_W) + (linear_id / (SWIZZLE_W * gridDim.y)) * SWIZZLE_W;
int by = (linear_id / SWIZZLE_W) % gridDim.y;
int tid = threadIdx.x; // 0~255
int warp_id = tid / WARP_SIZE;
int lane_id = tid % WARP_SIZE;
// 搬运映射
int load_a_row = tid / 4; // 0~63
int load_a_col = (tid % 4) * 8; // 0,8,16,24
int load_b_row = tid / 16; // 0~15 (K 维度)
int load_b_col = (tid % 16) * 8; // 0,8,16 ... 120 (N 维度)
// A/B 都行优先,用 union 复用同一块内存,写法优雅
__shared__ __align__(128) union {
// 前半段计算用的 A 和 B
struct {
T As[STAGES][BM][BK];
T Bs[STAGES][BK][BN];
};
// 后半段写回用的 C
T Cs[BM][BN];
} smem;
// warp tiling
// 每个 warp 负责 64 x 32 的 C 矩阵块
int warp_id_m = warp_id / 4; // 0, 1
int warp_id_n = warp_id % 4; // 0, 1, 2, 3
// 寄存器总量:M 维 4 块 * N 维 4 块 * 每块 4 个寄存器 = 64
float sum[4][4][4] = {0.f};
T *global_a_ptr = &a[(by * BM + load_a_row) * k + load_a_col];
T *global_b_ptr = &b[load_b_row * n + bx * BN + load_b_col];
// 1. prologue: 加载 stages-1 个 As/Bs 块
#pragma unroll
for (int i = 0; i < STAGES - 1; ++i) {
cp_async_load_A<BK>(smem.As[i], load_a_row, load_a_col, global_a_ptr, k);
cp_async_load_B<BK, BN>(smem.Bs[i], load_b_row, load_b_col, global_b_ptr, n);
CP_ASYNC_COMMIT_GROUP();
global_a_ptr += BK;
global_b_ptr += BK * n;
}
// commit 了两个 group, 允许 1 个 group 后台在 cp.async, 即等最早加载的 load 完毕
cp_async_wait_group<STAGES - 2>();
__syncthreads();
// 状态指针初始化
int load_stage = STAGES - 1; // 下一个要 Load 的位置
int compute_stage = 0; // 当前要 Compute 的位置
// 2. main loop
for (int bk = (STAGES - 1) * BK; bk < k; bk += BK) {
// 1. 先发起 cp.async load As/Bs 到 load_stage
cp_async_load_A<BK>(smem.As[load_stage], load_a_row, load_a_col, global_a_ptr, k);
cp_async_load_B<BK, BN>(smem.Bs[load_stage], load_b_row, load_b_col, global_b_ptr, n);
CP_ASYNC_COMMIT_GROUP();
// 2. Tensor Core 计算阶段 (k 维度分两次,一次 16 个 k)
#pragma unroll
for (int k_step = 0; k_step < 2; ++k_step) {
int k_offset = k_step * 16;
uint32_t reg_a[4][4];
uint32_t reg_b[4][2];
// 4 次 ldmatrix A (4 * 16 = 64 行)
ldmatrix_A<BK>(reg_a, smem.As[compute_stage], warp_id_m, lane_id, k_offset);
// 4 次 ldmatrix B (4 * 8 = 32 列)
ldmatrix_B<BN, BK>(reg_b, smem.Bs[compute_stage], warp_id_n, lane_id, k_offset);
// MMA 核心运算:4x4 次 m16n8k16
mma_compute<T>(sum, reg_a, reg_b);
}
// 推进指针
global_a_ptr += BK;
global_b_ptr += BK * n;
// 流水线往下推一级
load_stage = (load_stage + 1 == STAGES) ? 0 : load_stage + 1;
compute_stage = (compute_stage + 1 == STAGES) ? 0 : compute_stage + 1;
// 保障最早的 group load 好
cp_async_wait_group<STAGES - 2>();
__syncthreads();
}
// 3. epilogue 最后计算 stages-1 次 再写回
cp_async_wait_group<0>();
__syncthreads();
#pragma unroll
for (int i = 0; i < STAGES - 1; ++i) {
#pragma unroll
for (int k_step = 0; k_step < 2; ++k_step) {
int k_offset = k_step * 16;
uint32_t reg_a[4][4];
uint32_t reg_b[4][2];
// 4 次 ldmatrix A (4 * 16 = 64 行)
ldmatrix_A<BK>(reg_a, smem.As[compute_stage], warp_id_m, lane_id, k_offset);
// 4 次 ldmatrix B (4 * 8 = 32 列)
ldmatrix_B<BN, BK>(reg_b, smem.Bs[compute_stage], warp_id_n, lane_id, k_offset);
// MMA 核心运算:4x4 次 m16n8k16
mma_compute<T>(sum, reg_a, reg_b);
}
compute_stage = (compute_stage + 1 == STAGES) ? 0 : compute_stage + 1;
}
write_c_via_smem<BM, BN>(c, by, bx, n, sum, warp_id_m, warp_id_n, lane_id, tid, smem.Cs);
}ok,到现在我们还有大动作,所以没有什么特别可说的,下一小节见
# 3. hgemm_tma_r_k_stages_kernel
在我刚开始准备手搓 TMA copy + ldmatrix,我面临着两个问题
- 手搓 TMA copy 如何实现?
- TMA swizzle 如何能和我自己手写的 swizzle 对齐
第一个问题,经过一番坎坷的学习过程(研究 tensorMap 和 cp.async.bulk + mbarrier),我终于了解了整个流程,首先要在 host 端创建 CUtensorMap,用来描述 gmem 中的矩阵的 tiling size、内/外层循环的 dim/stride 等等,然后在 kernel 端使用 cp.aysnc.bulk 发起 TMA copy 请求 + mbarrier 相关命令进行同步。
# 3.1 host 端
这里给一下 host 端 tensorMap 创建方法,我们会用到 cuTensorMapEncodeTiled,其定义和解释如下:
CUresult CUDAAPI
cuTensorMapEncodeTiled(CUtensorMap *tensorMap, // 你要创建的 tensormap 指针
CUtensorMapDataType tensorDataType, //数据类型,bf16/fp16
cuuint32_t tensorRank, //矩阵的秩 (2D 矩阵传 2)
void *globalAddress, //矩阵首地址
const cuuint64_t *globalDim, // 全局形状。注意:必须把最内层连续维度放在第 0 位!行优先 A(MxK) 传 {K, M}
const cuuint64_t *globalStrides, // 步幅数组 (长度为 Rank-1)。2D 矩阵只需传 1 个元素:行跨度的字节数 (如 K * sizeof(half)
const cuuint32_t *boxDim, // TMA 一次搬运的 Tile 大小。维度顺序必须与 globalDim 严格对齐 (如 {BK, BM})
const cuuint32_t *elementStrides, //元素间跨度。对于密集矩阵,各维度全为 1,传 {1, 1}
CUtensorMapInterleave interleave,// 数据交错模式,一般选 NONE
CUtensorMapSwizzle swizzle, // swizzle 类型,可选 None,32B,64B,128B,还有几种特殊的没仔细研究
CUtensorMapL2promotion l2Promotion, // L2 缓存驻留粒度,推荐与缓存行对齐的 128B
CUtensorMapFloatOOBfill oobFill); // 越界填充策略。选 NONE 时,TMA 硬件会自动在边界外填充 0,省去繁琐的越界判断硬件 TMA 最强大的地方就在于,如果你请求的 boxDim 越过了 globalDim 的边界(比如矩阵边缘 Padding),TMA 硬件会自动帮你把越界的地方塞满 0,完全不需要你在 Kernel 里写 if (x < M && y < N) 这种恶心的边界判断分支,极大地释放了 ALU 算力!写起 kernel 也极其丝滑(前提是 smem 有多余空间),我的卡就没这个福了
好,接下来我们要用这个函数创建 a/b 矩阵的 tensorMap。
问题来了,TMA 引擎在开启 128B Swizzle(这是避免 Bank Conflict 的命脉)时,有一个冷酷的硬件限制——它要求你请求的 boxDim 中,最内层的连续维度(Fastest Changing Dimension)大小必须不多不少,正好是 128 字节!
但我们的 Tiling 策略是 BN = 128。我们用的是 fp16/bf16,每个元素 2 字节,这意味着 B 矩阵一个 Tile 的行宽高达 256 字节!如果直接把 {BN, BK} 塞给 TMA,它会当场罢工,抛出 CUDA_ERROR_INVALID_VALUE。(As 的行是 BK,不论 32 还是 64 都没有超过 128B,所以没有这个问题)
怎么办?为了满足硬件的要求,我们这里用了一个小技巧:把 B 的 Tile 物理上劈成两半(Chunking),用两次 TMA 发射来完成一次逻辑上的搬运。 看代码:
template <typename T, const int rowBytes = 128>
inline CUtensorMap
create_tensor_map(T *global_address, uint64_t fast_dim, uint64_t slow_dim, uint32_t fast_box, uint32_t slow_box) {
CUtensorMap tmap;
CUtensorMapDataType type =
std::is_same_v<T, __half> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT16 : CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
CUtensorMapSwizzle swizzle = rowBytes == 128 ? CU_TENSOR_MAP_SWIZZLE_128B : CU_TENSOR_MAP_SWIZZLE_64B;
// TMA 的核心逻辑:第 0 维永远是内存里最连续的维度 (Fastest Changing Dimension)
uint64_t globalDim[2] = {fast_dim, slow_dim};
uint64_t globalStrides[1] = {fast_dim * sizeof(T)}; // 外层维度的跨度(字节)
uint32_t boxDim[2] = {fast_box, slow_box};
uint32_t elementStrides[2] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(&tmap,
type,
2, // Tensor Rank (二维矩阵)
global_address,
globalDim,
globalStrides,
boxDim,
elementStrides,
CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle, // 对应 swizzle
CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
TORCH_CHECK(res == CUDA_SUCCESS, "cuTensorMapEncodeTiled failed!");
return tmap;
}
// ---------------- tma func binding
#define binding_tiled_tma_func_gen(name, BK) \
void name##_##BK(torch::Tensor a, torch::Tensor b, torch::Tensor c) { \
CHECK_T(a); \
CHECK_T(b); \
CHECK_T(c); \
const int M = a.size(0); \
const int K = a.size(1); \
const int N = b.size(1); \
const int BM = 128; \
const int BN = 128; \
const int threads_per_block = 256; \
const dim3 blocks_per_grid((N + BN - 1) / BN, (M + BM - 1) / BM); \
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const int smem_size = BM * BK * 2 * 3 * 2 + 24; \
if (a.dtype() == torch::kHalf) { \
CUtensorMap tma_a = \
create_tensor_map<__half, BK * 2>(reinterpret_cast<__half *>(a.data_ptr()), K, M, BK, BM); \
CUtensorMap tma_b = create_tensor_map<__half>(reinterpret_cast<__half *>(b.data_ptr()), N, K, BN / 2, BK); \
\
cudaFuncSetAttribute( \
name##_kernel<BM, BN, BK, 3, __half>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
name##_kernel<BM, BN, BK, 3><<<blocks_per_grid, threads_per_block, smem_size, stream>>>( \
tma_a, tma_b, reinterpret_cast<__half *>(c.data_ptr()), M, N, K); \
} else { \
CUtensorMap tma_a = create_tensor_map<__nv_bfloat16, BK * 2>( \
reinterpret_cast<__nv_bfloat16 *>(a.data_ptr()), K, M, BK, BM); \
CUtensorMap tma_b = \
create_tensor_map<__nv_bfloat16>(reinterpret_cast<__nv_bfloat16 *>(b.data_ptr()), N, K, BN / 2, BK); \
cudaFuncSetAttribute( \
name##_kernel<BM, BN, BK, 3, __nv_bfloat16>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \
name##_kernel<BM, BN, BK, 3><<<blocks_per_grid, threads_per_block, smem_size, stream>>>( \
tma_a, tma_b, reinterpret_cast<__nv_bfloat16 *>(c.data_ptr()), M, N, K); \
} \
}为了复用代码,我写了一个宏和一个模板函数。但我们主要看 tma_b 的创建,可以看到我传了 boxDim{BN/2,BK}进去,这样一行就是 128B 了。
# 3.2 kernel 端
那么在 kernel 内如何分配 Bs smem 呢,我使用了 Bs[2][BK][BN/2]。
// 使用动态共享数组
extern __shared__ __align__(128) uint8_t smem_buf[];
T(*As)[BM][BK] = reinterpret_cast<T(*)[BM][BK]>(smem_buf);
T(*Bs)[2][BK][BN / 2] = reinterpret_cast<T(*)[2][BK][BN / 2]>(smem_buf + STAGES * BM * BK * sizeof(T));
T(*Cs)[BN] = reinterpret_cast<T(*)[BN]>(smem_buf);为什么要写成 Bs[2][BK][BN/2]? 从字面意义上看,它是把一块大内存物理地切成了左右两半(Chunk 0 和 Chunk 1)。
- Chunk 0 负责装载第 0 ~ 63 列。
- Chunk 1 负责装载第 64 ~ 127 列。
这样切分后,每一个 Chunk 的行宽正好是 64 个元素(即 128 字节)。我们用数组声明的方式,强行在逻辑上契合了老黄刻在硅片底层的 128B 物理边界!当然,这种切分会对后续的 TMA 写入和 ldmatrix 读取带来寻址上的麻烦,但别慌,后文我们会用地址映射来化解这个问题。
# 3.2.1 cp.async.bulk 和 mbarrier
在全面开始 kernel 讲解之前,先了解一下我们会用到的 ptx 指令,主要是:
- mbarrier
- 我们 3 级流水线,所以需要 3 个 mbarrier,mbarrier 是 smem 上的 8 字节变量,用于线程同步,一般 init 和 arrive 对称出现
- cp.async.bulk
- 用于发起 TMA 拷贝,给出矩阵 tile 的左上角基地址,tma 就会开始自动异步搬运一整块 tile
具体黑魔法 ptx 代码如下,我加了一些我们这个 case 用到的详细说明:
// MBarrier 类型定义 (硬件要求 8 字节对齐)
typedef uint64_t mbarrier_t;
// 在 smem 上初始化 64 位的屏障变量 (只需在 prologue 中由单线程调用)
// 在 TMA 模式下,真正的“数据生产者”是硬件 DMA 引擎。我们的单线程只负责下达搬运指令,mbarrier 只需要等待【TMA 硬件这 1 个实体】把数据搬完并自动打卡,所以 expected_count 设置为 1。
__device__ __forceinline__ void mbarrier_init(mbarrier_t *mbar, uint32_t expected_count) {
asm volatile("mbarrier.init.shared.b64 [%0], %1;\n" ::"r"(static_cast<uint32_t>(__cvta_generic_to_shared(mbar))),
"r"(expected_count));
}
// 设定 TMA 传输的预期字节数,需要向指定的 mbar 汇报
// 传统同步是等“线程”arrive,这里是等“字节”arrive。硬件搬完这么多字节后,会自动触发 arrive 翻转 phase。
__device__ __forceinline__ void mbarrier_expect_tx(mbarrier_t *mbar, uint32_t tx_bytes) {
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;\n" ::"r"(
static_cast<uint32_t>(__cvta_generic_to_shared(mbar))),
"r"(tx_bytes));
}
// 计算线程消费完数据后,提交到达信号 (翻转 Phase)。我们用不到,因为没有 warp specialized,或者说所有 warp 都是消费者,直接使用 __syncthreads() 同步了
// 在 warp specialized 编程中,消费者线程会向对应 mbar 汇报数据消费完了
__device__ __forceinline__ void mbarrier_arrive(mbarrier_t *mbar) {
asm volatile("mbarrier.arrive.shared.b64 _, [%0];\n" ::"r"(static_cast<uint32_t>(__cvta_generic_to_shared(mbar))));
}
// 计算线程同步等待 TMA 数据就绪 (自带休眠,不占 ALU 算力)
// 有点天书,主要逻辑是 :
// 申请一个名为 p 的临时谓词寄存器(布尔值),用于存储 mbarrier 状态检查的结果。
// mbarrier.try_wait.parity.shared::cta.b64 p, [%0], %1, %2:判断 mbar 内部 phase 标志是否和给定的 phase 一致,不一致说明 tma 还没完成 copy,则挂起 ticks 周期,让渡出 ALU 给其他 warp 调度
// @p bra DONE:如果 p 为 true,即说明 mbarrier phase 已经反转,tma 完成 copy,直接跳跃到 DONE 标签之后的指令
// bra LAB_WAIT: 如果走到这里,说明是那一千万个时钟周期超时了(极小概率),那就跳回 LAB_WAIT 继续休眠。
// 最后的"memory" 是编译器内存屏障(破坏性描述符)。它防止编译器过度优化读到脏数据,作用是提示编译器:TMA 硬件刚刚在后台偷偷篡改了共享内存(As, Bs 矩阵),该指令之后所有寄存器缓存的共享内存变量必须强制失效,重新读取。
_device__ __forceinline__ void mbarrier_wait(uint64_t* mbar, uint32_t phase) {
uint32_t mbar_addr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar));
// 设定一个极大的挂起超时周期(0x989680 = 10,000,000 个时钟周期)
uint32_t ticks = 0x989680;
asm volatile(
"{\n\t"
".reg .pred p; \n\t"
"LAB_WAIT: \n\t"
// 注意这里的第三个参数 %2
"mbarrier.try_wait.parity.shared::cta.b64 p, [%0], %1, %2; \n\t"
"@p bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}\n"
:
: "r"(mbar_addr), "r"(phase), "r"(ticks)
: "memory"
);
}
// global to shared::cta 2d TMA 搬运
__device__ __forceinline__ void cp_async_bulk_tensor_2d(
mbarrier_t *mbar, const void *tmap, const void *smem_ptr, int32_t fast_coord, int32_t slow_coord) {
uint32_t smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
uint32_t mbar_addr = static_cast<uint32_t>(__cvta_generic_to_shared(mbar));
asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%2, %3}], [%4];\n" ::"r"(smem_addr),
"l"(tmap),
"r"(fast_coord),
"r"(slow_coord),
"r"(mbar_addr)
: "memory");
}所以 kernel 的主要 TMA 相关逻辑就是(在我们这个非 warp 特化的开发模式下,让线程 0 负责 TMA 调度)
- copy 线程(线程 0):
- 初始化 3 个 stage 的 mbarrier 变量,
- 并向 mbarrier 汇报一个 stage 要搬运的数据量
- 使用 tensorMap,发起 TMA copy 请求(cp.async.bulk)
- 计算线程(所有线程):
- 共同轮询等待(mbarrier_wait),对应 stage 的 mbarrier 翻转为指定 phase 时,线程全都被唤醒
- 开始 ldmatrix+mma 计算
# 3.3 kernel 设计
知道大概流程后,可以着手设计 kernel 了,其实上面还有一个问题我还没回复,就是 TMA 的 swizzle 和我的手写 swizzle 如何对齐,我没有花精力找到相关资料确定 TMA 硬件 swizzle 方式,所以只是猜,但猜就有概率问题。
为了减少测试量。我做出了一个违背“祖宗”的决定。我放弃了 BMxBNxBK=128x128x32 的 tiling,而是将 BK 设置为 64!这样 a 矩阵的一个 tile 一行也是 128B,b 矩阵经过我们劈开一个 chunk 的一行也是 128B,只要对齐 128B 的 TMA 和 ldmatrix swizzle,就能跑通了。
代价就是我需要开 12864232 = 96 KB 的 smem,只能驻留一个 block,还要用 cuda 魔法 cudaFuncSetAttribute + 动态 smem 数组。
此外,由于 TMA 只需要一条指令,解放了 unroll 的所有地址变量寄存器,所以我激进地加入了 双 buffer 寄存器读取 As/Bs(虽然作用不大)
这里给出完整 kernel 代码:
template <const int BK, typename T>
__device__ __forceinline__ void
ldmatrix_A_tma(uint32_t reg_a[4][4], T (*As)[BK], int warp_id_m, int lane_id, int k_offset) {
// 4 次 ldmatrix A (4 * 16 = 64 行)
#pragma unroll
for (int m_idx = 0; m_idx < 4; ++m_idx) {
int a_row = warp_id_m * 64 + m_idx * 16 + (lane_id % 16);
int a_col = k_offset + (lane_id / 16) * 8;
if constexpr (BK == 32) {
uint32_t smem_addr =
static_cast<uint32_t>(__cvta_generic_to_shared(&As[a_row][SWIZZLE_64B_TMA(a_row, a_col)]));
LDMATRIX_X4(reg_a[m_idx][0], reg_a[m_idx][1], reg_a[m_idx][2], reg_a[m_idx][3], smem_addr);
} else {
uint32_t smem_addr =
static_cast<uint32_t>(__cvta_generic_to_shared(&As[a_row][SWIZZLE_128B_TMA(a_row, a_col)]));
LDMATRIX_X4(reg_a[m_idx][0], reg_a[m_idx][1], reg_a[m_idx][2], reg_a[m_idx][3], smem_addr);
}
}
}
template <const int BN, const int BK, typename T>
__device__ __forceinline__ void
ldmatrix_B_tma(uint32_t reg_b[4][2], T (*Bs)[BK][BN / 2], int warp_id_n, int lane_id, int k_offset) {
#pragma unroll
for (int n_idx = 0; n_idx < 4; ++n_idx) {
int b_row = k_offset + (lane_id % 16);
int b_col = warp_id_n * 32 + n_idx * 8;
// 这里要区分 chunk
int chunk_idx = b_col / (BN / 2);
int local_col = b_col % (BN / 2);
uint32_t smem_addr =
static_cast<uint32_t>(__cvta_generic_to_shared(&Bs[chunk_idx][b_row][SWIZZLE_128B_TMA(b_row, local_col)]));
LDMATRIX_X2_TRANS(reg_b[n_idx][0], reg_b[n_idx][1], smem_addr);
}
}
// ------------------- tma r + mma -------------------
// a block calculate c[128][128]
template <const int BM = 128, const int BN = 128, const int BK = 64, const int STAGES = 3, typename T>
__global__ void hgemm_tma_r_k_stages_kernel(
__grid_constant__ const CUtensorMap tma_a, __grid_constant__ const CUtensorMap tma_b, T *c, int m, int n, int k) {
// grid swizzling
int linear_id = blockIdx.y * gridDim.x + blockIdx.x;
const int SWIZZLE_W = 8; // 将执行块设置为 8 的宽度
int bx = (linear_id % SWIZZLE_W) + (linear_id / (SWIZZLE_W * gridDim.y)) * SWIZZLE_W;
int by = (linear_id / SWIZZLE_W) % gridDim.y;
int tid = threadIdx.x; // 0~255
int warp_id = tid / WARP_SIZE;
int lane_id = tid % WARP_SIZE;
// 使用动态共享数组
extern __shared__ __align__(128) uint8_t smem_buf[];
T(*As)[BM][BK] = reinterpret_cast<T(*)[BM][BK]>(smem_buf);
T(*Bs)[2][BK][BN / 2] = reinterpret_cast<T(*)[2][BK][BN / 2]>(smem_buf + STAGES * BM * BK * sizeof(T));
T(*Cs)[BN] = reinterpret_cast<T(*)[BN]>(smem_buf);
// 把 mbar 放在末尾 ( 8 字节对齐,3 个 stages)
mbarrier_t *mbar = reinterpret_cast<mbarrier_t *>(smem_buf + BM * BK * sizeof(T) * STAGES * 2);
// 初始化 MBarrier (仅需 tid 0 执行,期待到达次数为 1,因为只有 TMA 会给它发信号)
if (tid == 0) {
for (int i = 0; i < STAGES; ++i)
mbarrier_init(&mbar[i], 1);
}
__syncthreads(); // 保证 MBarrier 初始化完毕
// warp tiling
// 每个 warp 负责 64 x 32 的 C 矩阵块
int warp_id_m = warp_id / 4; // 0, 1
int warp_id_n = warp_id % 4; // 0, 1, 2, 3
// 寄存器总量:M 维 4 块 * N 维 4 块 * 每块 4 个寄存器 = 64
float sum[4][4][4] = {0.f};
// 每次 TMA 需要搬运的总字节数
const uint32_t tx_bytes = (BM * BK + BK * BN) * sizeof(T);
// 只保留一个极其简单的坐标跟踪变量 (因为 TMA 的 Host 描述符里已经知道了跨度)
int load_k_coord = 0;
// 1. prologue 加载 STAGES - 1 块
for (int i = 0; i < STAGES - 1; ++i) {
if (tid == 0) {
// 设定这个 mbarrier 需要等多少字节的数据落盘
mbarrier_expect_tx(&mbar[i], tx_bytes);
cp_async_bulk_tensor_2d(&mbar[i], &tma_a, As[i], load_k_coord, by * BM);
// Bs 要 copy 两次,分成两个 chunk
cp_async_bulk_tensor_2d(&mbar[i], &tma_b, Bs[i][0], bx * BN, load_k_coord);
cp_async_bulk_tensor_2d(&mbar[i], &tma_b, Bs[i][1], bx * BN + BN / 2, load_k_coord);
}
load_k_coord += BK;
}
int load_stage = STAGES - 1;
int compute_stage = 0;
int wait_phase = 0; // MBarrier 天然的 0/1 交替相位开关
int total_k_step = BK / 16; // 根据 BK 自适应 step
// 2. main loop
for (int bk = (STAGES - 1) * BK; bk < k; bk += BK) {
// 发起下一轮的 TMA (依然只有 tid 0 干活)
if (tid == 0) {
mbarrier_expect_tx(&mbar[load_stage], tx_bytes);
cp_async_bulk_tensor_2d(&mbar[load_stage], &tma_a, As[load_stage], load_k_coord, by * BM);
cp_async_bulk_tensor_2d(&mbar[load_stage], &tma_b, Bs[load_stage][0], bx * BN, load_k_coord);
cp_async_bulk_tensor_2d(&mbar[load_stage], &tma_b, Bs[load_stage][1], bx * BN + BN / 2, load_k_coord);
}
load_k_coord += BK;
// 所有线程:轮询等待当前 compute_stage 的数据被 TMA 搬运完毕
mbarrier_wait(&mbar[compute_stage], wait_phase);
// 寄存器双 buffer: ldmatrix + mma
uint32_t reg_a[2][4][4], reg_b[2][4][2];
ldmatrix_A_tma<BK>(reg_a[0], As[compute_stage], warp_id_m, lane_id, 0);
ldmatrix_B_tma<BN, BK>(reg_b[0], Bs[compute_stage], warp_id_n, lane_id, 0);
int read_idx = 0, write_idx = 1;
#pragma unroll
for (int k_step = 0; k_step < total_k_step; ++k_step) {
if (k_step < total_k_step - 1) {
int next_k_offset = (k_step + 1) * 16;
ldmatrix_A_tma<BK>(reg_a[write_idx], As[compute_stage], warp_id_m, lane_id, next_k_offset);
ldmatrix_B_tma<BN, BK>(reg_b[write_idx], Bs[compute_stage], warp_id_n, lane_id, next_k_offset);
}
mma_compute<T>(sum, reg_a[read_idx], reg_b[read_idx]);
read_idx ^= 1;
write_idx ^= 1;
}
// 直接同步,没有 warp 特化,不需要 arrive
__syncthreads();
// 状态轮转
load_stage = (load_stage + 1 == STAGES) ? 0 : load_stage + 1;
compute_stage = (compute_stage + 1 == STAGES) ? 0 : compute_stage + 1;
// 完成一次三级流水线,三个 mbarrier 的 phase 都反转了,我们也要反转给定的 wait_phase
if (compute_stage == 0)
wait_phase ^= 1;
}
// 3. epilogue 计算 stages-1 次
#pragma unroll
for (int i = 0; i < STAGES - 1; ++i) {
// 继续等 TMA
mbarrier_wait(&mbar[compute_stage], wait_phase);
// 寄存器双 buffer
uint32_t reg_a[2][4][4], reg_b[2][4][2];
ldmatrix_A_tma<BK>(reg_a[0], As[compute_stage], warp_id_m, lane_id, 0);
ldmatrix_B_tma<BN, BK>(reg_b[0], Bs[compute_stage], warp_id_n, lane_id, 0);
int read_idx = 0, write_idx = 1;
#pragma unroll
for (int k_step = 0; k_step < total_k_step; ++k_step) {
if (k_step < total_k_step - 1) {
int next_k_offset = (k_step + 1) * 16;
ldmatrix_A_tma<BK>(reg_a[write_idx], As[compute_stage], warp_id_m, lane_id, next_k_offset);
ldmatrix_B_tma<BN, BK>(reg_b[write_idx], Bs[compute_stage], warp_id_n, lane_id, next_k_offset);
}
mma_compute<T>(sum, reg_a[read_idx], reg_b[read_idx]);
read_idx ^= 1;
write_idx ^= 1;
}
compute_stage = (compute_stage + 1 == STAGES) ? 0 : compute_stage + 1;
if (compute_stage == 0)
wait_phase ^= 1;
}
// 4. 写回
write_c_via_smem<BM, BN>(c, by, bx, n, sum, warp_id_m, warp_id_n, lane_id, tid, Cs);
}ldmatrix Bs 的时候,要先区分一下 chunk,还好我们只是完整的切成两大块,并不复杂。其他的就没有什么了,代码中也基本有注释帮助解释。
此外,细心的朋友可能还注意到我的 kernel 传入 tensorMap 用了__grid_constant__, 这个约束关键字是让 gpu 强制传入这个变量到常量存储。
# 3.3.1 踩坑血泪史:不可忽视的 __grid_constant__
事实上一开始,我在 tensorMap 的构建上就卡住了很久。我遇到了一个诡异的问题:一跑就挂。经过痛苦的排查,我发现是因为 tensorMap 本质上是一个占据 128 字节的复杂结构体(Descriptor)。按照 CUDA 的默认传参规则,这么大的结构体作为参数传入时,很容易被编译器分配到 Local Memory 中。而 TMA 硬件引擎要求 tensorMap 必须存在于全局常量区。
解决办法就是在声明 Kernel 的参数时,对 tensorMap 变量 加上 __grid_constant__ 关键字。这个修饰符,会强制编译器将 TensorMap 存放在 Constant Memory 中,并保证其在整个 Grid 生命周期的可见性与只读性。
初玩 TMA 的朋友,这个坑务必警醒!
# 4. 殊途同归:TMA 与 ldmatrix 的 Swizzle 握手
总之,经过一番痛苦的 debug 过程,我终于跑通了上面那个 kernel, 也通过了 diff_check,说明 TMA 128B swizzle 和我之前对 Bs 进行的 swizzle 一致 :((col) ^ (((row) & 0x7) << 3))
所以说啊,gpu 设计和进化也要讲基本法,我手动 ldmatrix 的 swizzle 既然和 TMA 的一样,说明底层逻辑其实一样的。
我就猜想 64B 的 swizzle 是不是也一样呢?
于是我又对代码做了一点调整,兼容 BK=32/64 两种情况,ldmatrix As 就用我原来的 swizzle_a。然后设置 BK=32,这样 smem 用量减半,又可以有两个 block 活跃,完美。
改完后,一跑,就通过 diff check 了!哈哈,果然皇天不负有心人。TMA 的黑盒 swizzle 逻辑我不知道,但在严格精度 diff 测试下,我的手动 Swizzle 逻辑与 TMA 硬件底层的黑盒行为完美契合,误差为 0。
这是纯粹的逆向工程浪漫——既然硬件不说,我们就用结果去反推逻辑。
# 5.benchmark、ncu report 和分析
本次测试,我预期终版 kernel 会是一个和 baseline 不相上下的结果,所以我进行了相对严格的对比,尽量关闭所有程序+显存显卡锁定频率,然后在低温环境进行测试。
直接贴 benchmark 结果:
####################################################################################################
n: 4096, m: 4096, k: 4096
torch mean time: 4.011336 ms, 34.26 tflops
hgemm_cublas mean time: 4.258727 ms, speedup: 0.94, tflops: 32.27
hgemm_bcf_dbf_rw mean time: 4.042202 ms, speedup: 0.99, tflops: 34.00
hgemm_k_stages mean time: 4.131343 ms, speedup: 0.97, tflops: 33.27
hgemm_tma_r_k_stages_64 mean time: 4.287896 ms, speedup: 0.94, tflops: 32.05
hgemm_tma_r_k_stages_32 mean time: 4.005909 ms, speedup: 1.00, tflops: 34.31我承认,这是我挑了一组终版 kernel 胜利的数据,毕竟花了这么多精力,还放个不如 hgemm_bcf_dbf_rw (实际胜负比 6/4 开吧),多煞风景。
不过在我心里,hgemm_tma_r_k_stages_32 和 hgemm_bcf_dbf_rw 就是相同等级的水平,都已经把 tensor-core 压榨到极限了(虽然前者要复杂得多)
ncu report

这是我最满意的一份 ncu report,虽然终版 kernel 性能不能稳压hgemm_bcf_dbf_rw一头,但是 ncu summary 很干净,没有讨厌的 Uncoalesced Shared Accesses,只有 tensor-core 算力不足,和 register/smem 限制导致的 Occupancy 提示。
hgemm_tma_r_k_stages_640 bank conflict
hgemm_tma_r_k_stages_32也是 0 bank conflict(ncu 还是报了一点点 bank conflict,毕竟双 block 活跃,有一些 warp 间的冲突,但 swizzle 不背锅)
# 一些讨论
- 为什么单纯的 2 级流水改为 3 级流水线性能没提高反而下降了
- 流水线实质就是空间换时间,我推测是 2 级流水线寄存器压力已经很大了,3 级理论要更多的寄存器同时还要保障 2 个 block 活跃,必定重排了指令,牺牲了更多的指令级并行度
- TMA 专用硬件搬运数据确实很有用,尤其是降低寄存器用量,从我的 ncu 报告能看出
- 128x128x32 的 tiling 策略下,即使 3 级流水线,寄存器用量已经从
hgemm_bcf_dbf_rw的 128 降低到了 96,这还是我用了双寄存器缓冲的情况下啊! - 这意味着在 SM120 上,瓶颈已经彻底从‘访存调度’转移到了‘算力吞吐’。TMA 已经做到了它能做的一切,剩下的差距是老黄割掉的那几刀算力。
- 128x128x32 的 tiling 策略下,即使 3 级流水线,寄存器用量已经从
- 但光有 TMA,没有大 smem 和足够的 tensor-core 算力,也没用呀
- ncu report 提示
Math Pipe Throttle Stalls, tensor core 都已经算不过来了 - 再给我 46KB 的 smem,再加些算力,我能再多一个 block,算力吞吐能提升 50%。只能说老黄刀法好啊
- ncu report 提示
- BK 从 64 降低为 32,单个 block 计算量下降一半,block 数量上升一倍,虽然总计算量不变,但实际性能还是更好
- 说明 在 Occupancy 极低的情况下,提高 Occupancy 让硬件来调度是更好的选择
- 终版 kernel 总结
- sm120 的 tma + ldmatrix(移花接木 swizzle、双寄存器 buffer) + mma + 3 级流水线 + grid swizzle kernel
- 我几乎用上了我能理解的一切
# 6. 结束
经过艰苦的 coding,我完成了一个自认完美的 kernel,这张卡已经被我榨干了。在高性能计算的战场上,顶级显卡是靠性能取胜(H100/B100),而我们在移动端显卡上的优化,则是一场在脚踏板上跳舞的微雕战争。虽然性能提升微小,但我们对硬件底层的掌控力,才是开发者最核心的护城河。
这真是 gemm 系列最后一篇了,原本还有个 todo,使用 tma 将 Cs 写回 gmem,但预期收益不大,就没动力了。
所以,gemm 系列完结 ~ 撒花 ~
以上。
如有错误,请大家指正。完整 kernel 和测试代码可以从 github 获取:https://github.com/WingEdge777/vitamin-cuda/tree/main/kernels/hgemm_sm120