AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


FlashAttention原理,终于看明白了!
发布日期:2024-12-26 07:07:39 浏览次数: 1528 来源:丁师兄大模型








offer捷报

训练营学员继拿下腾讯美团字节sp后,百度经过加面,也给了sp offer,且签字费给到了其他大厂的两倍。
目前 LLM 是基于 Transformer 结构,其核心是 self-attention,随着输入序列的不断增大,时间与空间复杂度都呈二次方增长。
为了解决扩大 Transformer 模型上下文长度时面临的挑战,‌斯坦福大学和纽约州立大学布法罗分校的研究者共同提出了 FlashAttention,通过提供一种快速且内存高效的注意力算法,‌无需任何近似即可加速注意力计算并减少内存占用。‌ 
FlashAttention 的核心原理是将输入 QKV 分块,并保证每个块能够在 SRAM(一级缓存)上完成注意力操作,并将结果更新回 HBM,从而降低对高带宽内存(HBM)的读写操作。
总之,FlashAttention 从 GPU 的内存读写入手,减少了内存读写量,从而实现了 2~4 倍的速度提升。
FlashAttention 的核心原理

01

From Online Softmax to FlashAttention

在计算注意力的过程中,点积可以分块累加,而 softmax 不能分块后直接处理,所以需要重新设计计算方式。

首先观察下 softmax 的计算方式,如下式 1 所示:

式1

import numpy as npinputs = np.array([0761210], dtype=np.float16)
eInputs = np.exp(inputs)result = eInputs/np.sum(eInputs)print(eInputs)print(np.sum(eInputs))print(result)
程序输出:
[1.000e+00 1.097e+03 4.035e+02 inf 2.203e+04]inf[0. 0. 0. nan. 0.]
为了缓解这个问题,通常采用一种称为 safe-softmax 的技巧,即每个数字减去最大值再求 softmax,如下式 2。
式2
import numpy as npinputs = np.array([0, 7, 6, 12, 10], dtype=np.float16)max_val = max(inputs)emInputs = np.exp(inputs-max_val)result1 = emInputs/np.sum(emInputs)print(emInputs)print(np.sum(emInputs))print(result1)print(sum(result1))
程序输出:
[6.139e-06 6.737e-03 2.480e-03 1.000e+00 1.354e-01]1.145[5.364e-06 5.886e-03 2.167e-03 8.735e-01 1.183e-01]0.9998794794082642
注意:safe-softmax 与 softmax 的结果一致。
import numpy as npinputs = np.array([11333], dtype=np.float16)
#softmaxeInputs = np.exp(inputs)result = eInputs/np.sum(eInputs)print(result)
# save-softmaxmax_val = max(inputs)emInputs = np.exp(inputs-max_val)result1 = emInputs/np.sum(emInputs)print(result1)
程序输出:
[0.04138 0.04138 0.3057  0.3057  0.3057 ][0.04138 0.04138 0.3057  0.3057  0.3057 ]
如图 1,可将 save-softmax 写成 3 步骤。
图 1 3-pass save softmax
图 2 2-pass online softmax
图 3 di'的迭代形式
# online SoftMax 2-passimport torch
= 8inputs = torch.randn(L)result = torch.zeros(L)
= torch.tensor(float("-inf"))= 0for i in range(L):    m_new = torch.max(m, inputs[i])    d = d * (m - m_new).exp() + (inputs[i] - m_new).exp()    m = m_new
for i in range(L):    result[i] = (inputs[i]-m).exp() / d
print('online softmax result:',result)print(torch.sum(result))
# save-softmax 3步骤max_value = torch.max(inputs)eX = torch.exp(inputs-max_value)result1 = eX/torch.sum(eX)print('save softmax result:', result1)print(torch.sum(result1))
程序输出:
online softmax result: tensor([0.0595, 0.0548, 0.3192, 0.1136, 0.0562, 0.2336, 0.0774, 0.0856])tensor(1.)save softmax result: tensor([0.0595, 0.0548, 0.3192, 0.1136, 0.0562, 0.2336, 0.0774, 0.0856])tensor(1.0000)

02

Flash Attention

multi-pass selft-attention 其实就是结合 online softmax 的 2 步骤(图 4),那么可不可以直接写成 1 步骤呢,答案是肯定的。

图 4 多步骤的 self-attention
式3
式4
式5
图5 1-pass flash attention
图6 1-pass flash attention(Tiling)

03

代码分析

FlashAttention 的伪代码:

FlashAttentionV2 的伪代码:
具体数据的代码案例:
import timeimport torch
torch.manual_seed(0)NEG_INF = -1e10  # -infinityEPSILON = 1e-10
Q_LEN = 6K_LEN = 6Q_BLOCK_SIZE = 3KV_BLOCK_SIZE = 3P_DROP = 0.2
Tr = Q_LEN // Q_BLOCK_SIZE # Tr 块数Tc = K_LEN // KV_BLOCK_SIZE # Tc 块数
= torch.randn(11, Q_LEN, 4, requires_grad=True).to(device='cpu')= torch.randn(11, K_LEN, 4, requires_grad=True).to(device='cpu')= torch.randn(11, K_LEN, 4, requires_grad=True).to(device='cpu')
# step 4Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
print("----------------FlashAttentionV1------------------------")= torch.zeros_like(Q, requires_grad=True)= torch.zeros(Q.shape[:-1])[..., None]= torch.ones(Q.shape[:-1])[..., None* NEG_INF# print(O.shape, l.shape, m.shape)
# step 5O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))# print(O_BLOCKS[0].shape, l_BLOCKS[0].shape, m_BLOCKS[0].shape)
# step 6start_time1 = time.time()for j in range(Tc):    # step 7    Kj = K_BLOCKS[j]    Vj = V_BLOCKS[j]    # step 8    for i in range(Tr):        # step 9        Qi = Q_BLOCKS[i]        Oi = O_BLOCKS[i]        li = l_BLOCKS[i]        mi = m_BLOCKS[i]                # step 10        # S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi, Kj) # Qi*Kj.T         S_ij = torch.einsum("... i d, ... j d -> ... i j", Qi, Kj)                # step 11        # mask = S_ij.ge(0.5)        # S_ij = torch.masked_fill(S_ij, mask, value=0)                # step 12        m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)        P_ij = torch.exp(S_ij - m_block_ij)        l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True+ EPSILON        P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)               # step 13        mi_new = torch.maximum(m_block_ij, mi)        li_new = torch.exp(mi - mi_new) * li + \                 torch.exp(m_block_ij - mi_new) * l_block_ij                # step 14        # m = torch.nn.Dropout(p=P_DROP)        # P_ij_Vj = m(P_ij_Vj)                # Step 15        O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi \                      + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj        # print(f'-----------Attention : Q{i}xK{j}---------')        # print(O_BLOCKS[i].shape)        # print(O_BLOCKS[0])        # print(O_BLOCKS[1])        # print('\n')                # step 16        l_BLOCKS[i] = li_new        m_BLOCKS[i] = mi_new
= torch.cat(O_BLOCKS, dim=2)= torch.cat(l_BLOCKS, dim=2)= torch.cat(m_BLOCKS, dim=2)print(O.shape, time.time()-start_time1)print(O)
print("----------------FlashAttentionV2------------------------")O2 = torch.zeros_like(Q, requires_grad=True)O2_BLOCKS = list(torch.split(O2, Q_BLOCK_SIZE, dim=2))
start_time2 = time.time()for i in range(Tr):    Qi = Q_BLOCKS[i]    Oi = O2_BLOCKS[i]    li = torch.zeros((*Q.shape[:-2], Q_BLOCK_SIZE, 1))    mi = torch.ones((*Q.shape[:-2], Q_BLOCK_SIZE, 1)) * NEG_INF    for j in range(Tc):        Kj = K_BLOCKS[j]        Vj = V_BLOCKS[j]        S_ij = torch.einsum("... i d, ... j d -> ... i j", Qi, Kj)        mi_new = torch.maximum(torch.max(S_ij, dim=-1, keepdims=True)[0], mi)        P_ij = torch.exp(S_ij - mi_new)        li = torch.exp(mi-mi_new)*li+torch.sum(P_ij, dim=-1, keepdims=True)+EPSILON        P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)        Oi = torch.exp(mi-mi_new) * Oi + P_ij_Vj        mi = mi_new    O2_BLOCKS[i] = Oi/li
O2 = torch.cat(O2_BLOCKS, dim=2)print(O2.shape, time.time()-start_time2)print(O2)
print("----------------Standard Self-Attention------------------------")start_time3 = time.time()scores = torch.matmul(Q, K.transpose(-2-1))attention_weights = torch.softmax(scores, dim=-1)output = torch.matmul(attention_weights, V)print(output.shape, time.time()-start_time3)print(output)
程序输出:
----------------FlashAttentionV1------------------------torch.Size([1, 1, 6, 4]) 0.0015511512756347656tensor([[[[ 0.2281, -0.2178, -0.3508,  0.1571],          [-0.1962, -0.6078, -0.4992, -0.5868],          [ 0.3373,  0.3694,  0.2818,  0.2253],          [-0.3096, -0.6828, -0.4914, -0.9161],          [ 0.0873,  0.6567,  0.1782,  0.1638],          [ 0.1808, -0.2194, -0.4053,  0.1305]]]], grad_fn=<CatBackward0>)----------------FlashAttentionV2------------------------torch.Size([1, 1, 6, 4]) 0.0009410381317138672tensor([[[[ 0.2281, -0.2178, -0.3508,  0.1571],          [-0.1962, -0.6078, -0.4992, -0.5868],          [ 0.3373,  0.3694,  0.2818,  0.2253],          [-0.3096, -0.6828, -0.4914, -0.9161],          [ 0.0873,  0.6567,  0.1782,  0.1638],          [ 0.1808, -0.2194, -0.4053,  0.1305]]]], grad_fn=<CatBackward0>)----------------Standard Self-Attention------------------------torch.Size([1, 1, 6, 4]) 0.00012636184692382812tensor([[[[ 0.2281, -0.2178, -0.3508,  0.1571],          [-0.1962, -0.6078, -0.4992, -0.5868],          [ 0.3373,  0.3694,  0.2818,  0.2253],          [-0.3096, -0.6828, -0.4914, -0.9161],          [ 0.0873,  0.6567,  0.1782,  0.1638],          [ 0.1808, -0.2194, -0.4053,  0.1305]]]],       grad_fn=<UnsafeViewBackward0>)
FlashAttentionV1、FlashAttentionV2、Standard Self-Attention 的结果是一致的,并无差别,且整体速度 FlashAttentionV2 比 FlashAttentionV1 更快。


END



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询