首页
avatar

WingEdge777

一文理解 PyTorch 进行分布式应用开发 - 分布式推理入门实战

如今模型越来越大。当模型参数量达到数百亿级别,即使通过 INT4 量化压到单卡能装下权重,推理时的 KV Cache 和激活值也会随 batch size 和序列长度线性增长,单卡显存很快捉襟见肘——多卡分布式推理几乎是必经之路。

使用 PyTorch 进行分布式应用开发 - 分布式推理入门实战

但想直接阅读 Megatron-LM、DeepSpeed、vllm、sglang 源码学习时,错综复杂的工程封装往往让人望而生畏。本文抛开所有工程包袱,仅用不到 200 行核心 PyTorch 代码,带你从零手搓大模型分布式的三大核心流派的实现。

本文使用一个简单的两层 Linear 模型,两个 GPU,手把手带你理解并实现三种主流并行推理模式:数据并行 (DP)、张量并行 (TP)、流水线并行 (PP)。同时了解基本的 torch.distributed 库的通信原语使用。

环境要求: PyTorch >= 2.0,2× NVIDIA GPU,NCCL backend。推荐使用 NVIDA 的 NGC PyTorch 镜像

目录

  1. 环境准备与基础概念
  2. 模型定义
  3. 数据并行 (DP)
  4. 张量并行 (TP)
  5. 流水线并行 (PP)
  6. 运行结果
  7. 三种并行对比

1. 环境准备与基础概念

1.1 分布式初始化

import os
import torch
import torch.distributed as dist

def setup_dist(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup_dist():
    dist.destroy_process_group()

关键概念:

  • rank: 进程编号,0 或 1
  • world_size: 总进程数,本文固定为 2
  • NCCL: NVIDIA 集合通信库,GPU 间通信的高效实现

1.2 进程启动

import torch.multiprocessing as mp

# 主进程初始化模型和数据(保持在 CPU 上,避免 CUDA 跨进程问题)
torch.manual_seed(42)
model = TwoLayerModel(hidden_size)
x = torch.randn(batch_size, hidden_size)

# 计算参考输出
with torch.inference_mode():
    ref_output = model(x)

# 启动多进程,rank 自动作为第一个参数传入
mp.spawn(run_demo, args=(world_size, model, x, ref_output), nprocs=world_size, join=True)

注意: 模型和数据必须保留在 CPU 上创建,由各子进程自行 .cuda() 到对应设备。mp.spawn 创建的子进程无法直接共享主进程的 CUDA 上下文。

1.3 通信原语

本文用到的通信原语:

原语类型说明
all_gather集合通信每个 rank 收集所有 rank 的数据
all_reduce集合通信所有 rank 的数据做规约操作(如求和)
isend/irecv点对点通信异步发送/接收数据
batch_isend_irecv点对点通信批量提交多个异步 send/recv 操作
all_gather:
Rank 0: [a] ──┐
              ├──> [[a], [b]] (所有 rank 都有)
Rank 1: [b] ──┘

all_reduce(SUM):
Rank 0: 10 ──┐
             ├──> 30 (所有 rank 都有)
Rank 1: 20 ──┘

isend/irecv:
Rank 0               Rank 1
  │                    │
  ├─── isend ─────────>│
  │<──────── irecv ────┤
  │                    │

2. 模型定义

class TwoLayerModel(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size, bias=False),
            nn.Linear(hidden_size, hidden_size, bias=False),
        ])

    def forward(self, x):
        x = torch.relu(self.layers[0](x))
        x = self.layers[1](x)
        return x

使用 ModuleList 便于后续切分不同层。同时设置 bias=False 是为了在演示张量并行(TP)时,无需额外处理偏置项的切分逻辑,从而降低入门的理解成本。


3. 数据并行 (DP)

3.1 原理

       输入 [B, H]

      ┌─────┴─────┐
      ▼           ▼
┌──────────┐ ┌──────────┐
│  Rank 0  │ │  Rank 1  │
│ 完整模型 │ │ 完整模型 │
│ Batch 0  │ │ Batch 1  │
└────┬─────┘ └────┬─────┘
     │            │
     ▼            ▼
 [B/2, H]     [B/2, H]
     │            │
     └─────┬──────┘

       all_gather


      [B, H] 输出

每个 GPU 持有完整模型副本,数据按 batch 切分,通过 all_gather 收集结果。

3.2 代码实现

class DPModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.world_size = dist.get_world_size()
        self.rank = dist.get_rank()
        self.gathered = None

    def forward(self, x):
        local_batch = x.shape[0] // self.world_size
        start_idx = self.rank * local_batch
        end_idx = (self.rank + 1) * local_batch
        local_x = x[start_idx:end_idx]
        out = self.model(local_x)
        if self.gathered is None:
            self.gathered = [torch.empty_like(out) for _ in range(self.world_size)]
        dist.all_gather(self.gathered, out)
        return torch.cat(self.gathered, dim=0)

要点:

  • 每个 rank 只取自己负责的 batch 切片进行计算
  • gathered 缓冲区在首次 forward 时按实际输出形状动态分配
  • all_gather 后拼接得到完整输出

训练 vs 推理: 本文的 DP 聚焦推理场景,核心是 all_gather 收集各 rank 的输出。训练场景下的 DP(如 DistributedDataParallel)核心则是 all_reduce 同步梯度。

3.3 all_gather

Rank 0: [a, b] ──┐
                 ├──> all_gather ──> [[a,b], [c,d]] (两个 rank 都有)
Rank 1: [c, d] ──┘

4. 张量并行 (TP)

4.1 原理

Column Parallel: 权重按列切分,输出在 hidden 维度分片 Row Parallel: 权重按行切分,输入 hidden 维度分片,all_reduce 聚合

Column Parallel Linear:
X [B,H]      W [H,H]

        ┌───────┴───────┐
        ▼               ▼
   W_0 [H/2,H]     W_1 [H/2,H]
        │               │
        ▼               ▼
    X @ W_0^T       X @ W_1^T
     [B,H/2]         [B,H/2]

Row Parallel Linear:
X [B,H] ─切分─> X_0 [B,H/2]    X_1 [B,H/2]
                    │              │
                    ▼              ▼
               X_0 @ W_0^T    X_1 @ W_1^T
                    │              │
                    └──────┬───────┘

                    all_reduce(SUM)


                      [B, H] 输出

Column + Row 组合是 TP 的精髓所在:Column Parallel 的输出 [B, H/2] 恰好是 Row Parallel 需要的分片输入,两者天然衔接,中间无需任何通信,只在最后 all_reduce 一次。

在真实的 LLM 中,Column Parallel 和 Row Parallel 通常被应用于 Attention 模块的 QKV linear 投影 和 Out linear 投影。在做切分时,通常是按注意力头的数量进行切分,保证每个 GPU 独立计算一部分头,最后再 all_reduce 汇总。

4.2 代码实现

class ColumnParallelLinear(nn.Module):
    def __init__(self, full_linear):
        super().__init__()
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        out_per_rank = full_linear.out_features // world_size
        self.weight = nn.Parameter(full_linear.weight.detach()[rank*out_per_rank : (rank+1)*out_per_rank].clone())

    def forward(self, x):
        return x @ self.weight.t()

class RowParallelLinear(nn.Module):
    def __init__(self, full_linear):
        super().__init__()
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        in_per_rank = full_linear.in_features // world_size
        self.weight = nn.Parameter(full_linear.weight.detach()[:, rank*in_per_rank : (rank+1)*in_per_rank].clone())

    def forward(self, x):
        out = x @ self.weight.t()
        dist.all_reduce(out, op=dist.ReduceOp.SUM)
        return out

class TPModel(nn.Module):
    def __init__(self, full_model):
        super().__init__()
        self.fc1 = ColumnParallelLinear(full_model.layers[0])
        self.fc2 = RowParallelLinear(full_model.layers[1])

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

4.3 all_reduce

Rank 0: 10 ──┐
             ├──> all_reduce(SUM) ──> 30 (两个 rank 都有)
Rank 1: 20 ──┘

5. 流水线并行 (PP)

5.1 原理

时间 T0:
┌──────────┐    ┌──────────┐
│  Rank 0  │    │  Rank 1  │
│ Layer 0  │    │  (等待)  │
│  计算中  │    │          │
└────┬─────┘    └──────────┘
     │ isend

   激活值

时间 T1:
┌──────────┐    ┌──────────┐
│  Rank 0  │    │  Rank 1  │
│  (等待)  │    │ Layer 1  │
│          │    │  计算中  │
└──────────┘    └────┬─────┘
     ▲               │ isend
     │               ▼
     └──irecv     最终输出

PP 将模型按层切分到不同 GPU,数据像流水线一样依次穿过。但这种天然的顺序依赖会导致一个致命代价:流水线气泡 (Pipeline Bubble)。 如图所示,当 Rank 0 在埋头苦算时,Rank 1 只能干瞪眼等待。两卡 PP 的理论空转时间高达 50%,这也解释了为什么在我们的基准测试中,PP 的延迟远高于 TP。 为了榨干硬件算力、压缩气泡,工业界引入了微批次 (Micro-batch) 技术。其核心思想是:不要等一整批数据 (Batch) 慢吞吞地全走完 Layer 0,而是把它切成更小的碎片。 Layer 0 算完第一个小块就赶紧发走,让下游的 GPU 提前‘转起来’。

需要注意的是,填补气泡的调度策略在训练和推理中截然不同:在训练场景下,通常采用 1F1B (一前向一反向) 调度来兼顾显存压力;而在我们关注的纯推理场景中,则演变为 Chunked Prefill (分块预填充) 或连续的前向 Micro-Batch 派发,通过让数据块无缝衔接地流过各层,从而最大化吞吐量。

5.2 代码实现

class PPModel(nn.Module):
    def __init__(self, full_model):
        super().__init__()
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        self.rank = rank
        self.world_size = world_size
        self.is_first = (rank == 0)
        self.is_last = (rank == world_size - 1)
        self.total_layers = len(full_model.layers)

        layers_per_rank = len(full_model.layers) // world_size
        self.start_idx = rank * layers_per_rank
        self.layers = nn.ModuleList([
            full_model.layers[i] for i in range(self.start_idx, (rank+1) * layers_per_rank)
        ])

        self.recv_buffer = None

    def _send(self, tensor, dst):
        reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, tensor.contiguous(), dst)])
        for req in reqs:
            req.wait()

    def _recv(self, src):
        reqs = dist.batch_isend_irecv([dist.P2POp(dist.irecv, self.recv_buffer, src)])
        for req in reqs:
            req.wait()
        return self.recv_buffer

    def forward(self, x):
        if self.recv_buffer is None:
            self.recv_buffer = torch.empty_like(x)

        # 1. 接收阶段
        # 首 rank 使用传入的 x,其余 rank 忽略 x,从前一个 rank 接收激活值
        if self.is_first:
            out = x
        else:
            out = self._recv(self.rank - 1)

        # 2. 计算阶段
        for i, layer in enumerate(self.layers):
            out = layer(out)
            if self.start_idx + i < self.total_layers - 1:
                out = torch.relu(out)

        # 3. 发送阶段
        if self.is_first:
            self._send(out, self.rank + 1)
            return self._recv(self.world_size - 1)
        elif self.is_last:
            self._send(out, 0)
            return None
        else:
            self._send(out, self.rank + 1)
            return None

5.3 为什么用 batch_isend_irecv

你可能会问:为什么不直接用 dist.send() / dist.recv()

  1. NCCL 不支持同步 P2Pdist.send() / dist.recv() 是同步阻塞 API,仅 Gloo 后端可用。NCCL 后端只支持异步的 isend / irecv
  2. 官方推荐的 P2P 方式batch_isend_irecv 是 PyTorch 官方推荐的 NCCL 点对点通信 API。虽然本文每次只提交一个 op,但在更复杂的流水线场景中(如多微批次 1F1B 调度),一个 rank 可能需要同时发送和接收不同微批次的数据,此时将多个 P2P 操作打包提交可以避免因调用顺序不一致导致的死锁。
# _send / _recv 封装了 batch_isend_irecv 的样板代码
def _send(self, tensor, dst):
    reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, tensor.contiguous(), dst)])
    for req in reqs:
        req.wait()

def _recv(self, src):
    reqs = dist.batch_isend_irecv([dist.P2POp(dist.irecv, self.recv_buffer, src)])
    for req in reqs:
        req.wait()
    return self.recv_buffer

6. 运行结果

python test.py
==================================================
Demo 0: Single GPU
==================================================
[Single GPU] Output match reference: True, Avg time: 10.009 ms

==================================================
Demo 1: DP
==================================================
[DP] Output match reference: True, Avg time: 5.123 ms

==================================================
Demo 2: TP
==================================================
[TP] Output match reference: True, Avg time: 5.748 ms

==================================================
Demo 3: PP
==================================================
[PP] Output match reference: True, Avg time: 11.032 ms

7. 三种并行对比

维度DPTPPP
切分对象数据模型参数模型层
模型副本完整部分参数部分层
内存占用
通信算子all_gatherall_reduceisend/irecv
通信频率每次推理一次每组 Column+Row 一次每阶段一次
适用场景小模型大 batch大 hidden超深模型
延迟特点低延迟低延迟存在 bubble

选择建议:

  • DP: 模型能单卡放下,需要高吞吐
  • TP: 单层参数太大,追求低延迟
  • PP: 层数太多,可接受一定延迟

💡 LLM 推理的阶段特性: LLM 推理特有的两个阶段。Prefill(预填充) 阶段是计算密集型的,大 Batch 进来时,TP 能有效降低首字延迟(TTFT);而 Decode(解码) 阶段是访存密集型的,此时按 TP 切分不仅切分了计算,也相当于把庞大的 KV Cache 分摊到了多张卡的显存中。

实际大模型推理通常组合使用:TP + PP 处理超大模型,DP 提升吞吐。理解了这三种基础范式,再去阅读 Megatron-LM/DeepSpeed 或 vllm/sglang 源码,就不再是看天书了。

8. 结束

本文为了聚焦核心逻辑,所有代码均在单机双卡环境下运行。但在真实的工业界场景中,通常通过环境变量配置结合 torchrun 即可利用 NCCL 完成多机多卡推理。需要注意的是,多机环境往往还需要 RDMA 网络(如 InfiniBand 或 RoCE)的硬件支持来实现低时延的数据交换,不过这部分超出了本入门教程的范围,所以没有加以讨论。

掌握了分布式推理的宏观通信原语,仅仅是第一步。当张量被切分到单张 GPU 上后,如何把单卡硬件的性能压榨到极致?如何写出比 PyTorch 原生算子更快、显存占用更低的手写 CUDA Kernel?

实际上,将高性能的自研 Kernel 包装成 PyTorch Ops,正是当代顶级推理框架(如 vLLM, SGLang)的标准做法。如果你对底层的极致性能优化感兴趣,欢迎浏览本人的硬核算子库项目 vitamin-cuda,里面包含了上百个手写 CUDA 算子。欢迎共同学习、交流与进步!

完整代码见 vitamin-cuda torch_dist

以上

AI infer LLM PyTorch distribution Tensor Parallelism Data Parallelism Pipeline Parallelism