AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


标杆开源大语言模型Llama的架构与RoPE
发布日期:2024-04-30 19:46:13 浏览次数: 2127 来源:明日丽的AI厨房


本计划先说说Google新论文动态分配计算的transformer,但是最近meta 发布了Llama3,这是个大新闻,所以我们先说说关于的Llama的架构。

同作为decoder-onlytransformer Llama的架构和我们说过的GPT2相差并不算大,主要体现在以下3点:

1 GQA, Grouped-Query Attention

2 旋转位置编码RoPE

3 使用RMSNorm

Llama3对比Llama2其改动主要体现在:

1 使用了更多数据训练(15T

2 采用了新的 Tokenizer,将词汇表大小扩展至 128,256(前版本为 32,000 Token

Llama的架构如下:

首先我们说说Grouped-Query AttentionMulti-Head-Attentionq k v 分为N组,每一组分别做 Attention,然后再concat。为了在保证效果的前提下节省计算量,Grouped-Query Attention采用了,一组Q共享一个K V的机制去做Attention

使用hidden_size = 768 ,num_heads = 8, num_key_value_heads =2,也就是分两组,一个headdim96,打印出来就是:

代码中 kv 重复 8/2 = 4 次:

keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

q k 矩阵乘法画出来就是

旋转位置编码是位置编码的一种方式,其核心思想在于,利用序列中token的距离产生相对的位置关系

具体做法是按照一个固定的角度,与位置旋转向量qk,比如角度为θ, q的位置为mk的位置为n,则qm倍θ度,kn倍θ度。这样做的好处在于,现在的编码不是绝对位置,而是相对qk距离而产生的编码!

下图的R表示旋转矩阵:

以二维向量来看,旋转过程如下

现在的问题在于,q k是高维向量没法直接旋转,所以只能分成一组组二维向量,然后一组给一个θ按组转!于是旋转矩阵就变成了

现在整个qk的旋转如下图:

总结下来,RoPE其实就是将 q,k 分为dim/2 组二维向量,每一组给一个固定的角度θ,按照q k位置关系 (n-m)*θ去旋转,以此获得相对位置编码。

在代码中,我们并不会真的去乘那个大旋转矩阵R,而是采用等价的方式实现,甚至不会用到sin cos直接用复数就行:


在代码中Llama预先计算好所有位置的旋转角度,

self.freqs_cis = precompute_freqs_cis(    self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)


def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.
Args: xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. " xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk)

注:以hidden_size = 768,num_heads = 8,num_key_value_heads =2算出来如下图(因为是按照二维分组,所以是96还要除2)

最后计算旋转位置编码的代码为:

def apply_rotary_emb(    xq: torch.Tensor,    xk: torch.Tensor,    freqs_cis: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:    """    Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.
Args: xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
Returns:        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk)


53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询