AI知识库

53AI知识库

学习大模型的前沿技术与行业应用场景


图算法之GraphSAGE原理以及代码实现
发布日期:2024-05-07 07:37:17 浏览次数: 1938


一、GraphSAGE实现原理

    相比DeepWalk、Node2Vec以及GCN等直推式(transductive)Graph Embedding框架,归纳式(inductive)框架GraphSAGE不再学习图中所有节点的Embedding表达,而是学习一个为每个节点产生其对应Embedding表示的映射。它的核心思想是通过学习如何聚合局部领域信息来生成节点的嵌入,而不依赖整个图的结构。并且GraphSAGE通过局部邻居采样和归纳式框架设计,能够更好地扩展到大规模图数据,同时便于处理图结构的动态变化。

直推式学习与归纳式学习:
  1. 直推式学习(Transductive Learning):

    在直推式学习中,模型仅对训练集中的样本进行预测,不对测试集中的未见样本进行预测。

    直推式学习只对训练数据中的样本进行预测,不对新样本进行泛化,因此适用于需要对训练数据进行标注的任务,但不需要对新数据进行预测的场景。


  2. 归纳式学习(Inductive Learning):

    在归纳式学习中,模型从训练数据中学习到的知识可以泛化到未见过的测试数据上。旨在构建一个从输入到输出的映射,使得模型能够对未见过的样本进行预测。

    归纳式学习则旨在通过训练数据构建一个泛化的模型,使得模型能够对未见过的新数据进行准确预测,因此适用于需要进行预测的任务。

      GraphSAGE中的SAGE指的是SAmple and aggreGatE,即不是参考全局的图结构对每个顶点都训练一个单独的Embedding向量而,是通过对每个节点的邻居节点进行采样,训练出一组 aggregator functions,这些函数学习如何从一个节点的局部邻居聚合特征信息。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成其对应的Embedding。

GraphSAGE实现步骤:
  1. a.采样邻居:对于每个目标节点,从其邻居中随机采样固定数量的邻居节点。这样做可以减少计算量并保持算法的可扩展性。

  2. b.聚合邻居信息:GraphSAGE定义了多种聚合函数(如均值聚合器、池化聚合器、LSTM聚合器等),用以聚合采样出的邻居节点的特征。这些聚合函数可以学习到如何从邻居节点的特征中提取有用信息。

  3. c.更新节点嵌入:将聚合后的邻居特征与目标节点自身的特征结合(通常是通过拼接或求和),然后通过一个神经网络层来更新节点的嵌入。

  4. d.重复与优化:对图中所有节点重复上述步骤,通过反向传播和梯度下降等优化方法来训练模型参数。

       上图是为红色的目标节点生成Embedding的过程。k表示距离目标节点的搜索深度,k=1就是目标节点的相邻节点,k=2表示目标节点的二跳邻居节点。对于上图中的例子:

  • 第一步是采样,k=1采样了3个节点,对k=2采用了5个节点;

  • 第二步是聚合邻居节点的信息,获得目标节点的Embedding;

  • 第三步是使用聚合得到的信息,也就是目标节点的Embedding,来预测图中想预测的信息;


二、GraphSAGE伪代码

      这里的K指的是网络的层数,也代表着每个顶点能够聚合的邻接点的跳数,如K=2的时候每个顶点可以最多根据其2跳邻接点的信息学习其自身的Embedding表示。在每一层的循环k中,对每个顶点v,首先使用v的邻接点的k-1层的embedding表示来产生其邻居顶点的第k层聚合表示,之后将和顶点v的第k-1层表进行拼接,经过一个非线性变换产生顶点v的第k层Embedding表示

三、GraphSAGE的聚合器

      Aggregator 的作用是把一个向量的集合转换成向量,也就是聚合。和其他机器学习任务中的数据(如图像,文本等)不同,图中的节点是没有顺序的(node’s neighbors have no natural ordering),即aggregator function操作的是一个无序的向量集合代表了节点v的邻居节点集合),所以希望构造出的聚合函数是对称的(即对顺序不敏感,改变输入的顺序,函数的输出结果不会发生任何变化),同时具有较高的表达能力。

  • Mean aggregator:显然对向量集合,对应元素取均值是最直接的想法,将目标顶点和邻居顶点的第k-1层向量拼接起来,然后对向量的每个维度求均值。

  • LSTM aggregator:和mean aggregator相比,LSTM有更大的表达能力。但是LSTM不符合symmetric的性质,输入是有顺序的。所以把相邻节点的向量集合随机打乱顺序,然后作为LSTM的输入。

  • Pooling aggregator:尝试了pooling做aggregator, 所有相邻节点的向量共享权重,先经过一个非线性全连接层,然后做max-pooling。


四、GraphSAGE的损失函数

有监督损失函数:根据具体的任务而定,针对节点分类任务的常规交叉熵样式预测。

无监督损失函数

      损失函数的蓝色部分试图强制说明,如果节点u和v在实际图中接近,则它们的节点嵌入在语义上应该相似。在理想情况下,我们期望的内积很大。如此大的数值输入到Sigmoid输出会接近1且log(1)=0

     损失函数的粉红色部分试图强制执行相反的操作!也就是说,如果节点u和v在实际图形中实际上相距较远,则我们期望它们的节点嵌入是不同的/相反的。在理想情况下,我们期望的内积为较大的负数。可以解释为,嵌入差别很大,以至于它们之间的距离大于90度。两个大负数的乘积变成一个大正数。如此大的数值输入到Sigmoid输出会接近1log(1)=0由于可能有更多的节点u远离我们的目标节点v在图中,我们从远离节点v的节点分布中仅采样了几个负节点u:这样可以确保训练时的损失功能达到平衡。另外添加epsilon可以确保我们永远不会取到log(0)。

五、GraphSAGE代码实现

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.nn import SAGEConvfrom torch_geometric.datasets import Planetoidfrom torch_geometric.data import DataLoader
# 导入所需的库# torch:PyTorch核心库# torch.nn:包含了神经网络层的库# torch.nn.functional:包含了神经网络函数的库# torch_geometric.nn:PyTorch Geometric中的图神经网络模块# torch_geometric.datasets:包含了一些常用的图数据集# torch_geometric.data:定义了用于处理图数据的数据结构和函数
class GraphSage(nn.Module): def __init__(self, in_channels, hidden_channels, num_layers): super(GraphSage, self).__init__() self.convs = nn.ModuleList() self.convs.append(SAGEConv(in_channels, hidden_channels)) for _ in range(num_layers - 1): self.convs.append(SAGEConv(hidden_channels, hidden_channels))
def forward(self, x, edge_index): for conv in self.convs: x = conv(x, edge_index) x = F.relu(x) return x
# 定义GraphSage模型类,继承自nn.Module# 初始化方法__init__中定义了GraphSage模型的网络层# forward方法中定义了模型的前向传播过程
# 加载数据集dataset = Planetoid(root='/tmp/Cora', name='Cora')
# 使用Planetoid加载Cora数据集,存储在/tmp/Cora目录下
# 划分数据集train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# 将数据集划分为mini-batch,每个batch大小为64,进行随机打乱
# 实例化GraphSage模型model = GraphSage(in_channels=dataset.num_features, hidden_channels=16, num_layers=2)
# 创建GraphSage模型的实例,指定输入特征维度、隐藏层维度和层数
# 训练循环optimizer = torch.optim.Adam(model.parameters(), lr=0.01)criterion = nn.CrossEntropyLoss()
# 使用Adam优化器和交叉熵损失函数进行模型训练
def train(): model.train() total_loss = 0 for data in train_loader: optimizer.zero_grad() out = model(data.x, data.edge_index) loss = criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(train_loader)
# 定义训练函数train,对模型进行训练# 遍历每个mini-batch,计算损失并更新模型参数
# 训练模型for epoch in range(100): loss = train() print(f'Epoch {epoch + 1}, Loss: {loss:.4f}')
# 训练模型100个epoch,打印每个epoch的损失值



53AI,企业落地应用大模型首选服务商

产品:大模型应用平台+智能体定制开发+落地咨询服务

承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业

联系我们

售前咨询
186 6662 7370
预约演示
185 8882 0121

微信扫码

与创始人交个朋友

回到顶部

 
扫码咨询