大模型学习笔记-Llama-3模型架构之RMS-Norm与激活函数SwiGLU

目录

大模型学习笔记——Llama 3模型架构之RMS Norm与激活函数SwiGLU

上文简单介绍了 。在以后的文章中将逐步学习并记录Llama 3模型中的各个部分。本文将首先介绍归一化模块RMS Norm与激活函数SwiGLU。

归一化模块是各个网络结构中必有得模块之一。Llama 3模型基于Transformer,Transformer中采用的归一化模块通常为层归一化Layer Norm(LN),如下图所示。而Llama模型采用LN的改进版RMS Norm。

https://i-blog.csdnimg.cn/direct/0e0eecb702374e849b286c3d43f0d291.png#pic_center

其中,N为batch size的大小,C为特征图通道数,H为特征图高,W为特征图宽。

LN层的计算为: https://i-blog.csdnimg.cn/direct/35de13c777a2412597354e38a6ca8857.png#pic_center

实质上,LN是针对layer维度进行标准化,在C,H,W上进行归一化,也就是与batch无关,执行完有B个均值,B个方差。每个样本共用相同的均值和方差。

RMS Norm是一种简化版的层归一化,它是通过均方根(Root Mean Square)计算得到归一化尺度,不需要计算均值和标准差。具体如下:

https://i-blog.csdnimg.cn/direct/3a378d1c54e34a1b824674826604d762.png#pic_center

与层归一化相比,RMSNorm不需要计算均值和标准差,减少了计算复杂度,提高了计算的效率,同时保持了良好的归一化效果。

Llama 3采用的是SwiGLU(Swish-Gated Linear Unit)激活函数,SwiGLU激活函数是一种结合Swish激活函数与GLU(Gated Linear Unit)机制的激活函数。

2.1 常用激活函数

其实,Swish激活函数与GLU激活函数都不是最常见的激活函数,在transformer结构中常用的激活函数是ReLU(Rectified Linear Unit)、GELU(Gaussian Error Linear Unit)等。在GPT3中采用的就是GELU激活函数,他们的形式如下图所示:

https://i-blog.csdnimg.cn/direct/1383818512844aa8bdf5c853df9831a5.png#pic_center

2.2 Swish激活函数

Swish激活函数是一种平滑且连续的激活函数,在Transformer等模型中应用广泛。具体计算公式如下所示:

Swish(x) = x * sigmoid(βx)

其中 sigmoid(x) = 1 / (1 + exp(-x)), β 是一个可学习的参数或一个常数。

通常 β 设置为 1,即 Swish(x) = x * sigmoid(x),此时的激活函数通常称为SiLU,有没有熟悉的感觉。当 β = 0 时,Swish 变为线性函数 x/2。 当 β 趋近于无穷大时,Swish 接近 ReLU。 因此,Swish 可以看作是 ReLU 和线性函数之间的插值。Swish与RuLU的对比如下所示:

https://i-blog.csdnimg.cn/direct/a217377b024b469aad275662ed79f2e7.png#pic_center

2.3 GLU(Gated Linear Unit)激活函数

GLU是一种强大的激活函数,它通过引入门控机制来动态地控制信息流。 它在许多任务中都表现出色,尤其是在需要处理复杂依赖关系的场景中。具体计算函数如下:

GLU(x) = (W_a * x + a) ⊗ σ(W_b * x +b)

其中,W_a, W_b 是权重矩阵; a, b 是偏置向量;σ 是 sigmoid 函数。在 GLU 中,输入特征被分成两个部分:

  • 线性部分: (W_a * x + a)
  • 门控部分:σ(W_b * x +b)

具体表现形式如下所示:

https://i-blog.csdnimg.cn/direct/987ec49cb3c54ded94d7e92a40bffab8.png#pic_center

2.4 SwiGLU激活函数

SwiGLU(Swish-Gated Linear Unit)是一种结合Swish激活函数与GLU(Gated Linear Unit)机制的激活函数。具体计算如下所示:

https://i-blog.csdnimg.cn/direct/d851b203fe324d06b56324de5074f94e.png#pic_center SwiGLU是对输入进行两次线性变换,将其中一个结果通过Swish激活函数,将两个结果逐元素相乘,形成最终输出。

SwiGLU在处理复杂的文本数据时表现出更好的泛化能力。由于引入了Swish激活函数,其平滑的曲线特性有助于稳定梯度,减少梯度消失和爆炸的可能性。

1、RMS Norm在模型中的作用仅仅是减少参数量提高效率这么简单吗?

答案当然是否定的,以下是我的一些思考,仅提供参考:

  • 中心化不是必须的 :LN 通过减去均值来中心化数据,使得数据分布以 0 为中心。 然而,对于某些任务,这种中心化可能并不是必需的。 神经网络可以通过后续的权重和偏置来学习合适的偏移。
  • 更关注方差/幅度 :RMSNorm 本质上是对输入的幅度进行归一化。 在很多情况下,幅度信息可能比中心化信息更重要。 例如,在语音识别中,信号的能量 (幅度) 通常是重要的特征。
  • 与优化器的结合 :RMSNorm 的设计理念与一些优化器 (如 Adam) 相似,它们都关注梯度的幅度。 这种一致性可能有助于提高训练的稳定性。

2、为什么采用Swish与GLU结合,而不是使用GELU与GLU结合?

从我的认知中,可能主要是基于以下两个方面的考虑:

  1. 效率考量 :虽然从函数的表现形式上来看,Swish与GELU基本差不多。但是GELU采用的主要使用高斯函数,当然,可以使用tanh和sqrt等进行相似计算。而Swish采用Sigmoid函数。从计算角度看,Swish的计算效率更高。
  2. 实验效果 :感觉当初开发Llama 3模型的时候,开发者应该考虑过GELU激活函数,但是最终选择 Swish + GLU 可能是因为在 Llama 3 的特定架构、数据集和训练设置下,这种组合取得了更好的实验结果。 深度学习模型的选择很多时候是经验性的,需要大量的实验来验证。