微信扫码
与创始人交个朋友
我要投稿
本人是某双一流大学硕士生,也最近刚好准备参加 2024年秋招,在找大模型算法岗实习中,遇到了很多有意思的面试,所以将这些面试题记录下来,并分享给那些和我一样在为一份满意的offer努力着的小伙伴们!!!
Self-Attention 的时间复杂度/空间复杂度是怎么计算的?
阅读 Transformer 相关的论文,在讨论 self-attention 的时间和空间复杂度时,都会提到是 O(N^2),其中 N 是序列长度。
关于时间复杂度(time complexity)或空间复杂度 O(·),首先要知道这只是一种定性分析,而不是精确的定量分析。
我们来看下 scaled dot-product attention的时间和空间复杂度:
为了分析时间和时间复杂度,我们把上面的计算过程拆分:
只要我们分析出每个计算的复杂度,就可以得到整体计算的复杂度。
我们先来看第一个矩阵乘法
矩阵乘法的朴素算法时间复杂度是
至于空间复杂度,只看存储 计算结果,复杂度是 ,但是也不要觉得这个数字很大,如果 ,其实存储 Q 和 K 要比 更占内(显)存。除非是序列很长 ,空间复杂度 才会是瓶颈。
简单回顾下矩阵乘法
C = np.zeros((m, l))
for i in range(m):
for j in range(l):
for k in range(n):
C[i][j] += A[i][k] * B[k][j]
显而易见,3 个 for 循环,因此矩阵乘法时间复杂度 。
我在网上查找矩阵乘法时间复杂度分析的资料时,发现很多人喜欢用numpy.dot+画图的方式来直观展示,很有趣:
m = 64
n = 64
l = 64
times = []
ms = []
for i in range(20):
ms.append(m)
begin = time.time()
m1 = np.random.random((m, n))
m2 = np.random.random((n, l))
times.append(time.time() - begin)
m *= 2# 改变 m 的大小, 同理可以改变 n 或 l 的大小
# 画图
fig, ax = plt.subplots()
ax.set_ylabel('Time')
ax.set_xlabel('Array Dimension Size')
ax.plot(ms, times)
plt.show()
可以看到,矩阵乘法时间和其中 m 维度大小成正相关,斜率 ~1,如果改变 n 或 l 也会得到相同的结论,因此矩阵乘法时间复杂度是 。
由于
因此
的时间复杂度为
我们再来看 softmax 时间复杂度,假设 z 是一维向量:
def softmax(x):
m_val = max(x)
x = [i-m_val for i in x]
x = [math.exp(i) for i in x]
deno = sum(x)
return [item / deno for item in x]
softmax([1,2,3])# [0.0900, 0.2447, 0.6652]
Self-Attention包括三个步骤:相似度计算,softmax和加权平均
因此,Self-Attention的时间复杂度是 。
这样,整个
的时间复杂度是:
如果把向量维度 d 看作常数,则可以说 self-attention 的时间复杂度是序列长度的平方。
再来看下空间复杂度,不论是存储
最后存储
的空间复杂度是
这样,整个空间复杂度可以看作:
如果把向量维度 d 看作常数,则可以说 self-attention 的空间复杂度是序列长度的平方。
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-08-03
2024-07-31
2024-07-25
2024-09-12
2024-07-25
2024-08-06
2024-07-09
2024-06-03
2024-10-17
2024-06-01