admin 管理员组文章数量: 1184232
显存
显存占用分析
- Model States
- 模型参数
- 后向传递计算得到的 梯度
- 优化器状态
- Activation
- 前向计算过程中产生的 中间激活
数据类型
- float32(FP32):32 位浮点数,也称为单精度。
- float16(FP16):16 位浮点数,表示范围较小,也被称为半精度。
- bfloat16(BF16):扩大了指数位数,缩小了小数位数,因此表示的范围更大,精度更弱。
一般采用 16 位的表示,那么一个参数占用 2byte,即 2B。
FP16 的精度高,但是表示范围小,容易上溢;
BF16 的表示范围大,但精度低,因此更容易下溢,为了避免溢出问题,提出了混合精度方案。
训练过程
训练大模型时通常会采用 AdamW 优化器 ,并用 混合精度 训练来加速训练,基于这个前提分析显存占用。
在一次训练迭代中,每个可训练模型参数都会对应 1 个梯度 ,并对应 2 个优化器状态 (Adam 优化器梯度的一阶动量和二阶动量)。
推理过程
在神经网络的推理阶段,没有优化器状态和梯度,也不需要保存中间激活。 模型推理阶段占用的显存要远小于训练阶段 。
如果使用 float16 来进行推理, 推理阶段模型参数占用的显存大概是 2 Φ 2\mathbf\Phi 2 Φ 。
模型参数
符号说明:
| 数学符号 | 定义 |
|---|---|
| l | 模型层数 |
| d | 隐层维度 |
| h | 注意力头数 |
| b | batch size |
| s | 序列长度 |
| V | 词表大小 |
| μ | 向量的均值 |
| σ | 向量的方差 |
从输入到输出的顺序依次计算:
Embedding 层:词嵌入矩阵即一个 V → d V \rightarrow d V → d 无偏置线性层,将 V V V 大小的 one-hot 编码映射成 d d d 大小的 token。参数个数 $ Vd $。
- Positional Embedding:如果采用可训练式的位置编码,会有一些可训练模型参数,数量比较少。如果采用相对位置编码,例如 RoPE 和 ALiBi,则不包含可训练的模型参数。我们忽略这部分参数。。
l l l 个 block:
Self-attention:attention 层中有四个 d → d d \rightarrow d d → d 线性层,包含了权重: W q W_q W q 、 W k W_k W k 、 W v W_v W v 、 W o u t W_{out} W o u t 以及各自的偏置。
- 权重矩阵 n 的形状 [ d , d ] [d,d] [ d , d ] ,参数个数 d 2 d^2 d 2 ,
- 偏置形状 [ d ] [d] [ d ] ,参数个数 d。
- 总计参数量 4 d 2 + 4 d 4d^2+4d 4 d 2 + 4 d .
Layer Normalization:设层输入是 x i n x_{in} x in ,
layer normalization 公式: x o u t = γ ⊙ α + β x_{out}= \gamma \odot \alpha+\beta x o u t = γ ⊙ α + β , α = x i n − μ ( σ 2 + ϵ ) \alpha=\frac{x_{in}−\mu}{\sqrt{(\sigma^2+\epsilon)}} α = ( σ 2 + ϵ ) x in − μ 。
其中 μ \mu μ 表示 x i n x_{in} x in 的均值,$ \sigma$ 表示 x i n x_{in} x in 的方差, ϵ \epsilon ϵ 防止除零, γ \gamma γ 和 β \beta β 是可学习的参数,形状都是 [ d ] [d] [ d ] ,参数个数 d d d ,一层的参数个数 2 d 2d 2 d 。
因为 self-attention 和 mlp 后各有一层 layer nromalization。所以总参数个数 4 d 4d 4 d 。
mlp:共有两个带偏置的线性层,隐层维度默认为 4 d 4d 4 d :
- 第一个是 d → 4 d d \rightarrow 4d d → 4 d ,权重矩阵形状 [ d , 4 d ] [d,4d] [ d , 4 d ] ,偏置形状 [ 4 d ] [4d] [ 4 d ] ,层参数 4 d 2 + 4 d 4d^2+4d 4 d 2 + 4 d ;
- 第二个是 4 d → d 4d \rightarrow d 4 d → d ,权重矩阵形状 [ 4 d , d ] [4d,d] [ 4 d , d ] ,偏置形状 [ d ] [d] [ d ] ,层参数 4 d 2 + d 4d^2+d 4 d 2 + d ;
- mlp 的总参数个数 8 d 2 + 5 d 8d^2+5d 8 d 2 + 5 d
每个 block 的参数个数共计 12 d 2 + 13 d 12d^2+13d 12 d 2 + 13 d .
输出层和 Embedding 层共用参数。
因此,模型共计参数 l ∗ ( 12 d 2 + 13 d ) + V d l∗(12d^2+13d)+Vd l ∗ ( 12 d 2 + 13 d ) + V d
CodeGen 350M 参数
Name Size Embedding transformer.wte.weight torch.Size([51200, 1024]) transformer.h.0.ln_1.weight torch.Size([1024]) transformer.h.0.ln_1.bias torch.Size([1024]) Self-attention transformer.h.0.attn.qkv_proj.weight torch.Size([3072, 1024]) Self-attention-out transformer.h.0.attn.out_proj.weight torch.Size([1024, 1024]) mlp transformer.h.0.mlp.fc_in.weight torch.Size([4096, 1024]) transformer.h.0.mlp.fc_in.bias torch.Size([4096]) transformer.h.0.mlp.fc_out.weight torch.Size([1024, 4096]) transformer.h.0.mlp.fc_out.bias torch.Size([1024])
不同版本 LLaMA 模型的参数量
| 实际参数量 | 隐藏维度 h | 层数 l | 12 l h 2 12lh^2 12 l h 2 |
|---|---|---|---|
| 6.7B | 4096 | 32 | 6,442,450,944 |
| 13.0B | 5120 | 40 | 12,582,912,000 |
| 32.5B | 6656 | 60 | 31,897,681,920 |
| 65.2B | 8192 | 80 | 64,424,509,440 |
优化器状态
在训练过程中,模型的每个参数会记录梯度用于更新,此外优化器也会额外记录一些数据,称为 优化器状态 。
设模型参数为 $ \mathbf\Phi$, 那么梯度的元素数量为
Φ
\mathbf\Phi
Φ
,模型参数(fp16)、模型梯度(fp16)和优化器状态(fp32),
总占用
:
2
Φ
+
2
Φ
+
K
Φ
=
(
4
+
K
)
Φ
2\mathbf\Phi +2\mathbf\Phi+K\mathbf\Phi = (4+K)\mathbf\Phi
2
Φ
+
2
Φ
+
K
Φ
=
(
4
+
K
)
Φ
- 总占用和参数量有关,和输入大小无关。
- 在整个训练过程中都要存在显存中。 模型参数一般只能通过并行切分 (Tensor Parallelism/Pipeline Parallism)能减少。 优化器状态一般通过 ZeRO 来减少。
- 不同优化器的 K 值不同 ,算法的中间变量、框架的实现都有可能有一定区别。
AdamW 优化器 对模型中的每个参数记录了两个动量(一阶和二阶动量) m t m_t m t 和 v t v_t v t 。
- 在 混合精度训练 中,会使用 float16 的模型参数 进行前向传递和后向传递,计算得到 float16 的梯度 ;
- 在 优化器 更新模型参数时,会使用 float32 的优化器状态 、 float32 的梯度 、 float32 的模型参数 来更新模型参数。
- 使用 AdamW 优化器 和 混合精度训练 来训练参数量为 Φ \mathbf\Phi Φ 的大模型, 模型参数、梯度和优化器状态占用的显存大小为 $ 20\mathbf\Phi$ bytes
2 + 4 ⏟ weights + 2 + 4 ⏟ gradients + 4 + 4 ⏟ Adam states = 20 \underbrace{2+4}_{\text {weights}} +\underbrace{2+4}_{\text {gradients}} + \underbrace{4+4}_{\text {Adam states}} = 20 weights 2 + 4 + gradients 2 + 4 + Adam states 4 + 4 = 20【注】:有的参考资料中,没有考虑 fp32 的梯度,计算得到总显存为 2 Φ + 2 Φ + 12 Φ = 16 Φ 2\mathbf\Phi +2\mathbf\Phi+12\mathbf\Phi = 16\mathbf\Phi 2 Φ + 2 Φ + 12 Φ = 16 Φ ,此处参考
中间激活值
激活(activations) 指的是前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量
中间激活值占用显存 分两个部分分析:Self-Attention 和 MLP,Embedding 没有中间值。
Self-Attention 块的中间激活占用显存大小为 11 b s d + 5 b s 2 h 11bsd+5bs^2h 11 b s d + 5 b s 2 h
对于 MLP 块,需要保存的中间激活值为 19 b s d 19bsd 19 b s d 。
layer norm 需要保存其输入,大小为 2 b s d 2bsd 2 b s d ,2 个 layer norm 需要保存的中间激活为 $ 4bsd $
对于 l l l 层 transformer 模型, 最终合计 l ∗ ( 34 b s d + 5 b s 2 h ) l*(34bsd +5bs^2h) l ∗ ( 34 b s d + 5 b s 2 h ) 。
- 激活值 与输入数据的大小( 批次大小 b 和 序列长度 )成正相关。
- 在训练过程中是变化值,特别是 batch size 大的时候成倍增长很容易导致 OOM。
- 可以通过 重计算 、 并行切分 策略减少。
在一次训练迭代中
- 模型参数(或梯度)占用的显存大小 只与 模型参数量 和 参数数据类型 有关,与输入数据的大小是没有关系的。
- 优化器状态占用的显存大小 与 优化器类型 有关,与 模型参数量 有关,与输入数据的大小无关。
- 中间激活值 与输入数据的大小( 批次大小 b b b 和 序列长度 s s s )是成正相关的,随着 批次大小 b b b 和 序列长度 s s s 的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足 OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。
以 GPT3-175B 为例,直观对比模型参数与中间激活的显存大小。GPT3 的模型配置如下。假设采用混合精度训练,模型参数和中间激活都采用 float16 数据类型,每个元素占 2 个 bytes。
| 模型名 | 参数量 | 层数 | 隐藏维度 | 注意力头数 |
|---|---|---|---|---|
| GPT3 | 175B | 96 | 12288 | 96 |
GPT3 的模型参数量为 175B,占用的显存大小为 2 ∗ 175 ∗ 1 0 9 bytes = 350 GB 2*175*10^9 \text{bytes}=350 \text{GB} 2 ∗ 175 ∗ 1 0 9 bytes = 350 GB 。GPT3 模型需要占用 350GB 的显存。
GPT3 的序列长度 l l l 为 2048 。对比不同的批次大小 b b b 占用的中间激活:
当 l l l = 1 时,中间激活占用显存为 ( 34 b s d + 5 b s 2 h ) ∗ l = 275 , 414 , 777 , 856 bytes ≈ 275 GB (34bsd+5bs^2h)∗l=275,414,777,856 \text{bytes}\approx 275 \text{GB} ( 34 b s d + 5 b s 2 h ) ∗ l = 275 , 414 , 777 , 856 bytes ≈ 275 GB ,大约是模型参数显存的 0.79 倍。
当 l l l = 64 时,中间激活占用显存为 ( 34 b s d + 5 b s 2 h ) ∗ l = 17626545782 bytes ≈ 17.6 TB (34bsd+5bs^2h)∗l=17626545782 \text{bytes}\approx 17.6 \text{TB} ( 34 b s d + 5 b s 2 h ) ∗ l = 17626545782 bytes ≈ 17.6 TB ,大约是模型参数显存的 50 倍。
当 l l l = 128 时,中间激活占用显存为, $ (34bsd+5bs^2h)∗l=35253091565568 \text{bytes}\approx 35.3 \text{TB}$ 大约是模型参数显存的 101 倍。
可以看到随着批次大小 b b b 的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用 激活重计算 技术来减少中间激活,理论上可以将中间激活显存从 O ( n ) O(n) O ( n ) 减少到 O ( n ) O(\sqrt{n}) O ( n ) ,代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。
版权声明:本文标题:显存计算_激活值显存计算 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.roclinux.cn/b/1773989461a3568041.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论