目录

PyTorchchapter-34transformer-6-RoPE

目录

【PyTorch][chapter-34][transformer-6] RoPE

前言:

本篇重点是介绍一下Positional Encoding 和  ROPE的区别

方案一: 固定位置编码:Positional Encoding

https://latex.csdn.net/eq?x_m%3Dx_m+P_m (m 词向量的位置)

https://i-blog.csdnimg.cn/direct/61159fa8b7b84e799641f0e65fd03eda.png

方案二: 旋转位置编码: LLAMA 采用了ROPE

https://latex.csdn.net/eq?x_m%3Dx_m%20e%5E%7Bi%5Ctheta_%7Bm%7D%7D

https://i-blog.csdnimg.cn/direct/85334124a234483290f7708579d36169.png


目录:

  1. 预备知识  self-attention
  2. 绝对位置编码
  3. ROPE 简介
  4. ROPE 数学原理
  5. PyTorch 代码实现

一  预备知识( self-attention

1  Transformer 计算过程

https://i-blog.csdnimg.cn/direct/5f565a9d37ef45b98e5a9eebb9bfb003.png

我们假设 https://latex.csdn.net/eq?S_N 是输入的token序列.其词向量编码用 https://latex.csdn.net/eq?E_N%3DX 表示

其中 https://latex.csdn.net/eq?x_i%5Cin%20R%5Ed 代表词向量(未包含位置信息)。

self-attention 机制首先会将位置信息融 入到词嵌入中,然后将它们转换成查询(queries)、键(keys)和值(values)表示。

https://i-blog.csdnimg.cn/direct/05f85dd0b4714b60a8db2ca2bde36ba9.png

其中:

f: 表示位置编码函数

https://latex.csdn.net/eq?q_m%2Ck_n%2Cv_n :表示融入了第m,n个位置的词向量.

query 和key 用于计算attention weights.

attention score用 公式(2)表示:

https://i-blog.csdnimg.cn/direct/670607943e01468ea5551c6b73fa0dbd.png

输出

https://i-blog.csdnimg.cn/direct/91f7d786bdc04bc9897fa20d876dea61.png

2   问题:

如下图两个句子:

https://i-blog.csdnimg.cn/direct/ff650112fa0544c3824af68d57048b5b.png

dog 和 pig  之间的attention weight:

第一个句子 dog    m=2   pig n=5

第二个句子  dog   m=5   pig n=2

如果不包含位置编码,我们发现计算出来的attention weights是一样的,但这两个

句子中单词的意义完全是相反的.  因为 https://latex.csdn.net/eq?x_m%2Cx_n 里面没有位置信息

https://latex.csdn.net/eq?q_m%5ETk_n%3D%28x_mW_q%29%5ET%28x_nW_k%29%3DW_q%5ET%28x_m%5ETx_n%29W_k

以transformer 为基础的模型 解决方法是选择一个合适的位置编码函数f


**二

绝对位置编码**

https://i-blog.csdnimg.cn/direct/bb1c9e2a78d94ce1bf8d353b0ac21dfd.png

公式表示:

https://i-blog.csdnimg.cn/direct/053f0902cd4b41e7b22337c07c1e3239.png

m 为词向量的位置, https://latex.csdn.net/eq?p_m 词向量的编码,常用正余弦编码如下(pos 对应m)

https://i-blog.csdnimg.cn/direct/8d77f96a9d9245299d7d6a694fa8d419.png

https://latex.csdn.net/eq?x_m%3Df%28x%2Cm%29%3Dx_m+p_m

其中
https://latex.csdn.net/eq?p_m
代表位置编码

如下图第一个句子,dog的位置m =2

https://i-blog.csdnimg.cn/direct/ff650112fa0544c3824af68d57048b5b.png

2:  问题

主要问题:

(1) 无法捕捉相对位置关系。

如下  我们计算 i 和 dog  之间的 attention score

https://i-blog.csdnimg.cn/direct/df37c947ea034e8eb8c0b762356e9c38.png

在第一个句子中: i: m= 3 ,dog n= 6

在第一个句子中: i: m= 1 ,dog n= 4

绝对位置变化,通过公式(2) 计算出来attention score 必然是不一样的, 但是

这两个句子本质上是一样的,所以我们需要一种相对位置编码。

(2) 最长长度限制

  • 问题 :在某些模型中(如BERT和GPT-2),位置编码的长度被限制在固定范围内(如BERT限制在512,GPT-2限制在1024)。这限制了模型能够处理的文本长度。
  • 影响 :最长长度限制可能导致模型在处理超过其能力范围的文本时性能下降,或需要额外的处理步骤来适应更长的文本。

3:代码实现

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 28 14:53:30 2025

@author: chengxf2
"""

import torch
import math
import torch.nn as nn
import matplotlib.pyplot as plt

def showEncoding(pe):
    plt.figure(figsize=( 12 , 8 )) 
    plt.imshow(pe, cmap= 'coolwarm' , vmin=- 1 , vmax= 1 ) 
    plt.colorbar() 
    plt.title( '正弦位置编码' ) 
    plt.xlabel( '维度' ) 
    plt.ylabel( '位置' ) 
    plt.tight_layout() 
    plt.show() 
    
    plt.figure(figsize=( 12 , 6 )) 
    Dimensions = [ 0 , 21 ] 
    for d in Dimensions: 
        plt.plot(pe[:, d], label= f'Dim {d} ' ) 
    plt.legend() 
    plt.title( '特定维度的正弦位置编码' ) 
    plt.xlabel( '位置' ) 
    plt.ylabel( '值' ) 
    plt.tight_layout() 
    plt.show()

class PositionalEncoding(nn.Module):
      def __init__(self, d_model=128, max_seq_len=100):
         super(PositionalEncoding,self).__init__() 
         #创建一个位置编码矩阵
         pe = torch.zeros(max_seq_len,d_model)
         position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
         pe[:,0::2] =  torch.sin(position*div_term)
         pe[:,1::2] =  torch.cos(position*div_term)
         self.register_buffer("pe", pe)
      
      def forward(self, x):
          
          batch_size,seq_len, d_model = x.shape
          
          x = x+self.pe[:,:seq_len]
          return x

    

net = PositionalEncoding()

pe = net.pe.numpy()
showEncoding(pe)

三     ROPE(Rotary Position Embedding,旋转位置编码)

一般都是作用再query,key上面,不是直接作用在词向量上

3.1 公式

基于Transformer的语言模型通常利用各个标记(token)的位置信息实现自注意力机制如方程(2)所示, https://latex.csdn.net/eq?q_m%5ETk_n 通常能够实现不同位置标记之间的知识传递。 为了融入相对位置信息, 我们需要让查询向量 https://latex.csdn.net/eq?q_m 和键向量 https://latex.csdn.net/eq?k_n 的内积通过一个函数 g 来计算.该函数仅将词嵌入 https://latex.csdn.net/eq?x_m%2Cx_n 和它们的相对位置 m - n 作为输入变量。换句话说, 我们希望内积仅以相对形式编码位置信息:

https://i-blog.csdnimg.cn/direct/cfb58cff1ea54866b0ee86a82d31b22f.png

最终目标是找到一个等效的编码机制,以解决函数 https://latex.csdn.net/eq?f_q%28x_m%2C%20m%29https://latex.csdn.net/eq?f_k%28x_n%2C%20n%29 ,以符合上述关系

3.2   Rotary position embedding- A 2D case

我们从维度d=2的一个简单情况开始。在这些设定下,我们利用二维平面上向量的几何性质及其复数形式来证明( 参考备注1: ),我们的公式方程(4)的一个解是:

https://i-blog.csdnimg.cn/direct/63616c0f438546f69e4c75326aed6fdf.png

其中,Re[·]表示复数的实部, https://latex.csdn.net/eq?%28W_kx_n%29%5E* 表示 ( https://latex.csdn.net/eq?W_kx_n ) 的共轭复数。θ ∈ R 是一个预设的非零常数。我们可以进一步将 f{q,k} 表示为乘法矩阵的形式:

备注1:

推导过程:

https://i-blog.csdnimg.cn/direct/d0d3e3ca08ec48d2a81da56b6492d6d8.png


四    ROPE 矩阵表示

1  2D 矩阵表示

https://i-blog.csdnimg.cn/direct/621f1b288994461f81ba61744495ed81.png

https://i-blog.csdnimg.cn/direct/8d02a5bdcdb14794830d2e55e0c1fe00.png

2   通用的数学形式

https://i-blog.csdnimg.cn/direct/5b795cf98d6e4a4984b9dab5387348bd.png

https://i-blog.csdnimg.cn/direct/383319ea7207476f95d1dd1b4cf83c58.png

— we divide the d-dimension (embedding dimension) space into d/2 sub-spaces and apply rotations individually.

https://i-blog.csdnimg.cn/direct/1ad485e9170e4e238e948523d64ee34d.png

https://i-blog.csdnimg.cn/direct/e93b8e7a1fbd4f539fa7341cb827e58c.png

其中:

https://i-blog.csdnimg.cn/direct/f729ea28ef4e4a07931c2cd9bbb66852.png

3  Efficient Implementation(使用哈达码积)

考虑原矩阵的稀疏性,可以用下面矩阵实现

https://i-blog.csdnimg.cn/direct/6f380e232a41471f80f0fafebea3e169.png


五  PyTorch 代码实现

https://i-blog.csdnimg.cn/direct/a8965b37a727412098488540471da3e4.png

sub-space 划分的方式不一样.论文里面是把 ![[x_1,x_2],[[x_3,x_4],[x_5,x_6]. 作为一个sub-space,

也有部分代码是把 https://latex.csdn.net/eq?%5Bx_1%2Cx_5%5D%2C%5Bx_2%2Cx_6%5D%2C%5Bx_3%2Cx_7%5D%2C%5Bx_4%2Cx_8%5D 作为一个sub-space,本质上是一致的.

# -*- coding: utf-8 -*-
"""
Created on Mon Mar 10 09:58:43 2025

@author: chengxf2
"""

import torch

class RoPE(torch.nn.Module):
    #RotaryEmbedding
    def __init__(self, max_seq_len=1000, base=10000, head_dim=64):
        super().__init__()
        #词向量的维度为d. 两两为一组,sub-space: theta为d/2
        slice_i = torch.arange(0, head_dim, 2).float()
        theta = torch.pow(base, -slice_i/head_dim)
        self.register_buffer('theta', theta)
        self.max_seq_len = max_seq_len
        position = torch.arange(0, self.max_seq_len)
        #不同位置对应不同的旋转角度。克罗克内积 outer product
        angle = torch.einsum("i,j->ij",position,theta)
        #[x1,x2] 共用一个旋转角度
        angles = torch.cat((angle,angle),dim=-1)
        #[batch_size, num_heads, seq_len, head_dim]
        self.register_buffer("cos_cached", angles.cos()[None, None, :, :])
        self.register_buffer("sin_cached", angles.sin()[None, None, :, :])
        
    
    
    def forward(self, x):
        batch_size, num_head, seq_len, head_dim = x.shape
        cos = self.cos_cached[:,:,:seq_len, ...]
        sin = self.sin_cached[:,:,:seq_len, ...]
        embedding = self.apply_rotary_pos_emb(x,cos, sin)
        return embedding
    
    def transform_tensor(self,x):
        if x.size(-1) % 2 != 0:
            raise ValueError("最后一维的长度必须是偶数")

        # 将最后一维分成偶数位和奇数位
        # 偶数位 [x2, x4, x6, ...]
        even_indices = x[..., 1::2]  
        # 奇数位 [x1, x3, x5, ...]
        odd_indices =  x[..., 0::2]  
        # 对偶数位取反
        even_indices_neg = -even_indices  # [-x2, -x4, -x6, ...]
        # 交叉合并结果
        result = torch.stack((even_indices_neg, odd_indices), dim=-1)
        # 将最后两维展平
        result = result.view(*x.shape[:-1], -1)
        return result
        
    
    def apply_rotary_pos_emb(self, x, cos, sin):
        
        cosEmb = x*cos
        x_transformer = self.transform_tensor(x)
        sinEmb = x_transformer*sin
        x_emb = cosEmb+sinEmb
        return x_emb
         

batch_size = 2
num_heads = 8
seq_len = 5
d_model = 64
q = torch.randn((batch_size,num_heads,seq_len, d_model))
embddding  = RoPE()
q_emb =  embddding(q)
print(q_emb.shape)