0
点赞
收藏
分享

微信扫一扫

LLM分布式训练第一课(通讯原语)

这个系列作为TFLOPS和显存消耗的续篇,今天开始正式连载

上一部地址:

LLM 参数,显存,Tflops? 训练篇(5) (qq.com)

前一篇文章举了65B模型的训练所消耗的显存的案例,如果把条件降低一点,我们看一下7B的模型需要多少显存?

  • 2byte的模型静态参数权重(以16bit存储)
  • 2byte的模型更新参数权重 (以16bit存储)= 14G
  • 2byte的梯度(以16bit存储)= 14G
  • 2byte的梯度更新(以16bit存储)= 14G
  • 4byte的一阶动量优化器更新(以32bit存储)= 28G
  • 4byte的二阶方差优化器更新(以32bit存储)= 28G

就不算Activation了,基本这些就已经超过目前在售的GPU最大内存80G了,所以要做分布式的训练来弥补单卡显存的上限。

这节我们介绍一下分布式训练系统,如同之前提到的内容,单独的卡目前对较小的模型也无法在做训练的时候load所有的静态,动态数据,所以要靠分布式训练系统来解决,分布式训练系统不只是单纯的能扩大显存的容量,装载更大的模型,和提供更多的batch_size,更长的训练语句长度,同时,它也在一定程度上,线性的提升了TFLOPS算力总量,提升了训练的速度。

在分布式训练的场景中会有很多分类,包括机器内部的分布式并行,跨机器的分布式,和数据分布式并行等等,以下就这几个场景展开讨论一下。

通信原语

在介绍几种分布式并行方式之前,我们要预先了解一下分布式训练中的通信原语。

Pytorch的分布式训练的通信是依赖torch.distributed模块来实现的,torch.distributed提供了point-2-point communication 和collective communication两种通信方式。

point-2-point communication(P2P)提供了send和recv语义,用于任务间的通信。

collective communication(CC)主要提供了scatter/broadcast/gather/reduce/all_reduce/all_gather 语义,不同的backend在提供的通信语义上具有一定的差异性。

我们训练大模型主要考虑的还是CC通信,所以着重讲一下这部分。

1-Broadcast

广播是相对最容易理解的通信原语。

LLM分布式训练第一课(通讯原语)_数据


Broadcast的意思就是节点要把它自己的数据原封不动的发到其他的节点,如图所示GPU0会把一份数据拷贝给GPU0,1,2,3,Broadcast操作是将某节点的输入广播到其他节点上,在分布式LLM训练中常用于网络参数的初始化。

Broadcast的数据具象图:

LLM分布式训练第一课(通讯原语)_数据_02


2-Scatter

Scatter是一种另类的Broadcast。

LLM分布式训练第一课(通讯原语)_权重_03


如图所示,Scatter是把一系列的数据切片成为不同的数据,然后分别提交给相应的数据处理节点,这种通信原语行为一般出现在类似于DP的数据分配起步阶段。

LLM分布式训练第一课(通讯原语)_权重_04


3-Gather

Gather的行为可以认为逻辑上和Scatter正好倒过来,是由多个节点将不同的数据一起发给一个汇聚节点的通信行为。

LLM分布式训练第一课(通讯原语)_数据_05

LLM分布式训练第一课(通讯原语)_数据_06


4-Reduce

Reduce相对复杂一点,是一系列简单运算操作的统称,包含但不限于SUM、MIN、MAX、PROD、LOR等类型的规约操作,Reduce意为减少/精简,因为其操作在每个节点上获取一个输入元素数组,通过执行操作后,将得到精简的更少的元素。下面以Reduce sum为例子(因为好理解)。

LLM分布式训练第一课(通讯原语)_数据_07


如图所示就是一个简单的Reduce Sum的通信示意图,4个节点分别把数据汇聚给0节点,由0节点(主节点)来执行Sum的操作(Gather可以近似认为至少在Gather完成的这一步是不做计算操作)

Reduce的数据具象图和Gather的区别也非常大,我们还是以Reduce Sum为例看一下区别。

LLM分布式训练第一课(通讯原语)_权重_08


5-All-Reduce

All-Reduce是经常可以在各种文档上看到的操作,同时也是最消耗带宽的操作,All Reduce操作可通过单节点上Reduce + Broadcast操作完成。NCCL库里关于实现All Reduce的逻辑是从多个sender那里接收数据,最终合并和分发到每一个节点上,还是以All-Reduce-Sum为例方便理解。

LLM分布式训练第一课(通讯原语)_数据_09

LLM分布式训练第一课(通讯原语)_权重_10

 数据具象图

6-All Gather

All Gather和All Reduce的区别看图就非常容易区别,All Gather和Gather一样,只是把数据收集下来,但是每一个节点,都有Gather来的数据的完整备份。

LLM分布式训练第一课(通讯原语)_数据_11

All-Gather通信简易图

LLM分布式训练第一课(通讯原语)_数据_12

All-Gather 数据具象图

7-All2All

All2All作为全交换操作,可以让每个节点都获取其他节点的值。

在使用 All 2All 时,每一个节点都会向任意一个节点发送数据,每一个节点也都会接收到任意一个节点的数据。每个节点的接收缓冲区和发送缓冲区都是一个分为若干个数据块的数组。

All2All 与 All Gather 相比较,区别在于:All Gather 操作中,不同节点向某一节点收集到的数据是完全相同的,而在 All2All 中,不同的节点向某一节点收集到的数据是不同的。

LLM分布式训练第一课(通讯原语)_原语_13

All2All 数据具象图

业界常用的大模型训练的通信原语库主要有MPI和NCCL,早期MPI在CPU和GPU的分布式通信领域都是主力军,在NCCL推出之后,MPI库现在就只用在了CPU的分布式通信场景,而GPU的分布式通信库目前都是以NCCL为主(在NV的场景)。

举报

相关推荐

0 条评论