正文
DP (Data Parallelism,数据并行):
- 原理: 每个人(GPU)拿这本完整的书(模型),读不同的章节(数据)。读完后,大家把学到的心得(梯度)平均一下。
- 瓶颈: 每个显卡必须存一份完整的模型。模型太大(如 70B)时,单卡存不下,DP 直接失效。
- 联系: FSDP 的前身就是 DP。
2. FSDP1 和 FSDP2 是什么?
FSDP (Fully Sharded Data Parallelism) 是 PyTorch 官方推出的、对标准 DP 的一种“显存极致优化版”。
FSDP 的核心思想 (Zero Redundancy)
标准的 DP 每个卡都有完整模型参数。FSDP 认为这太浪费了,于是它把模型参数、梯度、优化器状态全部切碎(Shard),分散存储在所有 GPU 上。
- 计算时: 当 GPU 需要计算某一层时,通过
AllGather把缺失的参数从其他卡拉过来,计算完立刻释放(Free)。 - 结果: 显存占用大幅降低(几乎随着 GPU 数量线性减少),但通信量增加了。
FSDP1 vs. FSDP2
- FSDP1 (完全分片数据并行 - 经典版):
- 实现方式: 使用
FlatParameter。它会把一个模块(Module)里的所有参数“拍扁”成一个长长的一维向量,然后进行切分。 - 缺点:
- 黑盒: 参数被拍扁并包装了,外部很难直接访问原始参数,导致 debug 困难。
- 兼容性: 这种 hack 的方式导致它很难和
torch.compile(PyTorch 2.0 的编译器)配合优化。 - TP 混合难: 因为参数结构变了,想和 TP 结合使用时代码非常复杂。
- 实现方式: 使用
- FSDP2 (Per-Parameter Sharding - 进化版):
- 注:PyTorch 官方文档中有时称其为 _
_FSDP_via_torch.distributed._composable_APIs。_ - 核心改进: 放弃了“拍扁”操作,基于 DTensor (Distributed Tensor) 技术。
- 优势:
- 原生对象: 参数保持原有的 Tensor 结构,只是在底层是分布式的。
- 性能更强: 更好地支持通信和计算的重叠(Communication-Computation Overlap)。
- 支持 torch.compile: 这是 FSDP2 最大的卖点,编译加速更顺滑。
- 组合性: 可以更容易地与 TP、SP(序列并行)结合。
- 注:PyTorch 官方文档中有时称其为 _
3. FSDP 和 DeepSpeed 的关系
这是很多人的混淆点。一句话总结:DeepSpeed 是带头大哥(发明者),FSDP 是 PyTorch 官方的标准化实现(追随者)。
- 起源(ZeRO 技术):
- 微软开发了 DeepSpeed 库,提出了 ZeRO (Zero Redundancy Optimizer) 算法。
- ZeRO-1: 只切分优化器状态。
- ZeRO-2: 切分优化器状态 + 梯度。
- ZeRO-3: 切分优化器状态 + 梯度 + 模型参数。
- 对应关系:
- FSDP $\approx$ DeepSpeed ZeRO-3。
- Meta(Facebook)看到 DeepSpeed 的 ZeRO-3 效果很好,于是把它集成进了 PyTorch 原生代码库中,命名为 FSDP。