仅仅使用pytorch来手撕transformer架构1位置编码的类的实现和向前传播
仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播
如果对transformer架构原理模糊的话可以看
P E ( p o s , 2 i )
sin ( p o s 1000 0 2 i / d model ) , PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right),
P
E
(
p
os
,
2
i
)
=
sin
(
1000
0
2
i
/
d
model
p
os
)
,
P E ( p o s , 2 i + 1 )
cos ( p o s 1000 0 2 i / d model ) . PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right).
P
E
(
p
os
,
2
i
1
)
=
cos
(
1000
0
2
i
/
d
model
p
os
)
.
其中,pos表示位置,d_model表示模型的维度,一般设置为512。
PositionalEncoding
类是 Transformer 模型中一个非常重要的组件,它的作用是为输入的嵌入向量(embedding)添加位置信息。Transformer 是一个基于序列的模型,但它本身并不像 RNN 那样能够自然地处理序列中的位置信息。因此,需要通过某种方式将位置信息注入到模型中,
PositionalEncoding
就是实现这一功能的关键部分。
代码解析
1. 初始化方法 __init__
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
d_model
:表示嵌入向量的维度,即每个位置编码的大小。dropout
:用于正则化,防止过拟合。max_len
:表示模型能够处理的最大序列长度。
1.1生成位置编码
位置编码的生成基于论文《Attention Is All You Need》中提出的方法。具体来说,位置编码是一个固定大小的向量(维度为
d_model
),它通过正弦和余弦函数生成。这种编码方式能够捕捉到位置信息,并且可以处理比训练时序列长度更长的序列。
torch.zeros(max_len, d_model)
:创建一个形状为(max_len, d_model)
的零矩阵,用于存储位置编码。torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
:生成一个从 0 到max_len-1
的序列,并将其转换为浮点数,然后通过unsqueeze(1)
将其扩展为二维张量,形状为(max_len, 1)
。这表示每个置的索引。torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
:计算分母部分,用于调整正弦和余弦函数的频率。torch.arange(0, d_model, 2)
表示每隔一个维度取一个值(因为d_model
是偶数),-math.log(10000.0) / d_model
是一个缩放因子,用于控制频率的变化范围。pe[:, 0::2] = torch.sin(position * div_term)
和pe[:, 1::2] = torch.cos(position * div_term)
:分别将正弦和余弦函数的值赋值到位置编码矩阵的偶数和奇数位置上。这样,每个位置编码的偶数维度由正弦函数生成,奇数维度由余弦函数生成。pe = pe.unsqueeze(0).transpose(0, 1)
:将位置编码矩阵的形状调整为(max_len, 1, d_model)
,然后转置为(1, max_len, d_model)
,以便在后续操作中可以直接与输入张量相加。self.register_buffer('pe', pe)
:将位置编码矩阵注册为一个缓冲区(buffer),这样它不会被视为模型的参数,但会在模型的state_dict
中保存,并且会随着模型的移动(如从 CPU 移动到 GPU)而自动移动。在代码中,位置编码(Positional Encoding)的生成是基于论文《Attention Is All You Need》中提出的公式。具体公式如下:
对于每个位置 ( pos ) 和每个维度 ( i ),位置编码 ( PE(pos, i) ) 的计算公式为:
P E ( p o s , i )
{ sin ( p o s 1000 0 2 i / d m o d e l ) if i is even cos ( p o s 1000 0 2 i / d m o d e l ) if i is odd PE(pos, i) = \begin{cases} \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) & \text{if } i \text{ is even} \ \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) & \text{if } i \text{ is odd} \end{cases}
PE
(
p
os
,
i
)
=
⎩
⎨
⎧
sin
(
1000
0
2
i
/
d
m
o
d
e
l
p
os
)
cos
(
1000
0
2
i
/
d
m
o
d
e
l
p
os
)
if
i
is even
if
i
is odd
其中:
p o s pos
p
os 是位置索引(从 0 开始)。
i i
i 是维度索引(从 0 开始)。
d m o d e l d_{model}
d
m
o
d
e
l
是嵌入向量的维度。
10000 10000
10000 是一个常数,用于控制频率的变化范围。
1.2代码与公式对应关系
在代码中,位置编码的生成过程与上述公式一一对应:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
position
:position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
这里生成了一个从 0 到
max_len-1
的序列,表示每个位置的索引 ( pos )。unsqueeze(1)
将其扩展为二维张量,形状为(max_len, 1)
,方便后续的广播操作。div_term
:div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
这里计算了公式中的分母部分:
1 1000 0 2 i / d m o d e l \frac{1}{10000^{2i/d_{model}}}
1000
0
2
i
/
d
m
o
d
e
l
1
其中:
torch.arange(0, d_model, 2)
表示每隔一个维度取一个值(因为偶数维度用正弦,奇数维度用余弦)。(-math.log(10000.0) / d_model)
是公式中的指数部分。torch.exp(...)
计算 ( e^{x} ),从而得到分母的值。
pe[:, 0::2]
和pe[:, 1::2]
:pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
这里分别计算了偶数维度和奇数维度的位置编码:
pe[:, 0::2]
对应公式中的正弦部分:P E ( p o s , i )
sin ( p o s 1000 0 2 i / d m o d e l ) (for even i ) PE(pos, i) = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \quad \text{(for even } i\text{)}
PE
(
p
os
,
i
)
=
sin
(
1000
0
2
i
/
d
m
o
d
e
l
p
os
)
(for even
i
)
pe[:, 1::2]
对应公式中的余弦部分:P E ( p o s , i )
cos ( p o s 1000 0 2 i / d m o d e l ) (for odd i ) PE(pos, i) = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \quad \text{(for odd } i\text{)}
PE
(
p
os
,
i
)
=
cos
(
1000
0
2
i
/
d
m
o
d
e
l
p
os
)
(for odd
i
)
1.3max_len的选择
在代码中,
max_len
表示模型能够处理的最大序列长度。这个值是预先设定的,用于定义位置编码矩阵的大小。在实际应用中,
max_len
应该足够大,以覆盖你预期中可能遇到的最长序列。
1.3.1 代码中的 max_len
pe = torch.zeros(max_len, d_model)
这里,
pe
是一个形状为
(max_len, d_model)
的零矩阵,用于存储位置编码。
max_len
是这个矩阵的行数,表示可以处理的最大序列长度。
1.3.2选择 max_len
的值
在实际应用中,选择
max_len
的值需要考虑以下因素:
- 数据集的特性
:如果你的数据集中序列的长度变化很大,那么
max_len
应该足够大,以覆盖最长的序列。 - 模型的容量
:较大的
max_len
会增加模型的参数量,因为位置编码矩阵的大小与max_len
直接相关。这可能会增加模型的复杂度和训练时间。 - 计算资源
:较大的
max_len
会占用更多的内存和计算资源,特别是在使用 GPU 进行训练时。
1.3.3 实际应用
在实际应用中,
max_len
的值通常根据数据集的统计特性来确定。例如,如果数据集中 99% 的序列长度都小于 100,那么可以选择
max_len = 100
。这样可以确保模型能够处理大多数序列,同时避免不必要的计算开销。
1.3.4总结
max_len
是一个重要的超参数,它定义了模型能够处理的最大序列长度。选择合适的
max_len
值需要考虑数据集的特性、模型的容量和计算资源。在实际应用中,可以通过分析数据集的统计特性来确定
max_len
的值。
2. 前向传播方法 forward
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
x
:输入的嵌入向量,形状为(seq_len, batch_size, d_model)
。self.pe[:x.size(0), :]
:根据输入序列的实际长度(x.size(0)
),从预定义的位置编码矩阵中取出相应长度的位置编码。x + self.pe[:x.size(0), :]
:将位置编码加到输入的嵌入向量上。由于位置编码的形状为(seq_len, 1, d_model)
,而输入嵌入向量的形状为(seq_len, batch_size, d_model)
,PyTorch 会自动进行广播操作,将位置编码应用到每个样本上。self.dropout(x)
:对加了位置编码后的嵌入向量应用 Dropout,以防止过拟合。
完整代码:
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
总结
PositionalEncoding
类的作用是为输入的嵌入向量添加位置信息,使得 Transformer 模型能够感知序列中每个元素的位置。位置编码通过正弦和余弦函数生成,能够捕捉到位置信息,并且可以处理比训练时序列长度更长的序列。