微信扫码
与创始人交个朋友
我要投稿
本计划先说说Google新论文动态分配计算的transformer,但是最近meta 发布了Llama3,这是个大新闻,所以我们先说说关于的Llama的架构。
同作为decoder-only的transformer ,Llama的架构和我们说过的GPT2相差并不算大,主要体现在以下3点:
1 (GQA, Grouped-Query Attention)
2 旋转位置编码RoPE
3 使用RMSNorm
Llama3对比Llama2其改动主要体现在:
1 使用了更多数据训练(15T)
2 采用了新的 Tokenizer,将词汇表大小扩展至 128,256(前版本为 32,000 Token)
Llama的架构如下:
首先我们说说Grouped-Query Attention,Multi-Head-Attention将q k v 分为N组,每一组分别做 Attention,然后再concat。为了在保证效果的前提下节省计算量,Grouped-Query Attention采用了,一组Q共享一个K V的机制去做Attention
使用hidden_size = 768 ,num_heads = 8, num_key_value_heads =2,也就是分两组,一个head的dim为96,打印出来就是:
代码中 k,v 重复 8/2 = 4 次:
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
q k 矩阵乘法画出来就是
旋转位置编码是位置编码的一种方式,其核心思想在于,利用序列中token的距离产生相对的位置关系
具体做法是按照一个固定的角度,与位置旋转向量q和k,比如角度为θ, q的位置为m,k的位置为n,则q转m倍θ度,k转n倍θ度。这样做的好处在于,现在的编码不是绝对位置,而是相对q与k距离而产生的编码!
下图的R表示旋转矩阵:
以二维向量来看,旋转过程如下
现在的问题在于,q k是高维向量没法直接旋转,所以只能分成一组组二维向量,然后一组给一个θ按组转!于是旋转矩阵就变成了
现在整个qk的旋转如下图:
总结下来,RoPE其实就是将 q,k 分为dim/2 组二维向量,每一组给一个固定的角度θ,按照q k位置关系 (n-m)*θ去旋转,以此获得相对位置编码。
在代码中,我们并不会真的去乘那个大旋转矩阵R,而是采用等价的方式实现,甚至不会用到sin cos直接用复数就行:
在代码中Llama预先计算好所有位置的旋转角度,
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
注:以hidden_size = 768,num_heads = 8,num_key_value_heads =2算出来如下图(因为是按照二维分组,所以是96还要除2)
最后计算旋转位置编码的代码为:
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
-> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings.
xk (torch.Tensor): Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:
Tuple of modified query tensor and key tensor with rotary embeddings. :
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-05-14
2024-04-26
2024-03-30
2024-04-12
2024-05-10
2024-07-18
2024-05-22
2024-05-28
2024-04-25
2024-04-26
2024-11-14
2024-11-13
2024-11-13
2024-11-13
2024-11-12
2024-11-11
2024-11-08
2024-11-07