More than code

More Than Code
The efficiency of your iteration of reading, practicing and thinking decides your understanding of the world.
  1. 首页
  2. 未分类
  3. 正文

Pytorch FSDP1 Introduction

2026年1月1日 9点热度 0人点赞 0条评论

整理一下FSDP1相关的实现,偏源码级,适合想读代码的同学来看

Core

  • 核心流程如图


  • 图中可以看到,FSDP虽然shard了模型,但是在计算(forward)的过程中,还是需要收集所有的参数做计算。所以他还是被划分成DataParallel

Init

主要分为两块:

  • 把模型切分成fsdp unit

  • 把每一个fsdp unit的参数展开成FlatParameter

  • 以一个CausalTransformer为例,模型会被组织成一个树形的结构。每一个节点都会有自己的Parameter,定义自己的forward计算逻辑。

FSDP的第一步是切分这个模型,把它划分成若干个FSDP unit。torch在这里把切分模型的任务主要给了用户,用户需要定义自己的模型切割策略,告诉FSDP那些module应该被划分成一个unit。

  • torch默认提供的策略,比如size_based_auto_wrap_policy就是一个参考,可以根据module的参数大小来确定是否做切割。

  • 还有一些策略,直接根据module的类型来切割。比如上面的CausalTransformer中,我们可以把每一个TransformerBlock划分成一个FSDP unit

  • 有关unit大小的tradeoff,paper中并没有相关的讨论,这里就根据我的理解说一下。

    • FSDP unit越大,肯定对内存的要求越高。因为计算时需要收集FSDP unit内所有的参数。

    • 而FSDP unit越小,比如极端情况就是一个小矩阵的乘法,此时发起通信的开销比较大(虽然通信的总量是固定的,但是启动通信算子本身就有开销,这也是为什么要做batch)

      • 不过猜测启动算子的开销应该到FSDP unit大到一个点的时候,就可以被overlap起来?或者是占比比较少了。
    • 从上面的分析看,感觉更像是一个分段函数,开始的时候对着FSDP unit变大,吞吐会变大。到某一个点后这个增长的斜率会变缓

  • 如图,如果按照TransformerBlock切的话,上面的模型切分完之后就是这个样子。再重复一下,每一个FSDP unit在执行时会收集所有分片的参数,所以这里切分就需要保证单个TransformerBlock能够放到GPU中,否则就需要继续做切分。

  • 核心思路是把所有的参数拍平,都放到FlatParameter中。然后再根据world size去切分FlatParameter,每一个rank拿到一部分的FlatParameter。
    • 使用这种方式设计的原因在论文中有提,感兴趣可以看《PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel》
  • 实现上,这里有几个需要注意的点:
    • 通信库中(nccl)有部分高性能的通信算子要求各个rank输入的tensor大小相同。比如All Gather希望所有tensor大小相同。这里会把FlatParameter按照world size做一下padding,让FlatParameter的大小能够整除world size

    • 因为后续使用上,还会涉及到通过view把FlatParameter的切片变换成原始的参数,所以为了提高访存效率,这里还会对每一个子的parameter按照16byte做对齐。

Runtime

有了切分好的参数后,就可以看一下FSDP是怎么把上面概念上的通信算子加入到计算逻辑中了。
FSDP wrap了原始模型,在forward链路增加了pre-forward/post-forward的逻辑。

  • pre-forward的逻辑就是做unshard,通过All Gather收集参数,给后面模型的计算做准备

  • post-forward的逻辑就是做reshard,释放掉unsharded param来释放内存

forward阶段的逻辑比较好理解,这里需要关注的点主要是backward阶段,因为backward阶段是torch的autograd引擎自动执行的,无法直接像forward一样直接增加逻辑进去。不过autograd中提供了很多hook的能力来完成这个任务。

  • _pre_forward阶段
    • 会做All Gather,构建Unsharded Param

    • 注册post_backward_hook,注册到AccumulateGrad对象上。每一个Parameter在计算图中都会有一个AccumulateGrad对象,用来累积梯度。当累积的梯度计算完(猜测就是AccumulateGrad入边都遍历完了),会调用这个hook。

      • 执行_post_backward_hook,会对这个FSDP unit的梯度做reduce-scatter
  • Module computation阶段,使用unsharded param做计算。

  • _post_forward

    • 释放unsharded param,节省内存

    • 注册pre_backward_hook。注册到当前FSDP unit的输出上。tensor.register_hook的作用是当有梯度传过来的时候,会调用这个hook。对应的就是有梯度传到当前FSDP unit了,是当前FSDP unit开始做反向传播的时机了。

      • 执行_pre_backward_hook,和pre forward一样,做All Gather,构建Unsharded Param参与梯度计算。
  • _post_backward_final_callback
    • 由FSDP root state注册到_execution_engine上的。这块因为我对torch的autograd并不熟悉,看注释就是在整个backward pass结束的时候调用。

    • 这块主要做的是一些收尾逻辑,比如状态的清理有核心点在prepare_gradient_for_optim中,把反向传播阶段积累的梯度放到flat_param.grad上。从而实现每一个rank自己拥有的FlatParam上,都有sharded param + sharded grad,然后optimizer.step()即可。

Autograd

实现上,还有一个需要考虑的事情。既然FSDP引入了FlatParam,而用户定义的模型的参数都是保存在他们自身的模块中的。那么怎么做梯度的reduce-scatter,以及何时做,就成了一个问题。

  • 比如可以尝试hack autograd引擎,在计算阶段手动遍历一下所有的param,看梯度是否已经计算完毕,然后发起reduce scatter。但是这个遍历的时机也很难把握,以及还需要处理mini-batch的case

  • 这里使用了torch的movement类型的操作,把原始module的参数看作是普通的tensor。真正的parameter只有FlatParameter。而原始module的参数是FlatParameter经过split + view变化之后得到的。

  • 因为这些split/view也是会被autograd处理的,那么在反向传播的时候,梯度会自动从上面原始模块的参数Tensor中传播,直到下面的FlatParameter。

  • 同时ReduceScatter的时机上面也提到了,因为Parameter上有一个AccumulateGrad的对象,可以在他累积梯度完成之后注册一个hook,完美契合我们希望做ReduceScatter时间。

如果阅读代码的话,有一个额外的点,就是在注册这个AccumulateGrad的post hook的时候:

  • 新版本torch.compile支持了这个操作,直接执行register_post_accumulate_grad_hook即可

  • 老版本没支持的时候,这里会hack的先做一次假的操作(expand_as),得到一个新的tensor。通过新tensor的grad_fn.next_functions[0][0],即在计算图上找这个函数的上游(反向传播的下一个节点),就是AccumulateGrad这个对象了。

还有一个实现上需要注意的点,就是FlatParameter在做unshard/reshard时的实现:

  • 在计算dx的时候,需要用到linear中的权重。根据我们上面的讲解,linear的权重实际上是unsharded flat param的一个view。

  • 在reshard阶段,我们会清空unsharded param,节省空间。backward的unshard阶段,我们会把unsharded param再拷贝回来。这个过程中我们需要保证上面的Linear的权重在反向传播时是有效的。

    • 比如一种比较挫的实现,就是reshard阶段只是设置FlatParameter.data = None。因为Linear还持有这个view,所以不会发生GC,导致空间无法被释放。

    • 而unshard的时候,如果只是设置FlatParameter.data = unsharded_param。只是更新了这个指针,Linear在autograd中,还是引用的forwad时用的weight tensor。

  • 这里的做法是,这些tensor底层会共享相同的一个Storage的对象。上层autograd看到的是这里的tensor,而不会感知底层的Storage。

    • reshard时,直接操作这个Storage对象,把他的空间释放到。

    • unshard时,resize这个Storage对象,并做拷贝。

    • 在这个过程中,autograd始终看到的是相同的tensor/Storage对象。但是Storage对象对应的数据是有改变的。

Computation/Communication Overlapping

  • overlapping相关主要分两块

    • 控制多个stream来实现计算/通信的并发

    • 通过Prefetch,减少等待参数ready的时间

  • stream,这块我也没抠过代码,主要知识还是来源于GPT老师
    • stream可以看作是任务队列,队列内的操作是FIFO的。不同stream之间是并行的。

    • 核心的几个API,在FSDP代码中会经常看到

      • with stream(StreamA),表示切换stream,默认情况下有一个stream。切换stream后,会把CUDA相关的操作都提交到对应的新的stream上。那么通过with stream切换,就可以实现提交多个并发的任务

      • streamA.wait_stream(StreamB),stream之间做同步。相当于在streamA中提交一个任务,这个任务的语义是在StreamB当前提交的所有任务执行完之后,再去执行streamA的操作。

      • tensor.record_stream(StreamA),告诉allocator,tensor会在StreamA上使用,所以在StreamA的操作完成之前,不要释放tensor

    • 这里是pre forward阶段的示意图。这种在stream之间切换,pipeline的模式。就是多种任务的overlapping。
      • PreUnshard stream使用的是precision casting的资源(不知道是不是独立的),以及cpu-gpu的移动(PCIE的带宽)

      • Unshard stream使用的是网络带宽

      • Compute stream使用的是GPU的计算资源。

  • 第二块就是prefetch,这块相对简单
    • FSDP维护了每一个fsdp unit的pre/post forward的执行顺序。执行的时候,就会根据当前的阶段,以及之前记录的pre/post forward的执行顺序,来尝试对下一个FSDP unit做prefetch,即做unshard param(All gather)

    • pre_forward阶段,记录pre forward的执行顺序。这里假设是静态图,所以只会在第一个iteration做记录。每一个FSDP unit的状态上会记录自己执行的顺序。

      • 即假设forward是根据pre forward的顺序来的。
    • post_forward阶段,记录每个FSDP unit post forward的执行顺序,用来做backward的prefetch
      • 即假设backward是根据post forward的逆顺序来的

  • prefetch这块有一个额外需要注意的点,就是forward阶段的prefetch默认是关闭的。
    • 因为CPU的执行一般快于GPU,所以按顺序执行forward的时候,后面FSDP unit的All Gather都已经被提前发起了。

    • 当部分CPU bound的场景时,All Gather快于CPU的执行,此时需要通过CPU预先发出后面的All Gather。才需要prefetch forward

  • 那么相关的问题,既然CPU执行这么快,backward阶段应该也不需要prefetch呢?

    • 这块在论文中已经有回答,因为ProcessGroup中通信时串行的,如果不做Prefetch的话,那么有可能先执行ReduceScatter,再执行AllGather。导致阻塞下一次backward compute的进行

    • 所以做一下prefetch可以先发起All Gather,让后面的Backward compute和当前的ReduceScatter并行。

MemoryManagement

  • 这块主要是follow论文中的描述,以及对应的实现。这块不看论文/经过大量实践的话,感觉完全发现不了这个问题,估计是工程团队实践出真知了,感谢分享。

  • 先看一下Pytorch的CUDA Allocator的工作原理:

    • 使用GPU内存/释放内存的时候,需要调用cudaMalloc/cudaFree,会产生设备的同步,开销比较大。

    • 为了缓解这个问题,torch自己搞了一个缓存分配器,会先从cuda中分配比较大的block,然后自己内部复用这些block,尽可能的减少cudaMalloc/cudaFree的调用

    • torch的allocator是在CPU中工作的。是在执行到这个代码的时候立刻做决策的,而CPU/GPU是异步的,所以他需要做估计,来决定尝试复用那些block。

    • 在单个stream的语义下,所有的操作都是串行的。所以这个决策比较容易做,看一下队列里的任务对于某一个block都用完了,就可以做复用了。

    • 在多个stream的情况下,因为执行是并发的,stream之间没有顺序保证,所以复用策略比较保守,此时就会更容易多分配一些内存。

    • 那么,那么可能出现一些stream疯狂的发起操作,尝试申请内存,然后发现内存满了,触发cuda的内存分配。

      • 比如producer stream(unshard stream),发起了很多all gather操作。占用了大量内存无法被复用。(因为这些内存还等着用完了给consumer stream做计算)
    • 同时,block的分配是stream级别的,无法跨stream使用。就更容易出现这个问题。比如上面的case,producer stream分配了一堆block,consumer没有的用了,就会在分配内存的时候触发同步。

  • 这里核心问题是cpu执行过快,解决方法是all gather做限流,当发现all gather发的比较多了,就挂起cpu做等待。

    • 实现上,有一个_FreeEventQueue。每次All Gather之前,会等待上上个step的unsharded param释放。

    • 比如图中,第三个all gather要发起的时候,发现此时有两个unsharded param还没释放,说明可能有两个inflight All Gather,就会做一下等待,等待上上个unsharded释放后,那么此时一定只有一个All Gather。在发起当前的All Gather

    • 所以任何时间最多有两个并发的All Gather

  • 实现上,_unshard/_reshard在整个链路中多次配合使用(pre/post, forward/backward)。来做param状态的切换,以及实现上面的限流逻辑。

ParamInit

  • 好像没看到论文里说记录fake device上的操作,这块可能是我理解不到位。也可能是在torchdistX中做了,这块没有研究。

  • 目前看到的init主要就是每一个FSDP unit单独做初始化

    • 如果用户参数的设备是meta(没参数,只有shape),就会做materialize_module,把它移动到GPU上。

    • 调用module.reset_parameter()或者调用用户传入的param_init_fn

    • 初始化完成后会立刻做shard

Implementation Details

这里记录一些在阅读源码时可能比较迷惑的代码,以及一些invariance,方便在读代码时有一个直观的感觉

FlatParam状态切换

  • FlatParam在执行过程中,会进行状态的切换。需要明确一下,否则读代码的时候容易搞混
    • 因为在Optimizer的视角下,只能看到FlatParam和对应的梯度,此时他应该对应这个rank所负责的那块分片的参数

    • 而在Forward/Backward阶段,我们需要unsharded param做计算,同时也需要autograd参与,需要保证在这个计算过程中FlatParam的data/grad这两个字段都是unsharded的。

  • 同时,因为还有MixedPrecision/CPUOffload这些feature,也需要涉及到参数的变化。

  • 为了维护上述的状态,FlatParam会把很多中间状态的参数存到他的成员变量中,在阅读代码的时候可以看FlatParameter中的注释,这里列举几个变量给一个直观的概念:

    • _local_shard,保存的就是当前shard的参数。在optimizer阶段。flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()

    • _full_param_padded,unsharded param。在计算阶段,flat_param.data.data_ptr() == flat_param._full_param_padded.data_ptr()

    • _mp_shard,低精度的参数,用于混合精度训练

    • _cpu_grad,用于offload grad到CPU上

  • FlatHandle有一个状态的切换,代码中的逻辑会根据这个状态来决定走不同的分支
    • 比如根据当前状态决定是否做prefetch(BACKWARD_PRE/BACKWARD_POST)

    • 比如在BACKWARD_PRE阶段,如果发现进入了forward的链路,说明是启动了activation checkpoint,会跳过这次preforward(因为不需要重复做unshard)

  • FSDP state中也有一个training state,不过状态更少,主要是用来做一些入口的检查,防止不合法的状态。相对比较简单。

LazyInit

  • 或者这一节也想讨论一下module的嵌套。

FSDP中,有一个root state,套着所有的子的fsdp state。这个root state负责初始化runtime相关的各种变量,比如stream,event queue等。

  • 在forward时,第一个执行的state就是root state,此时会做LazyInit,做上面初始化的操作
    • 之所以不在init时做,是因为init时用户可能自己单独用FSDP去wrap module,所以无法得知正确的fsdp state。
  • LazyInit中,root初始化相关变量后,会把这些变量都传递给子的fsdp state,防止重复初始化。

  • 那为什么代码中有一个reset_lazy_init呢

    • 这是因为用户可能单独调用一个子模块的state_dict,此时会触发all gather,需要lazy init所初始化的stream。但是他并不是一个真的root。真的root需要在forward时才能知道。所以最后再重置一下状态。

有关FSDP的嵌套:

  • 嵌套实际上是一个不太好的行为,因为嵌套的FSDP unit会在递归的调用preforward,内存占用是累加的。违背了FSDP希望节省内存的初衷。

  • 不过还是有可能出现嵌套,比如一个简单的情况CausalLM中,我们的root FSDP state中,包含若干子的FSDP state,和lm_head + embedding这两个没有被wrapp的模块,这种情况下lm_head/embedding就会持续在内存中,和root fsdp state一起。

    • 延展一点就是如果共享参数的param如果被手动放到了不同的shard上会怎么样。答案是不支持

Features

这一节简写一下FSDP相关的一些小feature的实现

StateDict

允许用户通过state_dict读取模型参数/进行修改。state_dict有多种模式,这里主要看FULL_STATE_DICT这种,对应state_dict返回的是unsharded unflatten param。也就是最符合直觉的一种。

  • FSDP的实现没有重写state dict,而是注册了_pre_state_dict_hook, _post_state_dict_hook, _pre_load_state_dict_hook, _post_load_state_dict_hook。

  • 主要做的就是两块:

    • Unshard param,然后把初始化的时候记录的原始的模型参数都设置回来,方便用户通过named_parameters()来访问。这块实现上和forward阶段的unshard是一样的,做All Gather

    • Reduce grad,因为grad也是shard的,返回给用户的时候需要把grad也收集起来。

  • 在state dict访问结束后,有可能用户修改了参数的内容,这里还会有一个write back的选项,把参数同步到flat param中

GradientClipping

因为GradientClipping需要收集全局的梯度,所以这里FSDP也提供了一个针对FSDP模型的实现:

  • 每个模型本地计算自己的sum(grad ** p_norm)

  • all_reduce,计算全局的sum。再计算全局的norm

  • 放缩梯度即可

MixedPrecision

https://zhuanlan.zhihu.com/p/694288870 这篇文章的图示画的已经非常清楚了。这里主要提一些要注意的点:

  • FSDP的mixed precision是fsdp unit级别的,没有做到更细粒度的划分。

  • 在forward/backward的unshard阶段,会把master weight(_local_shard)在精度转化,放到_mp_shard中,做后面的计算。

    • post_unshard阶段就会释放掉这些低精度的参数。
  • Backward reduce-scatter时,会根据配置转化成对应的低精度的类型做reduce。累计精度的时候,会用全精度的。
    • 这里还有一篇讨论reduce type的文章https://main-horse.github.io/posts/reduction-precision/

  • ShardedGradScaler
    • pytorch提供了一个GradScaler的工具,为了避免梯度计算时的underflow的问题,他会先给loss乘上一个比较大的值,然后做backward pass。最后再把梯度缩回来。

    • 这个过程中,因为乘上了一个比较大的scale,所以可能出现overflow,当出现overflow的时候,这一个pass会被跳过。

    • 对于FSDP来说,需要保证当一个rank出现了overflow,所有的worker都需要停止这次更新。ShardedGradScaler帮忙做了这个事情,会在optimizer之前做一次all reduce,同步overflow的数量,所有worker看到相同的overflow的数量,作出相同的决策,是选择做梯度更新或者是跳过更新。

CPUOffload

  • CPU offload,支持将master weight和grad offload到CPU上。核心思路可以参考Zero-offload。同时如果和上面的MixedPrecision配合起来的话,可以做到Zero-offload一样的效果。

  • 启动cpu offload之后,会把_local_shard对应的master weight放到CPU上。

  • unshard阶段,会把master weight从CPU移动到GPU上。如果有mixed precision,也会在这个阶段做转化。计算完之后再释放掉这段内存

  • backward阶段,做完reduce-scatter之后,会把grad移动到_cpu_grad中。因为optimizer要求param和grad在同一个设备上,才能做计算。

UseOriginParam

use_orig_param的目的是让用户在使用FSDP的时候,也可以通过parameter()访问到原始的那些参数,同时支持了frozen/un-frozen混合的训练。给用户更丝滑的使用体验。
这一节因为比较细节,更多是讲实现相关的。按需阅读即可
在代码上,主要关注的就是四个地方:

  • 参数划分
    • ShardParamInfo中会记录,对于当前rank,每一个参数是否有一些数据保存在当前rank中。如果有,他对应的offset/length是什么

    • FSDP提供给用户的参数就是这里被切分后的参数,梯度也是一样。

      • 比如一个weight,有一部分在rank0,一部分在rank1。那么rank0就看到的是一个1D的tensor,对应这个weight的一个切片。
  • unshard
    • 这里也是把原始模型的参数替换成tensor slice。但是会把新的tensor slice也注册到module.parameter中

    • 生成的tensor slice会被保存下来。

      • 因为如果有reentrant AC,在backward pre阶段会再次调用forward。此时需要用slice再替换掉之前的parameter。

      • 为了保证两次forward相同,这里会用第一次forward的tensor slice来替换,只不过data会被重新复制。

      • 相关的一个问题是,为什么use_orig_param=False不需要做这个处理呢?

        • use_orig_param=False的时候,reshard不会做任何操作,不会去操作这些tensor slice。所以再次forward的时候,没有人动过这里的结构。

        • 而use_orig_param=True的时候,reshard会把原始的parameter替换回来,所以pre forward时候需要再把之前的tensor slice拿回来。

        • 至于data,我个人感觉可能不用重新赋值,因为他们应该始终都是指向同一个flat param的storage

  • reshard

    • 会把之前module的param切回来。然后把data设置成当前shard的flat param的slice。如果param不在当前shard的话,会被设置成0

    • 在post backward阶段的reshard,会把grad也切分,分发到对应的param上。如果require_grad=false,梯度也会被清理掉。

      • 这里就解释了fsdp是如何支持frozen/unfrozen一起训练了:
        • Use orig param= true的时候,在backward阶段,会把梯度再重新赋值给各个param。如果是frozen的param,require grad就是false。此时就不会传梯度给他。

        • 同时也可以解释为什么说一起用的时候资源占用会多,因为fsdp的backward是统一给的flat param。在backward阶段没有区分frozen/un frozen。而是计算完之后,最后分发的时候做的区分。

          • 个人感觉做这个区分应该不算很复杂,可能就是给tensor param哪里设置一下require_grad = false就行?
  • manual
    • 因为用户有权限访问原始的param,那么他就有可能去修改原始param的数据,或者是尝试主动去清理梯度。这些操作会被记录下来

    • 在pre_unshard阶段,会去同步用户的操作和flat param/grad。核心函数在_writeback_orig_params中。

    • 这里会去判断param的数据是否有变化(是否和flat param指向相同位置),如果有变化会把对应位置的数据拷贝到param中。梯度也是同理。

Reference

https://zhuanlan.zhihu.com/p/694288870
PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel

标签: 暂无
最后更新:2026年1月1日

sheep

think again

点赞
< 上一篇
下一篇 >

文章评论

取消回复

COPYRIGHT © 2021 heavensheep.xyz. ALL RIGHTS RESERVED.

THEME KRATOS MADE BY VTROIS