微信扫码
添加专属顾问
我要投稿
一、GraphSAGE实现原理
相比DeepWalk、Node2Vec以及GCN等直推式(transductive)Graph Embedding框架,归纳式(inductive)框架GraphSAGE不再学习图中所有节点的Embedding表达,而是学习一个为每个节点产生其对应Embedding表示的映射。它的核心思想是通过学习如何聚合局部领域信息来生成节点的嵌入,而不依赖整个图的结构。并且GraphSAGE通过局部邻居采样和归纳式框架设计,能够更好地扩展到大规模图数据,同时便于处理图结构的动态变化。
在直推式学习中,模型仅对训练集中的样本进行预测,不对测试集中的未见样本进行预测。
直推式学习只对训练数据中的样本进行预测,不对新样本进行泛化,因此适用于需要对训练数据进行标注的任务,但不需要对新数据进行预测的场景。
在归纳式学习中,模型从训练数据中学习到的知识可以泛化到未见过的测试数据上。旨在构建一个从输入到输出的映射,使得模型能够对未见过的样本进行预测。
归纳式学习则旨在通过训练数据构建一个泛化的模型,使得模型能够对未见过的新数据进行准确预测,因此适用于需要进行预测的任务。
GraphSAGE中的SAGE指的是SAmple and aggreGatE,即不是参考全局的图结构对每个顶点都训练一个单独的Embedding向量而,是通过对每个节点的邻居节点进行采样,训练出一组 aggregator functions,这些函数学习如何从一个节点的局部邻居聚合特征信息。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成其对应的Embedding。
a.采样邻居:对于每个目标节点,从其邻居中随机采样固定数量的邻居节点。这样做可以减少计算量并保持算法的可扩展性。
b.聚合邻居信息:GraphSAGE定义了多种聚合函数(如均值聚合器、池化聚合器、LSTM聚合器等),用以聚合采样出的邻居节点的特征。这些聚合函数可以学习到如何从邻居节点的特征中提取有用信息。
c.更新节点嵌入:将聚合后的邻居特征与目标节点自身的特征结合(通常是通过拼接或求和),然后通过一个神经网络层来更新节点的嵌入。
d.重复与优化:对图中所有节点重复上述步骤,通过反向传播和梯度下降等优化方法来训练模型参数。
上图是为红色的目标节点生成Embedding的过程。k表示距离目标节点的搜索深度,k=1就是目标节点的相邻节点,k=2表示目标节点的二跳邻居节点。对于上图中的例子:
第一步是采样,k=1采样了3个节点,对k=2采用了5个节点;
第二步是聚合邻居节点的信息,获得目标节点的Embedding;
第三步是使用聚合得到的信息,也就是目标节点的Embedding,来预测图中想预测的信息;
二、GraphSAGE伪代码
这里的K指的是网络的层数,也代表着每个顶点能够聚合的邻接点的跳数,如K=2的时候每个顶点可以最多根据其2跳邻接点的信息学习其自身的Embedding表示。在每一层的循环k中,对每个顶点v,首先使用v的邻接点的k-1层的embedding表示来产生其邻居顶点的第k层聚合表示
,之后将
和顶点v的第k-1层表示
进行拼接,经过一个非线性变换产生顶点v的第k层Embedding表示
。
三、GraphSAGE的聚合器
Aggregator 的作用是把一个向量的集合转换成向量,也就是聚合。和其他机器学习任务中的数据(如图像,文本等)不同,图中的节点是没有顺序的(node’s neighbors have no natural ordering),即aggregator function操作的是一个无序的向量集合(
代表了节点v的邻居节点集合),所以希望构造出的聚合函数是对称的(即对顺序不敏感,改变输入的顺序,函数的输出结果不会发生任何变化),同时具有较高的表达能力。
Mean aggregator:显然对向量集合,对应元素取均值是最直接的想法,将目标顶点和邻居顶点的第k-1层向量拼接起来,然后对向量的每个维度求均值。
LSTM aggregator:和mean aggregator相比,LSTM有更大的表达能力。但是LSTM不符合symmetric的性质,输入是有顺序的。所以把相邻节点的向量集合随机打乱顺序,然后作为LSTM的输入。
Pooling aggregator:尝试了pooling做aggregator, 所有相邻节点的向量共享权重,先经过一个非线性全连接层,然后做max-pooling。
四、GraphSAGE的损失函数
有监督损失函数:根据具体的任务而定,针对节点分类任务的常规交叉熵样式预测。
无监督损失函数:
损失函数的蓝色部分试图强制说明,如果节点u和v在实际图中接近,则它们的节点嵌入在语义上应该相似。在理想情况下,我们期望和
的内积很大。如此大的数值输入到Sigmoid输出会接近1且log(1)=0。
损失函数的粉红色部分试图强制执行相反的操作!也就是说,如果节点u和v在实际图形中实际上相距较远,则我们期望它们的节点嵌入是不同的/相反的。在理想情况下,我们期望和
的内积为较大的负数。可以解释为,嵌入
和
差别很大,以至于它们之间的距离大于90度。两个大负数的乘积变成一个大正数。如此大的数值输入到Sigmoid输出会接近1,log(1)=0。由于可能有更多的节点u远离我们的目标节点v在图中,我们从远离节点v的节点分布中仅采样了几个负节点u:
。这样可以确保训练时的损失功能达到平衡。另外添加epsilon可以确保我们永远不会取到log(0)。
五、GraphSAGE代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader
# 导入所需的库
# torch:PyTorch核心库
# torch.nn:包含了神经网络层的库
# torch.nn.functional:包含了神经网络函数的库
# torch_geometric.nn:PyTorch Geometric中的图神经网络模块
# torch_geometric.datasets:包含了一些常用的图数据集
# torch_geometric.data:定义了用于处理图数据的数据结构和函数
class GraphSage(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers):
super(GraphSage, self).__init__()
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
def forward(self, x, edge_index):
for conv in self.convs:
x = conv(x, edge_index)
x = F.relu(x)
return x
# 定义GraphSage模型类,继承自nn.Module
# 初始化方法__init__中定义了GraphSage模型的网络层
# forward方法中定义了模型的前向传播过程
# 加载数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
# 使用Planetoid加载Cora数据集,存储在/tmp/Cora目录下
# 划分数据集
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# 将数据集划分为mini-batch,每个batch大小为64,进行随机打乱
# 实例化GraphSage模型
model = GraphSage(in_channels=dataset.num_features, hidden_channels=16, num_layers=2)
# 创建GraphSage模型的实例,指定输入特征维度、隐藏层维度和层数
# 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# 使用Adam优化器和交叉熵损失函数进行模型训练
def train():
model.train()
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
# 定义训练函数train,对模型进行训练
# 遍历每个mini-batch,计算损失并更新模型参数
# 训练模型
for epoch in range(100):
loss = train()
print(f'Epoch {epoch + 1}, Loss: {loss:.4f}')
# 训练模型100个epoch,打印每个epoch的损失值
53AI,企业落地大模型首选服务商
产品:场景落地咨询+大模型应用平台+行业解决方案
承诺:免费场景POC验证,效果验证后签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2025-02-01
2025-01-01
2024-08-13
2024-04-25
2025-02-04
2024-07-25
2024-06-13
2024-04-26
2024-09-23
2024-04-12
2025-02-24
2025-02-23
2025-02-23
2025-02-23
2025-02-23
2025-02-22
2025-02-22
2025-02-22