微信扫码
与创始人交个朋友
我要投稿
offer捷报
From Online Softmax to FlashAttention
在计算注意力的过程中,点积可以分块累加,而 softmax 不能分块后直接处理,所以需要重新设计计算方式。
首先观察下 softmax 的计算方式,如下式 1 所示:
式1
import numpy as np
inputs = np.array([0, 7, 6, 12, 10], dtype=np.float16)
eInputs = np.exp(inputs)
result = eInputs/np.sum(eInputs)
print(eInputs)
print(np.sum(eInputs))
print(result)
[1.000e+00 1.097e+03 4.035e+02 inf 2.203e+04]inf[0. 0. 0. nan. 0.]
import numpy as npinputs = np.array([0, 7, 6, 12, 10], dtype=np.float16)max_val = max(inputs)emInputs = np.exp(inputs-max_val)result1 = emInputs/np.sum(emInputs)print(emInputs)print(np.sum(emInputs))print(result1)print(sum(result1))
[6.139e-06 6.737e-03 2.480e-03 1.000e+00 1.354e-01]1.145[5.364e-06 5.886e-03 2.167e-03 8.735e-01 1.183e-01]0.9998794794082642
import numpy as np
inputs = np.array([1, 1, 3, 3, 3], dtype=np.float16)
#softmax
eInputs = np.exp(inputs)
result = eInputs/np.sum(eInputs)
print(result)
# save-softmax
max_val = max(inputs)
emInputs = np.exp(inputs-max_val)
result1 = emInputs/np.sum(emInputs)
print(result1)
[0.04138 0.04138 0.3057 0.3057 0.3057 ][0.04138 0.04138 0.3057 0.3057 0.3057 ]
# online SoftMax 2-pass
import torch
L = 8
inputs = torch.randn(L)
result = torch.zeros(L)
m = torch.tensor(float("-inf"))
d = 0
for i in range(L):
m_new = torch.max(m, inputs[i])
d = d * (m - m_new).exp() + (inputs[i] - m_new).exp()
m = m_new
for i in range(L):
result[i] = (inputs[i]-m).exp() / d
print('online softmax result:',result)
print(torch.sum(result))
# save-softmax 3步骤
max_value = torch.max(inputs)
eX = torch.exp(inputs-max_value)
result1 = eX/torch.sum(eX)
print('save softmax result:', result1)
print(torch.sum(result1))
online softmax result: tensor([0.0595, 0.0548, 0.3192, 0.1136, 0.0562, 0.2336, 0.0774, 0.0856])tensor(1.)save softmax result: tensor([0.0595, 0.0548, 0.3192, 0.1136, 0.0562, 0.2336, 0.0774, 0.0856])tensor(1.0000)
Flash Attention
multi-pass selft-attention 其实就是结合 online softmax 的 2 步骤(图 4),那么可不可以直接写成 1 步骤呢,答案是肯定的。
代码分析
FlashAttention 的伪代码:
import time
import torch
torch.manual_seed(0)
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10
Q_LEN = 6
K_LEN = 6
Q_BLOCK_SIZE = 3
KV_BLOCK_SIZE = 3
P_DROP = 0.2
Tr = Q_LEN // Q_BLOCK_SIZE # Tr 块数
Tc = K_LEN // KV_BLOCK_SIZE # Tc 块数
Q = torch.randn(1, 1, Q_LEN, 4, requires_grad=True).to(device='cpu')
K = torch.randn(1, 1, K_LEN, 4, requires_grad=True).to(device='cpu')
V = torch.randn(1, 1, K_LEN, 4, requires_grad=True).to(device='cpu')
# step 4
Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
print("----------------FlashAttentionV1------------------------")
O = torch.zeros_like(Q, requires_grad=True)
l = torch.zeros(Q.shape[:-1])[..., None]
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF
# print(O.shape, l.shape, m.shape)
# step 5
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
# print(O_BLOCKS[0].shape, l_BLOCKS[0].shape, m_BLOCKS[0].shape)
# step 6
start_time1 = time.time()
for j in range(Tc):
# step 7
Kj = K_BLOCKS[j]
Vj = V_BLOCKS[j]
# step 8
for i in range(Tr):
# step 9
Qi = Q_BLOCKS[i]
Oi = O_BLOCKS[i]
li = l_BLOCKS[i]
mi = m_BLOCKS[i]
# step 10
# S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi, Kj) # Qi*Kj.T
S_ij = torch.einsum("... i d, ... j d -> ... i j", Qi, Kj)
# step 11
# mask = S_ij.ge(0.5)
# S_ij = torch.masked_fill(S_ij, mask, value=0)
# step 12
m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
P_ij = torch.exp(S_ij - m_block_ij)
l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
# step 13
mi_new = torch.maximum(m_block_ij, mi)
li_new = torch.exp(mi - mi_new) * li + \
torch.exp(m_block_ij - mi_new) * l_block_ij
# step 14
# m = torch.nn.Dropout(p=P_DROP)
# P_ij_Vj = m(P_ij_Vj)
# Step 15
O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi \
+ (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
# print(f'-----------Attention : Q{i}xK{j}---------')
# print(O_BLOCKS[i].shape)
# print(O_BLOCKS[0])
# print(O_BLOCKS[1])
# print('\n')
# step 16
l_BLOCKS[i] = li_new
m_BLOCKS[i] = mi_new
O = torch.cat(O_BLOCKS, dim=2)
l = torch.cat(l_BLOCKS, dim=2)
m = torch.cat(m_BLOCKS, dim=2)
print(O.shape, time.time()-start_time1)
print(O)
print("----------------FlashAttentionV2------------------------")
O2 = torch.zeros_like(Q, requires_grad=True)
O2_BLOCKS = list(torch.split(O2, Q_BLOCK_SIZE, dim=2))
start_time2 = time.time()
for i in range(Tr):
Qi = Q_BLOCKS[i]
Oi = O2_BLOCKS[i]
li = torch.zeros((*Q.shape[:-2], Q_BLOCK_SIZE, 1))
mi = torch.ones((*Q.shape[:-2], Q_BLOCK_SIZE, 1)) * NEG_INF
for j in range(Tc):
Kj = K_BLOCKS[j]
Vj = V_BLOCKS[j]
S_ij = torch.einsum("... i d, ... j d -> ... i j", Qi, Kj)
mi_new = torch.maximum(torch.max(S_ij, dim=-1, keepdims=True)[0], mi)
P_ij = torch.exp(S_ij - mi_new)
li = torch.exp(mi-mi_new)*li+torch.sum(P_ij, dim=-1, keepdims=True)+EPSILON
P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
Oi = torch.exp(mi-mi_new) * Oi + P_ij_Vj
mi = mi_new
O2_BLOCKS[i] = Oi/li
O2 = torch.cat(O2_BLOCKS, dim=2)
print(O2.shape, time.time()-start_time2)
print(O2)
print("----------------Standard Self-Attention------------------------")
start_time3 = time.time()
scores = torch.matmul(Q, K.transpose(-2, -1))
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
print(output.shape, time.time()-start_time3)
print(output)
----------------FlashAttentionV1------------------------torch.Size([1, 1, 6, 4]) 0.0015511512756347656tensor([[[[ 0.2281, -0.2178, -0.3508, 0.1571], [-0.1962, -0.6078, -0.4992, -0.5868], [ 0.3373, 0.3694, 0.2818, 0.2253], [-0.3096, -0.6828, -0.4914, -0.9161], [ 0.0873, 0.6567, 0.1782, 0.1638], [ 0.1808, -0.2194, -0.4053, 0.1305]]]], grad_fn=<CatBackward0>)----------------FlashAttentionV2------------------------torch.Size([1, 1, 6, 4]) 0.0009410381317138672tensor([[[[ 0.2281, -0.2178, -0.3508, 0.1571], [-0.1962, -0.6078, -0.4992, -0.5868], [ 0.3373, 0.3694, 0.2818, 0.2253], [-0.3096, -0.6828, -0.4914, -0.9161], [ 0.0873, 0.6567, 0.1782, 0.1638], [ 0.1808, -0.2194, -0.4053, 0.1305]]]], grad_fn=<CatBackward0>)----------------Standard Self-Attention------------------------torch.Size([1, 1, 6, 4]) 0.00012636184692382812tensor([[[[ 0.2281, -0.2178, -0.3508, 0.1571], [-0.1962, -0.6078, -0.4992, -0.5868], [ 0.3373, 0.3694, 0.2818, 0.2253], [-0.3096, -0.6828, -0.4914, -0.9161], [ 0.0873, 0.6567, 0.1782, 0.1638], [ 0.1808, -0.2194, -0.4053, 0.1305]]]], grad_fn=<UnsafeViewBackward0>)
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-12-26
Token已死?AI认知的新范式正在崛起
2024-12-26
字节和BAT,谁能缚住AI苍龙?
2024-12-26
大模型语义分析之嵌入(Embedding)模型
2024-12-26
Anthropic:高效构建AI Agent的最佳实践范式
2024-12-26
微软CEO纳德拉给出AI时代的关键答案:先有组织进化,才有技术突破(附视频)
2024-12-26
10分钟了解大模型应用全貌 : 大模型应用架构(LLM application architecture)
2024-12-26
聊聊 Anthropic MCP (Model Context Protocol ) - 本地如何配置试用
2024-12-26
首次!大模型自动搜索人工生命,做出AI科学家的Sakana AI又放大招
2024-05-28
2024-08-13
2024-04-26
2024-08-21
2024-07-09
2024-06-13
2024-08-04
2024-04-11
2024-07-18
2024-07-01