从Online-Softmax到FlashAttention
从Online Softmax到FlashAttention
前言
最近在学习 FlashAttention,看到一份不错的手稿分享下🤗
manuscript:
0. Abstract
FlashAttention 的关键创新是使用类似于 Online Softmax 的思想来 tile 分块 self-attention 的计算,这样可以融合整个多头注意力层,而无需多次重复访问 GPU 的全局内存来获取临时变量和注意力分数矩阵
A A
A 。本文将简要解释如何 tiling 分块 self-attention 的计算,以及如何从 online softmax 中推导出 FlashAttention 计算
1. The Self-Attention
Self-Attention 的计算可以描述为:
O
s o f t m a x ( Q K T ) V \begin{equation} O = \mathrm{softmax}\left(QK^T\right)V \end{equation}
O
=
softmax
(
Q
K
T
)
V
其中
Q , K , V , O ∈ R L × D Q,K,V,O \in \mathbb{R}^{L\times D}
Q
,
K
,
V
,
O
∈
R
L
×
D ,
L L
L 是序列长度,
D D
D 是每个头的维度,softmax 将按列应用于
Q K T QK^T
Q
K
T
Note :这里我们忽略了多头、多 batch,因为在这些维度上的计算是完全并行的,也就是说我们只关注单头、单 batch 的计算就行。另外为了简单起见,我们也忽略了注意力掩码以及缩放因子
1 D \frac{1}{\sqrt{D}}
D
1
计算 self-attention 的标准方法是将其分解为以下几个阶段:
X
Q K T A
s o f t m a x ( X ) O
A V \begin{align} X & = QK^T \ A & = \mathrm{softmax}(X) \ O & = AV \end{align}
X
A
O
=
Q
K
T
=
softmax
(
X
)
=
A
V
我们称
X X
X 矩阵为 pre-softmax logits,
A A
A 矩阵为注意力分数(attention score),
O O
O 矩阵为输出
FlashAttention 的一个惊人之处在于,我们不需要在全局内存(global memory)上实现
X X
X 和
A A
A 矩阵,而是将公式 (1) 中的整个计算融合到单个 CUDA kernel 中。这要求我们设计一种能够精心管理片上内存(on-chip memory)的算法,因为 NVIDIA GPU 的共享内存(shared memory)非常小
对于矩阵乘法等经典算法,我们通常会使用 tiling 技术来确保片上内存不超过硬件限制。下图提供了一个例子,在 kernel 执行期间,无论矩阵形状如何,只有
3 T 2 3T^2
3
T
2 个元素存储在片上。这种 tiling 方法之所以有效,是因为加法是关联的,允许将整个矩阵乘法分解为许多 tile-wise 矩阵乘法的总和
然而,Self-Attention 包含一个不直接关联的 softmax 运算符,因此很难像下图那样简单地对 Self-Attention 进行 tile,那有没有办法让 softmax 具有关联性呢?🤔
上图简要说明了如何对矩阵乘法
C
A × B C=A\times B
C
=
A
×
B 的输入和输出矩阵进行 tile(分块),矩阵被划分为
T × T T\times T
T
×
T 个小块。对于每个输出块,我们从左到右遍历
A A
A 中的相关块,从上到下遍历
B B
B 中的相关块,并将它们的值从全局内存(global memory)加载到片上内存(on-chip memory)(以蓝色表示,整体片上内存占用为
O ( T 2 ) O(T^2)
O
(
T
2
) )
接着逐块进行矩阵乘法,对于位置
( i , j ) (i,j)
(
i
,
j
) ,我们先从片上内存中加载块内所有
i i
i 行的元素即
A [ i , k ] A[i,k]
A
[
i
,
k
] 以及所有
j j
j 列的元素即
B [ k , j ] B[k,j]
B
[
k
,
j
] (以红色表示),然后利用
A [ i , k ] A[i,k]
A
[
i
,
k
] 和
B [ k , j ] B[k,j]
B
[
k
,
j
] 计算得到片上内存中的
C [ i , j ] C[i,j]
C
[
i
,
j
] 。以此循环,直到完成一个块的计算后,我们再将片上
C C
C 块的结果写回主内存并继续处理下一个块
实际 tiling 的应用要复杂得多,大家可以参考:
2. (Safe) Softmax
我们先回顾一下 softmax 算子,下面是 softmax 计算的通用公式:
s o f t m a x ( { x 1 , … , x N } )
{ e x i ∑ j
1 N e x j } i
1 N \begin{equation} \mathrm{softmax}({x_1,\ldots,x_N})=\left{\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}}\right}_{i=1}^N \end{equation}
softmax
({
x
1
,
…
,
x
N
})
=
{
∑
j
=
1
N
e
x
j
e
x
i
}
i
=
1
N
值得注意的是,当
x i x_i
x
i
比较大时
e x i e^{x_i}
e
x
i
很容易溢出,例如 float16 可以支持的最大值是 65536,这意味着当
x ⩾ 11 x \geqslant 11
x
⩾
11 时,
e x e^x
e
x 将超过 float16 的有效范围
为了解决这个问题,像 pytorch、tensorflow 等框架通常会使用一种被称为 “safe” 的 softmax 计算方法:
e x i ∑ j
1 N e x j
e x i − m ∑ j
1 N e x j − m \begin{equation} \frac{e^{x_{i}}}{\sum_{j=1}^{N}e^{x_{j}}}=\frac{e^{x_{i}-m}}{\sum_{j=1}^{N}e^{x_{j}-m}} \end{equation}
∑
j
=
1
N
e
x
j
e
x
i
=
∑
j
=
1
N
e
x
j
−
m
e
x
i
−
m
其中
m
max j
1 N ( x j ) m=\max_{j=1}^{N}(x_j)
m
=
max
j
=
1
N
(
x
j
)
我们将
x i x_i
x
i
减去它们中的最大值,这样可以确保每个元素
x i − m ⩽ 0 x_i-m \leqslant 0
x
i
−
m
⩽
0 ,从而再做指数运算时可以确保其值不会溢出
我们可以将 safe softmax 的计算总结为下面的 3-pass 算法:
Algorithm 3-pass safe softmax
NOTATIONS
{ m i } {m_i}
{
m
i
} :
max j
1 i { x j } \max_{j=1}^i{x_j}
max
j
=
1
i
{
x
j
} ,初始值
m 0
− ∞ m_0=-\infty
m
0
=
−
∞
{ d i } {d_i}
{
d
i
} :
∑ j
1 i e x j − m N \sum_{j=1}^{i}e^{x_{j}-m_{N}}
∑
j
=
1
i
e
x
j
−
m
N
,初始值
d 0
0 d_0=0
d
0
=
0 ,
d N d_N
d
N
是 safe softmax 的分母
{ a i } {a_i}
{
a
i
} :最终的 softmax 值
BODY
for
i ← 1 , N i\leftarrow1,N
i
←
1
,
N do
m i ← max ( m i − 1 , x i ) \begin{equation} m_i \leftarrow \max (m_{i-1},x_i) \end{equation}
m
i
←
max
(
m
i
−
1
,
x
i
)
end
for
i ← 1 , N i\leftarrow1,N
i
←
1
,
N do
d i ← d i − 1 + e x i − m N \begin{equation} d_i \leftarrow d_{i-1} + e^{x_i-m_N} \end{equation}
d
i
←
d
i
−
1
e
x
i
−
m
N
end
for
i ← 1 , N i\leftarrow1,N
i
←
1
,
N do
a i ← e x i − m N d N \begin{equation} a_i \leftarrow \frac{e^{x_i-m_N}}{d_N} \end{equation}
a
i
←
d
N
e
x
i
−
m
N
end
该算法要求我们对
[ 1 , N ] [1,N]
[
1
,
N
] 进行 3 次迭代,在 Transformer 的 self-attention 中,每个
x i x_i
x
i
是通过
Q K T QK^T
Q
K
T 计算得到的。如果我们无法将所有的
x i x_i
x
i
都缓存到片上内存(SRAM)中(实际上我们也没有足够大的片上内存 SRAM 来容纳所有的
x i x_i
x
i
),那么在每次迭代时不得不访问
Q , K Q,K
Q
,
K 以动态重新计算
x i x_i
x
i
,这种频繁的访问会导致大量的 I/O 操作,从而降低整体效率
3. Online Softmax
如果我们能够在一个循环中融合公式 (7)、(8)、(9),那么就可以将 global memory 的访问时间降低 3 倍,不幸的是,我们不能在同一个循环中融合公式 (7)、(8),因为公式 (8) 的计算需要依赖
m N m_N
m
N
,而
m N m_N
m
N
只有在第一个循环完成之后才能够确定
既然数据之间存在着依赖关系,那我们可以创建另外一个序列
d i ′ :
∑ j
1 i e x j − m i d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{i}}
d
i
′
:=
∑
j
=
1
i
e
x
j
−
m
i
作为原始序列
d i ′ :
∑ j
1 i e x j − m N d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{N}}
d
i
′
:=
∑
j
=
1
i
e
x
j
−
m
N
的替代,以消除对
N N
N 的依赖,并且这两个序列的第
N N
N 项是相同的即
d N
d N ′ d_N=d^{\prime}_N
d
N
=
d
N
′
,因此我们可以安全地用
d N ′ d^{\prime}_N
d
N
′
替换公式 (9) 中的
d N d_N
d
N
Note :在数学和计算机科学中,符号
:
:=
:= 用于表示 “定义为” 或 “被定义为”,例如当我们写下
x :
y x:=y
x
:=
y 时,我们是在说明 “我们将
x x
x 定义为
y y
y ” 或 “
x x
x 等于
y y
y (这是它的定义)”
此外,我们还可以找到
d i ′ d^{\prime}_i
d
i
′
和
d i − 1 ′ d^{\prime}_{i-1}
d
i
−
1
′
之间的递归关系:
d i ′
∑ j
1 i e x j − m i
( ∑ j
1 i − 1 e x j − m i ) + e x i − m i
( ∑ j
1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i
d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{equation} \begin{aligned} d_{i}^{\prime} & =\sum_{j=1}^ie^{x_j-m_i} \ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_i}\right)+e^{x_i-m_i} \ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \ & =d_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} \end{equation}
d
i
′
=
j
=
1
∑
i
e
x
j
−
m
i
=
(
j
=
1
∑
i
−
1
e
x
j
−
m
i
)
e
x
i
−
m
i
=
(
j
=
1
∑
i
−
1
e
x
j
−
m
i
−
1
)
e
m
i
−
1
−
m
i
e
x
i
−
m
i
=
d
i
−
1
′
e
m
i
−
1
−
m
i
e
x
i
−
m
i
这个递归公式只依赖于
m i m_i
m
i
和
m i − 1 m_{i-1}
m
i
−
1
,此外我们还可以在一个循环中同时计算
m j m_j
m
j
和
d j ′ d^{\prime}_j
d
j
′
Algorithm 2-pass online softmax
for
i ← 1 , N i\leftarrow 1,N
i
←
1
,
N do
m i ← max ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} m_i & \leftarrow \max (m_{i-1}, x_i) \ d^{\prime}i & \leftarrow d^{\prime}{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned}
m
i
d
i
′
←
max
(
m
i
−
1
,
x
i
)
←
d
i
−
1
′
e
m
i
−
1
−
m
i
e
x
i
−
m
i
end
for
i ← 1 , N i\leftarrow 1,N
i
←
1
,
N do
a i ← e x i − m N d N ′ a_i \leftarrow \frac{e^{x_i - m_N}}{d^{\prime}_N}
a
i
←
d
N
′
e
x
i
−
m
N
end
这是 Online Softmax 论文中提出的算法,但它仍然需要两次循环才能完成 softmax 的计算,我们能否将循环次数减少到 1 次来最小化全局的 I/O 呢?🤔
4. FlashAttention
不幸的是,对于 softmax 来说,答案是否定的,但在 Self-Attention 中,我们最终的目标并不是注意力分数矩阵
A
s o f t m a x ( Q K T ) A=\mathrm{softmax}(QK^T)
A
=
softmax
(
Q
K
T
) ,而是输出矩阵
O
A × V O=A \times V
O
=
A
×
V ,既然无法找到 softmax 的一次递归形式,那我们换个思路思考下能否找到矩阵
O O
O 的一次递归形式呢?
下面我们尝试下先将 Self-Attention 计算的第
k k
k 行公式转化为递归算法:
Note :由于所有行的计算都是独立的,为了简单起见,这里我们只解释一行的计算
Algorithm Multi-pass Self-Attention
NOTATIONS
Q [ k , : ] Q[k,:]
Q
[
k
,
:
] :查询矩阵
Q Q
Q 的第
k k
k 行向量
K T [ : , i ] K^T[:,i]
K
T
[
:
,
i
] :键矩阵
K T K^T
K
T 的第
i i
i 列向量
O [ k , : ] O[k,:]
O
[
k
,
:
] :输出矩阵
O O
O 的第
k k
k 行
V [ i , : ] V[i,:]
V
[
i
,
:
] :值矩阵
V V
V 的第
i i
i 行
o i \bm{o}_i
o
i
:
∑ j
1 i a j V [ j , : ] \sum_{j=1}^ia_jV[j,:]
∑
j
=
1
i
a
j
V
[
j
,
:
] ,存储
A [ k , : i ] × V [ : i , : ] A[k,:i]\times V[:i,:]
A
[
k
,
:
i
]
×
V
[
:
i
,
:
] 结果的行向量
BODY
for
i ← 1 , N i\leftarrow 1,N
i
←
1
,
N do
x i ← Q [ k , : ] K T [ : , i ] m i ← max ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} x_i & \leftarrow Q[k,:]K^T[:,i] \ m_i & \leftarrow \max (m_{i-1},x_i) \ d^{\prime}i & \leftarrow d^{\prime}{i-1} e^{m_{i-1}-m_i} + e^{x_i-m_i} \end{aligned}
x
i
m
i
d
i
′
←
Q
[
k
,
:
]
K
T
[
:
,
i
]
←
max
(
m
i
−
1
,
x
i
)
←
d
i
−
1
′
e
m
i
−
1
−
m
i
e
x
i
−
m
i
end
for
i ← 1 , N i\leftarrow 1,N
i
←
1
,
N do
a i ← e x i − m N d N ′ o i ← o i − 1 + a i V [ i , : ] \begin{align} a_i & \leftarrow \frac{e^{x_i-m_N}}{d^{\prime}_N} \ \bm{o}i & \leftarrow \bm{o}{i-1} + a_iV[i,:] \end{align}
a
i
o
i
←
d
N
′
e
x
i
−
m
N
←
o
i
−
1
a
i
V
[
i
,
:
]
end
O [ k , : ] ← o N O[k,:] \leftarrow \bm{o}_N
O
[
k
,
:
]
←
o
N
我们将公式 (12) 中的
a i a_i
a
i
替换公式 (11) 中的定义,有:
o i :
∑ j
1 i ( e x j − m N d N ′ V [ j , : ] ) \begin{align} \bm{o}i := \sum{j=1}^i\left(\frac{e^{x_j-m_N}}{d_N^{\prime}}V[j,:]\right) \end{align}
o
i
:=
j
=
1
∑
i
(
d
N
′
e
x
j
−
m
N
V
[
j
,
:
]
)
该公式仍然依赖
m N m_N
m
N
和
d N d_N
d
N
变量,这两个值在前一个循环完成之前无法确定。但我们可以再次使用第 3 节中的替代技巧,即创建替代序列
o ′ \bm{o}^{\prime}
o
′ :
o i ′ :
( ∑ j
1 i e x j − m i d i ′ V [ j , : ] ) \bm{o}i^{\prime}:=\left(\sum{j=1}^i\frac{e^{x_j-m_i}}{d_i^{\prime}}V[j,:]\right)
o
i
′
:=
(
j
=
1
∑
i
d
i
′
e
x
j
−
m
i
V
[
j
,
:
]
)
o \bm{o}
o 和
o ′ \bm{o}^{\prime}
o
′ 的第
N N
N 个元素相同即
o N ′
o N \bm{o}^{\prime}_N=\bm{o}_N
o
N
′
=
o
N
,并且我们可以发现
o i ′ \bm{o}^{\prime}_i
o
i
′
和
o i − 1 ′ \bm{o}^{\prime}_{i-1}
o
i
−
1
′
之间存在如下的递归关系:
o i ′
∑ j
1 i e x j − m i d i ′ V [ j , : ]
( ∑ j
1 i − 1 e x j − m i d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ]
( ∑ j
1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ]
( ∑ j
1 i − 1 e x j − m i − 1 d i − 1 ′ V [ j , : ] ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i , : ]
o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{equation} \begin{aligned} \bm{o}^{\prime}{i}& = \sum{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d^{\prime}{i}}V[j, : ]\ & = \left(\sum{j=1}^{i-1}\frac{e^{x_{j}-m_{i}}}{d^{\prime}{i} }V[j, : ]\right) + \frac{e^{x{i}-m_{i}}}{d^{\prime}{i}}V[i, : ]\ & = \left(\sum{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}{ i-1}}\frac{e^{x{j}-m_{i}}}{e^{x_{j}-m_{i-1}}}\frac{d^{\prime}{i-1}}{d^{ \prime}{i}}V[j, : ]\right) + \frac{e^{x_{i}-m_{i}}}{d^{\prime}{i}}V[i, : ]\ & = \left(\sum{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}{ i-1}}V[j, : ]\right) \frac{d^{\prime}{i-1}}{d^{\prime}{i}}e^{m{i-1}-m_{i}} + \frac{e^{x_{i}-m_{i}}}{d^{\prime}{i}}V[i, : ]\ & = \bm{o}^{\prime}{i-1}\frac{d^{\prime}{i-1} e^{m{i-1}-m_{ i}}}{d^{\prime}{i}} + \frac{e^{x{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ] \end{aligned} \end{equation}
o
i
′
=
j
=
1
∑
i
d
i
′
e
x
j
−
m
i
V
[
j
,
:
]
=
(
j
=
1
∑
i
−
1
d
i
′
e
x
j
−
m
i
V
[
j
,
:
]
)
d
i
′
e
x
i
−
m
i
V
[
i
,
:
]
=
(
j
=
1
∑
i
−
1
d
i
−
1
′
e
x
j
−
m
i
−
1
e
x
j
−
m
i
−
1
e
x
j
−
m
i
d
i
′
d
i
−
1
′
V
[
j
,
:
]
)
d
i
′
e
x
i
−
m
i
V
[
i
,
:
]
=
(
j
=
1
∑
i
−
1
d
i
−
1
′
e
x
j
−
m
i
−
1
V
[
j
,
:
]
)
d
i
′
d
i
−
1
′
e
m
i
−
1
−
m
i
d
i
′
e
x
i
−
m
i
V
[
i
,
:
]
=
o
i
−
1
′
d
i
′
d
i
−
1
′
e
m
i
−
1
−
m
i
d
i
′
e
x
i
−
m
i
V
[
i
,
:
]
它仅依赖于
d i ′ , d i − 1 ′ , m i , m i − 1 , x i d^{\prime}i,\ d^{\prime}{i-1}, \ m_i, \ m_{i-1}, \ x_i
d
i
′
,
d
i
−
1
′
,
m
i
,
m
i
−
1
,
x
i
,因此我们可以在单个循环中融合 Self-Attention 中的所有计算
Algorithm FlashAttention
for
i ← 1 , N i \leftarrow 1,N
i
←
1
,
N do
x i ← Q [ k , : ] K T [ : , i ] m i ← max ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i o i ′ ← o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{aligned} {x_{i}} & \leftarrow {Q[k,:]} {K^{T}[:,i]}\ {m_{i}} & \leftarrow \max\left( {m_{i-1},x_{i}} \right)\ {d^{\prime}{i}} & \leftarrow {d^{\prime}{i-1}e^{m_{i-1}-m_{i}}}+ {e^{x_{i}-m_{i}}}\ {o^{\prime}{i}} & \leftarrow {o^{\prime}{i-1}}\frac{ {d^{\prime}{i-1}e^{m{i-1}-m_{i}}}}{ { d^{\prime}{i}}}+\frac{ {e^{x{i}-m_{i}}}}{ {d^{\prime}_{i}}} {V[i,:]}\end{aligned}
x
i
m
i
d
i
′
o
i
′
←
Q
[
k
,
:
]
K
T
[
:
,
i
]
←
max
(
m
i
−
1
,
x
i
)
←
d
i
−
1
′
e
m
i
−
1
−
m
i
e
x
i
−
m
i
←
o
i
−
1
′
d
i
′
d
i
−
1
′
e
m
i
−
1
−
m
i
d
i
′
e
x
i
−
m
i
V
[
i
,
:
]
end
O [ k , : ] ← o N ′ O[k,:] \leftarrow \bm{o}^{\prime}_N
O
[
k
,
:
]
←
o
N
′
状态量
x i , m i , d i ′ , o i ′ x_i, \ m_i, \ d^{\prime}_i, \bm{o}^{\prime}_i
x
i
,
m
i
,
d
i
′
,
o
i
′
占用的内存都很小,可以非常轻松的放入 GPU 的 shared memory 中。另外由于此算法中的所有操作都是关联的,因此它与 tiling 兼容,如果我们逐个 tiling 计算,则该算法可以表示为:
Algorithm FlashAttention(Tiling)
NEW NOTATIONS
b b
b :tile 的 block size 大小
#tiles \text{#tiles}
#tiles :每行 tile 的数量,
N
b × #tiles N= b \times \text{#tiles}
N
=
b
×
#tiles
x i \bm{x}_i
x
i
:存储第
i i
i 个 tile 的
Q [ k ] K T Q[k]K^T
Q
[
k
]
K
T 值的向量
[ ( i − 1 ) b : i b ] [(i-1)b:ib]
[(
i
−
1
)
b
:
ib
]
m i ( l o c a l ) m_i^{\mathrm{(local)}}
m
i
(
local
)
:
x i \bm{x}_i
x
i
内部的局部最大值
BODY
for
i ← 1 , #tile i \leftarrow 1,\text{#tile}
i
←
1
,
#tile do
x i ← Q [ k , : ] K T [ : , ( i − 1 ) b : i b ] m i (local)
max j
1 b ( x i [ j ] ) m i ← max ( m i − 1 , m i (local) ) d i ′ ← d i − 1 ′ e m i − 1 − m i + ∑ j
1 b e x i [ j ] − m i o i ′ ← o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + ∑ j
1 b e x i [ j ] − m i d i ′ V [ j + ( i − 1 ) b , : ] \begin{split}\bm{x}{i}&\ \leftarrow\ Q[k, : ] K^{T}[:,(i-1) b : i b]\ m{i}^{\text{(local)}}&\ =\ \max_{j=1}^{b} (\bm{x}{i} [j]) \ m{i}&\ \leftarrow\ \max\big(m_{i-1},m_{i}^{\text{(local)}} \big)\ d_{i}^{\prime}&\ \leftarrow\ d_{i-1}^{\prime} e^{m_{i-1}-m_{i}} + \sum_{j=1}^{b} e^{\bm{x}{i}[j]-m{i}}\ \bm{o}{i}^{\prime}&\ \leftarrow\ \bm{o}{i-1}^{\prime} \frac{d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}}{d_{i}^{\prime}} + \sum_{j=1}^{b} \frac{e^{\bm{x}{i}[j]-m{i}}}{d_{i}^{\prime}}V[j+(i-1) b, : ]\end{split}
x
i
m
i
(local)
m
i
d
i
′
o
i
′
←
Q
[
k
,
:
]
K
T
[
:
,
(
i
−
1
)
b
:
ib
]
=
j
=
1
max
b
(
x
i
[
j
])
←
max
(
m
i
−
1
,
m
i
(local)
)
←
d
i
−
1
′
e
m
i
−
1
−
m
i
j
=
1
∑
b
e
x
i
[
j
]
−
m
i
←
o
i
−
1
′
d
i
′
d
i
−
1
′
e
m
i
−
1
−
m
i
j
=
1
∑
b
d
i
′
e
x
i
[
j
]
−
m
i
V
[
j
(
i
−
1
)
b
,
:
]
end
O [ k , : ] ← o N / b ′ O[k,:] \leftarrow \bm{o}^{\prime}_{N/b}
O
[
k
,
:
]
←
o
N
/
b
′
下图说明了如何将该算法应用到硬件上:
上图说明了 FlashAttention 在硬件上的计算方式,其中蓝色块表示在 SRAM 中的 tile,而红色块对应于 tile 中的第
i i
i 行,
L L
L 表示序列长度,
D D
D 表示每个头的维度,
B B
B 表示 block size 的大小
值得注意的是,整体 SRAM 的内存占用仅仅取决于
B B
B 和
D D
D ,与
L L
L 无关,因此,该算法可以扩展到长上下文而不会遇到内存问题(GPU 的共享内存很小,例如 H100 架构中为 228kb/SM)。在计算过程中,我们从左到右遍历
K T K^T
K
T 和
A A
A ,从上到下遍历
V V
V ,并相应地更新
m , d , O m,\ d, \ O
m
,
d
,
O 的状态
结语
这篇文章主要分享了 FlashAttention 的思想,FlashAttention 想要做的就是在单个循环中完成 self-attention 的整个计算
我们先从 safe softmax 出发分析了 safe softmax 算法需要循环迭代 3 次,第一次循环求取序列中的最大值
m N m_N
m
N
,第二次循环求取 safe softmax 的分母
d N d_N
d
N
,第三次循环求取最终的 softmax 值
a i a_i
a
i
。由于 SRAM 片上内存空间有限,我们无法将所有的
x i x_i
x
i
缓存,导致我们在每次迭代时需要重新访问
Q , K Q,K
Q
,
K 以动态重新计算
x i x_i
x
i
,这造成了大量的 I/O 操作,我们并不希望看到
在 online softmax 中提出了一个有意思的想法,那就是创建一个新的序列
d i ′ d_i^{\prime}
d
i
′
来替换原始的
d i d_i
d
i
,主要区别在于将原来
d i d_i
d
i
定义中的
m N m_N
m
N
替换为了
m i m_i
m
i
,以消除对
N N
N 的依赖,并且我们发现
d N
d N ′ d_N=d_N^{\prime}
d
N
=
d
N
′
,那么整个 safe softmax 的计算就只需要迭代 2 次即可,第一次循环迭代可以同时计算
m i m_i
m
i
和
d i ′ d_i^{\prime}
d
i
′
FlashAttention 想要把 self-attention 中的循环次数降低到一次,但是发现对于其中的 softmax 而言无法实现,因此它另辟蹊径试图寻找输出矩阵
O O
O 的一次迭代形式。经过推导我们发现输出行向量
o i \bm{o}_i
o
i
仍然依赖于
m N m_N
m
N
和
d N d_N
d
N
变量,这两个值在前一个循环完成之前无法确定,因此仍需要两次循环
FlashAttention 借助 online softmax 的替代思想,将序列
o \bm{o}
o 替换为
o ′ \bm{o}^{\prime}
o
′ ,替换之后发现
o i ′ \bm{o}_i^{\prime}
o
i
′
不再依赖于
m N m_N
m
N
和
d N d_N
d
N
等变量,而仅依赖于
d i ′ , d i − 1 ′ , m i , m i − 1 , x i d^{\prime}i,\ d^{\prime}{i-1}, \ m_i, \ m_{i-1}, \ x_i
d
i
′
,
d
i
−
1
′
,
m
i
,
m
i
−
1
,
x
i
,从而可以在单个循环中融合 self-attention 中的所有计算
最后,由于 flashattention 算法中的所有操作都是关联的,因此可以逐个 tiling 分块计算,并且整体 SRAM 的内存占用与序列长度无关,因此可以扩展到长上下文而不用担心遇到内存问题
OK,以上就是关于这篇手稿的全部内容了,大家感兴趣的可以看看,作为 FlashAttention 的入门还是没有问题的😄