目录

从Online-Softmax到FlashAttention

从Online Softmax到FlashAttention

前言

最近在学习 FlashAttention,看到一份不错的手稿分享下🤗

manuscript:

0. Abstract

FlashAttention 的关键创新是使用类似于 Online Softmax 的思想来 tile 分块 self-attention 的计算,这样可以融合整个多头注意力层,而无需多次重复访问 GPU 的全局内存来获取临时变量和注意力分数矩阵

A A

A 。本文将简要解释如何 tiling 分块 self-attention 的计算,以及如何从 online softmax 中推导出 FlashAttention 计算

1. The Self-Attention

Self-Attention 的计算可以描述为:

O

s o f t m a x ( Q K T ) V \begin{equation} O = \mathrm{softmax}\left(QK^T\right)V \end{equation}

O

=

softmax

(

Q

K

T

)

V

其中

Q , K , V , O ∈ R L × D Q,K,V,O \in \mathbb{R}^{L\times D}

Q

,

K

,

V

,

O

R

L

×

D ,

L L

L 是序列长度,

D D

D 是每个头的维度,softmax 将按列应用于

Q K T QK^T

Q

K

T

Note :这里我们忽略了多头、多 batch,因为在这些维度上的计算是完全并行的,也就是说我们只关注单头、单 batch 的计算就行。另外为了简单起见,我们也忽略了注意力掩码以及缩放因子

1 D \frac{1}{\sqrt{D}}

D

1

计算 self-attention 的标准方法是将其分解为以下几个阶段:

X

Q K T A

s o f t m a x ( X ) O

A V \begin{align} X & = QK^T \ A & = \mathrm{softmax}(X) \ O & = AV \end{align}

X

A

O

=

Q

K

T

=

softmax

(

X

)

=

A

V

我们称

X X

X 矩阵为 pre-softmax logits,

A A

A 矩阵为注意力分数(attention score),

O O

O 矩阵为输出

FlashAttention 的一个惊人之处在于,我们不需要在全局内存(global memory)上实现

X X

X 和

A A

A 矩阵,而是将公式 (1) 中的整个计算融合到单个 CUDA kernel 中。这要求我们设计一种能够精心管理片上内存(on-chip memory)的算法,因为 NVIDIA GPU 的共享内存(shared memory)非常小

对于矩阵乘法等经典算法,我们通常会使用 tiling 技术来确保片上内存不超过硬件限制。下图提供了一个例子,在 kernel 执行期间,无论矩阵形状如何,只有

3 T 2 3T^2

3

T

2 个元素存储在片上。这种 tiling 方法之所以有效,是因为加法是关联的,允许将整个矩阵乘法分解为许多 tile-wise 矩阵乘法的总和

然而,Self-Attention 包含一个不直接关联的 softmax 运算符,因此很难像下图那样简单地对 Self-Attention 进行 tile,那有没有办法让 softmax 具有关联性呢?🤔

https://i-blog.csdnimg.cn/direct/158826e34d414323b046e9a21c2c339f.png#pic_center

上图简要说明了如何对矩阵乘法

C

A × B C=A\times B

C

=

A

×

B 的输入和输出矩阵进行 tile(分块),矩阵被划分为

T × T T\times T

T

×

T 个小块。对于每个输出块,我们从左到右遍历

A A

A 中的相关块,从上到下遍历

B B

B 中的相关块,并将它们的值从全局内存(global memory)加载到片上内存(on-chip memory)(以蓝色表示,整体片上内存占用为

O ( T 2 ) O(T^2)

O

(

T

2

) )

接着逐块进行矩阵乘法,对于位置

( i , j ) (i,j)

(

i

,

j

) ,我们先从片上内存中加载块内所有

i i

i 行的元素即

A [ i , k ] A[i,k]

A

[

i

,

k

] 以及所有

j j

j 列的元素即

B [ k , j ] B[k,j]

B

[

k

,

j

] (以红色表示),然后利用

A [ i , k ] A[i,k]

A

[

i

,

k

] 和

B [ k , j ] B[k,j]

B

[

k

,

j

] 计算得到片上内存中的

C [ i , j ] C[i,j]

C

[

i

,

j

] 。以此循环,直到完成一个块的计算后,我们再将片上

C C

C 块的结果写回主内存并继续处理下一个块

实际 tiling 的应用要复杂得多,大家可以参考:

2. (Safe) Softmax

我们先回顾一下 softmax 算子,下面是 softmax 计算的通用公式:

s o f t m a x ( { x 1 , … , x N } )

{ e x i ∑ j

1 N e x j } i

1 N \begin{equation} \mathrm{softmax}({x_1,\ldots,x_N})=\left{\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}}\right}_{i=1}^N \end{equation}

softmax

({

x

1

,

,

x

N

})

=

{

j

=

1

N

e

x

j

e

x

i

}

i

=

1

N

值得注意的是,当

x i x_i

x

i

比较大时

e x i e^{x_i}

e

x

i

很容易溢出,例如 float16 可以支持的最大值是 65536,这意味着当

x ⩾ 11 x \geqslant 11

x

11 时,

e x e^x

e

x 将超过 float16 的有效范围

为了解决这个问题,像 pytorch、tensorflow 等框架通常会使用一种被称为 “safe” 的 softmax 计算方法:

e x i ∑ j

1 N e x j

e x i − m ∑ j

1 N e x j − m \begin{equation} \frac{e^{x_{i}}}{\sum_{j=1}^{N}e^{x_{j}}}=\frac{e^{x_{i}-m}}{\sum_{j=1}^{N}e^{x_{j}-m}} \end{equation}

j

=

1

N

e

x

j

e

x

i

=

j

=

1

N

e

x

j

m

e

x

i

m

其中

m

max ⁡ j

1 N ( x j ) m=\max_{j=1}^{N}(x_j)

m

=

max

j

=

1

N

(

x

j

)

我们将

x i x_i

x

i

减去它们中的最大值,这样可以确保每个元素

x i − m ⩽ 0 x_i-m \leqslant 0

x

i

m

0 ,从而再做指数运算时可以确保其值不会溢出

我们可以将 safe softmax 的计算总结为下面的 3-pass 算法:

Algorithm 3-pass safe softmax

NOTATIONS

  • { m i } {m_i}

    {

    m

    i

    } :

    max ⁡ j

    1 i { x j } \max_{j=1}^i{x_j}

    max

    j

    =

    1

    i

    {

    x

    j

    } ,初始值

    m 0

    − ∞ m_0=-\infty

    m

    0

    =

  • { d i } {d_i}

    {

    d

    i

    } :

    ∑ j

    1 i e x j − m N \sum_{j=1}^{i}e^{x_{j}-m_{N}}

    j

    =

    1

    i

    e

    x

    j

    m

    N

    ,初始值

    d 0

    0 d_0=0

    d

    0

    =

    0 ,

    d N d_N

    d

    N

    是 safe softmax 的分母

  • { a i } {a_i}

    {

    a

    i

    } :最终的 softmax 值

BODY

for

i ← 1 , N i\leftarrow1,N

i

1

,

N do

m i ← max ⁡ ( m i − 1 , x i ) \begin{equation} m_i \leftarrow \max (m_{i-1},x_i) \end{equation}

m

i

max

(

m

i

1

,

x

i

)

end

for

i ← 1 , N i\leftarrow1,N

i

1

,

N do

d i ← d i − 1 + e x i − m N \begin{equation} d_i \leftarrow d_{i-1} + e^{x_i-m_N} \end{equation}

d

i

d

i

1

e

x

i

m

N

end

for

i ← 1 , N i\leftarrow1,N

i

1

,

N do

a i ← e x i − m N d N \begin{equation} a_i \leftarrow \frac{e^{x_i-m_N}}{d_N} \end{equation}

a

i

d

N

e

x

i

m

N

end

该算法要求我们对

[ 1 , N ] [1,N]

[

1

,

N

] 进行 3 次迭代,在 Transformer 的 self-attention 中,每个

x i x_i

x

i

是通过

Q K T QK^T

Q

K

T 计算得到的。如果我们无法将所有的

x i x_i

x

i

都缓存到片上内存(SRAM)中(实际上我们也没有足够大的片上内存 SRAM 来容纳所有的

x i x_i

x

i

),那么在每次迭代时不得不访问

Q , K Q,K

Q

,

K 以动态重新计算

x i x_i

x

i

,这种频繁的访问会导致大量的 I/O 操作,从而降低整体效率

3. Online Softmax

如果我们能够在一个循环中融合公式 (7)、(8)、(9),那么就可以将 global memory 的访问时间降低 3 倍,不幸的是,我们不能在同一个循环中融合公式 (7)、(8),因为公式 (8) 的计算需要依赖

m N m_N

m

N

,而

m N m_N

m

N

只有在第一个循环完成之后才能够确定

既然数据之间存在着依赖关系,那我们可以创建另外一个序列

d i ′ :

∑ j

1 i e x j − m i d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{i}}

d

i

:=

j

=

1

i

e

x

j

m

i

作为原始序列

d i ′ :

∑ j

1 i e x j − m N d_{i}^{ \prime} := \sum_{j=1}^{i}e^{x_{j}-m_{N}}

d

i

:=

j

=

1

i

e

x

j

m

N

的替代,以消除对

N N

N 的依赖,并且这两个序列的第

N N

N 项是相同的即

d N

d N ′ d_N=d^{\prime}_N

d

N

=

d

N

,因此我们可以安全地用

d N ′ d^{\prime}_N

d

N

替换公式 (9) 中的

d N d_N

d

N

Note :在数学和计算机科学中,符号

:

:=

:= 用于表示 “定义为” 或 “被定义为”,例如当我们写下

x :

y x:=y

x

:=

y 时,我们是在说明 “我们将

x x

x 定义为

y y

y ” 或 “

x x

x 等于

y y

y (这是它的定义)”

此外,我们还可以找到

d i ′ d^{\prime}_i

d

i

d i − 1 ′ d^{\prime}_{i-1}

d

i

1

之间的递归关系:

d i ′

∑ j

1 i e x j − m i

( ∑ j

1 i − 1 e x j − m i ) + e x i − m i

( ∑ j

1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i

d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{equation} \begin{aligned} d_{i}^{\prime} & =\sum_{j=1}^ie^{x_j-m_i} \ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_i}\right)+e^{x_i-m_i} \ & =\left(\sum_{j=1}^{i-1}e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \ & =d_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} \end{equation}

d

i

=

j

=

1

i

e

x

j

m

i

=

(

j

=

1

i

1

e

x

j

m

i

)

e

x

i

m

i

=

(

j

=

1

i

1

e

x

j

m

i

1

)

e

m

i

1

m

i

e

x

i

m

i

=

d

i

1

e

m

i

1

m

i

e

x

i

m

i

这个递归公式只依赖于

m i m_i

m

i

m i − 1 m_{i-1}

m

i

1

,此外我们还可以在一个循环中同时计算

m j m_j

m

j

d j ′ d^{\prime}_j

d

j

Algorithm 2-pass online softmax

for

i ← 1 , N i\leftarrow 1,N

i

1

,

N do

m i ← max ⁡ ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} m_i & \leftarrow \max (m_{i-1}, x_i) \ d^{\prime}i & \leftarrow d^{\prime}{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned}

m

i

d

i

max

(

m

i

1

,

x

i

)

d

i

1

e

m

i

1

m

i

e

x

i

m

i

end

for

i ← 1 , N i\leftarrow 1,N

i

1

,

N do

a i ← e x i − m N d N ′ a_i \leftarrow \frac{e^{x_i - m_N}}{d^{\prime}_N}

a

i

d

N

e

x

i

m

N

end

这是 Online Softmax 论文中提出的算法,但它仍然需要两次循环才能完成 softmax 的计算,我们能否将循环次数减少到 1 次来最小化全局的 I/O 呢?🤔

4. FlashAttention

不幸的是,对于 softmax 来说,答案是否定的,但在 Self-Attention 中,我们最终的目标并不是注意力分数矩阵

A

s o f t m a x ( Q K T ) A=\mathrm{softmax}(QK^T)

A

=

softmax

(

Q

K

T

) ,而是输出矩阵

O

A × V O=A \times V

O

=

A

×

V ,既然无法找到 softmax 的一次递归形式,那我们换个思路思考下能否找到矩阵

O O

O 的一次递归形式呢?

下面我们尝试下先将 Self-Attention 计算的第

k k

k 行公式转化为递归算法:

Note :由于所有行的计算都是独立的,为了简单起见,这里我们只解释一行的计算

Algorithm Multi-pass Self-Attention

NOTATIONS

  • Q [ k , : ] Q[k,:]

    Q

    [

    k

    ,

    :

    ] :查询矩阵

    Q Q

    Q 的第

    k k

    k 行向量

  • K T [ : , i ] K^T[:,i]

    K

    T

    [

    :

    ,

    i

    ] :键矩阵

    K T K^T

    K

    T 的第

    i i

    i 列向量

  • O [ k , : ] O[k,:]

    O

    [

    k

    ,

    :

    ] :输出矩阵

    O O

    O 的第

    k k

    k 行

  • V [ i , : ] V[i,:]

    V

    [

    i

    ,

    :

    ] :值矩阵

    V V

    V 的第

    i i

    i 行

  • o i \bm{o}_i

    o

    i

    ∑ j

    1 i a j V [ j , : ] \sum_{j=1}^ia_jV[j,:]

    j

    =

    1

    i

    a

    j

    V

    [

    j

    ,

    :

    ] ,存储

    A [ k , : i ] × V [ : i , : ] A[k,:i]\times V[:i,:]

    A

    [

    k

    ,

    :

    i

    ]

    ×

    V

    [

    :

    i

    ,

    :

    ] 结果的行向量

BODY

for

i ← 1 , N i\leftarrow 1,N

i

1

,

N do

x i ← Q [ k , : ] K T [ : , i ] m i ← max ⁡ ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} x_i & \leftarrow Q[k,:]K^T[:,i] \ m_i & \leftarrow \max (m_{i-1},x_i) \ d^{\prime}i & \leftarrow d^{\prime}{i-1} e^{m_{i-1}-m_i} + e^{x_i-m_i} \end{aligned}

x

i

m

i

d

i

Q

[

k

,

:

]

K

T

[

:

,

i

]

max

(

m

i

1

,

x

i

)

d

i

1

e

m

i

1

m

i

e

x

i

m

i

end

for

i ← 1 , N i\leftarrow 1,N

i

1

,

N do

a i ← e x i − m N d N ′ o i ← o i − 1 + a i V [ i , : ] \begin{align} a_i & \leftarrow \frac{e^{x_i-m_N}}{d^{\prime}_N} \ \bm{o}i & \leftarrow \bm{o}{i-1} + a_iV[i,:] \end{align}

a

i

o

i

d

N

e

x

i

m

N

o

i

1

a

i

V

[

i

,

:

]

end

O [ k , : ] ← o N O[k,:] \leftarrow \bm{o}_N

O

[

k

,

:

]

o

N

我们将公式 (12) 中的

a i a_i

a

i

替换公式 (11) 中的定义,有:

o i :

∑ j

1 i ( e x j − m N d N ′ V [ j , : ] ) \begin{align} \bm{o}i := \sum{j=1}^i\left(\frac{e^{x_j-m_N}}{d_N^{\prime}}V[j,:]\right) \end{align}

o

i

:=

j

=

1

i

(

d

N

e

x

j

m

N

V

[

j

,

:

]

)

该公式仍然依赖

m N m_N

m

N

d N d_N

d

N

变量,这两个值在前一个循环完成之前无法确定。但我们可以再次使用第 3 节中的替代技巧,即创建替代序列

o ′ \bm{o}^{\prime}

o

′ :

o i ′ :

( ∑ j

1 i e x j − m i d i ′ V [ j , : ] ) \bm{o}i^{\prime}:=\left(\sum{j=1}^i\frac{e^{x_j-m_i}}{d_i^{\prime}}V[j,:]\right)

o

i

:=

(

j

=

1

i

d

i

e

x

j

m

i

V

[

j

,

:

]

)

o \bm{o}

o 和

o ′ \bm{o}^{\prime}

o

′ 的第

N N

N 个元素相同即

o N ′

o N \bm{o}^{\prime}_N=\bm{o}_N

o

N

=

o

N

,并且我们可以发现

o i ′ \bm{o}^{\prime}_i

o

i

o i − 1 ′ \bm{o}^{\prime}_{i-1}

o

i

1

之间存在如下的递归关系:

o i ′

∑ j

1 i e x j − m i d i ′ V [ j , : ]

( ∑ j

1 i − 1 e x j − m i d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ]

( ∑ j

1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ]

( ∑ j

1 i − 1 e x j − m i − 1 d i − 1 ′ V [ j , : ] ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i , : ]

o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{equation} \begin{aligned} \bm{o}^{\prime}{i}& = \sum{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d^{\prime}{i}}V[j, : ]\ & = \left(\sum{j=1}^{i-1}\frac{e^{x_{j}-m_{i}}}{d^{\prime}{i} }V[j, : ]\right) + \frac{e^{x{i}-m_{i}}}{d^{\prime}{i}}V[i, : ]\ & = \left(\sum{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}{ i-1}}\frac{e^{x{j}-m_{i}}}{e^{x_{j}-m_{i-1}}}\frac{d^{\prime}{i-1}}{d^{ \prime}{i}}V[j, : ]\right) + \frac{e^{x_{i}-m_{i}}}{d^{\prime}{i}}V[i, : ]\ & = \left(\sum{j=1}^{i-1}\frac{e^{x_{j}-m_{i-1}}}{d^{\prime}{ i-1}}V[j, : ]\right) \frac{d^{\prime}{i-1}}{d^{\prime}{i}}e^{m{i-1}-m_{i}} + \frac{e^{x_{i}-m_{i}}}{d^{\prime}{i}}V[i, : ]\ & = \bm{o}^{\prime}{i-1}\frac{d^{\prime}{i-1} e^{m{i-1}-m_{ i}}}{d^{\prime}{i}} + \frac{e^{x{i}-m_{i}}}{d^{\prime}_{i}}V[i, : ] \end{aligned} \end{equation}

o

i

=

j

=

1

i

d

i

e

x

j

m

i

V

[

j

,

:

]

=

(

j

=

1

i

1

d

i

e

x

j

m

i

V

[

j

,

:

]

)

d

i

e

x

i

m

i

V

[

i

,

:

]

=

(

j

=

1

i

1

d

i

1

e

x

j

m

i

1

e

x

j

m

i

1

e

x

j

m

i

d

i

d

i

1

V

[

j

,

:

]

)

d

i

e

x

i

m

i

V

[

i

,

:

]

=

(

j

=

1

i

1

d

i

1

e

x

j

m

i

1

V

[

j

,

:

]

)

d

i

d

i

1

e

m

i

1

m

i

d

i

e

x

i

m

i

V

[

i

,

:

]

=

o

i

1

d

i

d

i

1

e

m

i

1

m

i

d

i

e

x

i

m

i

V

[

i

,

:

]

它仅依赖于

d i ′ ,   d i − 1 ′ ,   m i ,   m i − 1 ,   x i d^{\prime}i,\ d^{\prime}{i-1}, \ m_i, \ m_{i-1}, \ x_i

d

i

,

d

i

1

,

m

i

,

m

i

1

,

x

i

,因此我们可以在单个循环中融合 Self-Attention 中的所有计算

Algorithm FlashAttention

for

i ← 1 , N i \leftarrow 1,N

i

1

,

N do

x i ← Q [ k , : ] K T [ : , i ] m i ← max ⁡ ( m i − 1 , x i ) d i ′ ← d i − 1 ′ e m i − 1 − m i + e x i − m i o i ′ ← o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{aligned} {x_{i}} & \leftarrow {Q[k,:]} {K^{T}[:,i]}\ {m_{i}} & \leftarrow \max\left( {m_{i-1},x_{i}} \right)\ {d^{\prime}{i}} & \leftarrow {d^{\prime}{i-1}e^{m_{i-1}-m_{i}}}+ {e^{x_{i}-m_{i}}}\ {o^{\prime}{i}} & \leftarrow {o^{\prime}{i-1}}\frac{ {d^{\prime}{i-1}e^{m{i-1}-m_{i}}}}{ { d^{\prime}{i}}}+\frac{ {e^{x{i}-m_{i}}}}{ {d^{\prime}_{i}}} {V[i,:]}\end{aligned}

x

i

m

i

d

i

o

i

Q

[

k

,

:

]

K

T

[

:

,

i

]

max

(

m

i

1

,

x

i

)

d

i

1

e

m

i

1

m

i

e

x

i

m

i

o

i

1

d

i

d

i

1

e

m

i

1

m

i

d

i

e

x

i

m

i

V

[

i

,

:

]

end

O [ k , : ] ← o N ′ O[k,:] \leftarrow \bm{o}^{\prime}_N

O

[

k

,

:

]

o

N

状态量

x i ,   m i ,   d i ′ , o i ′ x_i, \ m_i, \ d^{\prime}_i, \bm{o}^{\prime}_i

x

i

,

m

i

,

d

i

,

o

i

占用的内存都很小,可以非常轻松的放入 GPU 的 shared memory 中。另外由于此算法中的所有操作都是关联的,因此它与 tiling 兼容,如果我们逐个 tiling 计算,则该算法可以表示为:

Algorithm FlashAttention(Tiling)

NEW NOTATIONS

  • b b

    b :tile 的 block size 大小

  • #tiles \text{#tiles}

    #tiles :每行 tile 的数量,

    N

    b × #tiles N= b \times \text{#tiles}

    N

    =

    b

    ×

    #tiles

  • x i \bm{x}_i

    x

    i

    :存储第

    i i

    i 个 tile 的

    Q [ k ] K T Q[k]K^T

    Q

    [

    k

    ]

    K

    T 值的向量

    [ ( i − 1 ) b : i b ] [(i-1)b:ib]

    [(

    i

    1

    )

    b

    :

    ib

    ]

  • m i ( l o c a l ) m_i^{\mathrm{(local)}}

    m

    i

    (

    local

    )

    x i \bm{x}_i

    x

    i

    内部的局部最大值

BODY

for

i ← 1 , #tile i \leftarrow 1,\text{#tile}

i

1

,

#tile do

x i   ←   Q [ k , : ] K T [ : , ( i − 1 ) b : i b ] m i (local)  

  max ⁡ j

1 b ( x i [ j ] ) m i   ←   max ⁡ ( m i − 1 , m i (local) ) d i ′   ←   d i − 1 ′ e m i − 1 − m i + ∑ j

1 b e x i [ j ] − m i o i ′   ←   o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + ∑ j

1 b e x i [ j ] − m i d i ′ V [ j + ( i − 1 ) b , : ] \begin{split}\bm{x}{i}&\ \leftarrow\ Q[k, : ] K^{T}[:,(i-1) b : i b]\ m{i}^{\text{(local)}}&\ =\ \max_{j=1}^{b} (\bm{x}{i} [j]) \ m{i}&\ \leftarrow\ \max\big(m_{i-1},m_{i}^{\text{(local)}} \big)\ d_{i}^{\prime}&\ \leftarrow\ d_{i-1}^{\prime} e^{m_{i-1}-m_{i}} + \sum_{j=1}^{b} e^{\bm{x}{i}[j]-m{i}}\ \bm{o}{i}^{\prime}&\ \leftarrow\ \bm{o}{i-1}^{\prime} \frac{d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}}{d_{i}^{\prime}} + \sum_{j=1}^{b} \frac{e^{\bm{x}{i}[j]-m{i}}}{d_{i}^{\prime}}V[j+(i-1) b, : ]\end{split}

x

i

m

i

(local)

m

i

d

i

o

i

Q

[

k

,

:

]

K

T

[

:

,

(

i

1

)

b

:

ib

]

=

j

=

1

max

b

(

x

i

[

j

])

max

(

m

i

1

,

m

i

(local)

)

d

i

1

e

m

i

1

m

i

j

=

1

b

e

x

i

[

j

]

m

i

o

i

1

d

i

d

i

1

e

m

i

1

m

i

j

=

1

b

d

i

e

x

i

[

j

]

m

i

V

[

j

(

i

1

)

b

,

:

]

end

O [ k , : ] ← o N / b ′ O[k,:] \leftarrow \bm{o}^{\prime}_{N/b}

O

[

k

,

:

]

o

N

/

b

下图说明了如何将该算法应用到硬件上:

https://i-blog.csdnimg.cn/direct/011cab53bdb4407ea1db49aee3982c94.png#pic_center

上图说明了 FlashAttention 在硬件上的计算方式,其中蓝色块表示在 SRAM 中的 tile,而红色块对应于 tile 中的第

i i

i 行,

L L

L 表示序列长度,

D D

D 表示每个头的维度,

B B

B 表示 block size 的大小

值得注意的是,整体 SRAM 的内存占用仅仅取决于

B B

B 和

D D

D ,与

L L

L 无关,因此,该算法可以扩展到长上下文而不会遇到内存问题(GPU 的共享内存很小,例如 H100 架构中为 228kb/SM)。在计算过程中,我们从左到右遍历

K T K^T

K

T 和

A A

A ,从上到下遍历

V V

V ,并相应地更新

m ,   d ,   O m,\ d, \ O

m

,

d

,

O 的状态

结语

这篇文章主要分享了 FlashAttention 的思想,FlashAttention 想要做的就是在单个循环中完成 self-attention 的整个计算

我们先从 safe softmax 出发分析了 safe softmax 算法需要循环迭代 3 次,第一次循环求取序列中的最大值

m N m_N

m

N

,第二次循环求取 safe softmax 的分母

d N d_N

d

N

,第三次循环求取最终的 softmax 值

a i a_i

a

i

。由于 SRAM 片上内存空间有限,我们无法将所有的

x i x_i

x

i

缓存,导致我们在每次迭代时需要重新访问

Q , K Q,K

Q

,

K 以动态重新计算

x i x_i

x

i

,这造成了大量的 I/O 操作,我们并不希望看到

在 online softmax 中提出了一个有意思的想法,那就是创建一个新的序列

d i ′ d_i^{\prime}

d

i

来替换原始的

d i d_i

d

i

,主要区别在于将原来

d i d_i

d

i

定义中的

m N m_N

m

N

替换为了

m i m_i

m

i

,以消除对

N N

N 的依赖,并且我们发现

d N

d N ′ d_N=d_N^{\prime}

d

N

=

d

N

,那么整个 safe softmax 的计算就只需要迭代 2 次即可,第一次循环迭代可以同时计算

m i m_i

m

i

d i ′ d_i^{\prime}

d

i

FlashAttention 想要把 self-attention 中的循环次数降低到一次,但是发现对于其中的 softmax 而言无法实现,因此它另辟蹊径试图寻找输出矩阵

O O

O 的一次迭代形式。经过推导我们发现输出行向量

o i \bm{o}_i

o

i

仍然依赖于

m N m_N

m

N

d N d_N

d

N

变量,这两个值在前一个循环完成之前无法确定,因此仍需要两次循环

FlashAttention 借助 online softmax 的替代思想,将序列

o \bm{o}

o 替换为

o ′ \bm{o}^{\prime}

o

′ ,替换之后发现

o i ′ \bm{o}_i^{\prime}

o

i

不再依赖于

m N m_N

m

N

d N d_N

d

N

等变量,而仅依赖于

d i ′ ,   d i − 1 ′ ,   m i ,   m i − 1 ,   x i d^{\prime}i,\ d^{\prime}{i-1}, \ m_i, \ m_{i-1}, \ x_i

d

i

,

d

i

1

,

m

i

,

m

i

1

,

x

i

,从而可以在单个循环中融合 self-attention 中的所有计算

最后,由于 flashattention 算法中的所有操作都是关联的,因此可以逐个 tiling 分块计算,并且整体 SRAM 的内存占用与序列长度无关,因此可以扩展到长上下文而不用担心遇到内存问题

OK,以上就是关于这篇手稿的全部内容了,大家感兴趣的可以看看,作为 FlashAttention 的入门还是没有问题的😄

参考