GraphSAGE图算法:原理详解与代码实现
一、GraphSAGE核心机制与归纳式学习范式
先厘清几项关键认知。在图表征学习领域,早期方案如DeepWalk、Node2Vec乃至GCN,均属于直推式(transductive)模式——训练阶段需观测整张图,每个节点绑定一个固定嵌入向量,一旦图结构动态变化(例如增加节点),必须重新训练整个模型。GraphSAGE则采用截然不同的路径:归纳式(inductive)。它不学习节点的静态表征,而是学习一个“如何利用局部邻居信息生成节点嵌入”的映射函数。换言之,给定一个全新节点,只要知晓其邻居分布,即可即时计算其嵌入向量。这意味着模型天然适配大规模图数据,并能轻松应对动态演化的图结构。
直推式学习与归纳式学习的对比:
- 直推式学习(Transductive Learning):
直推式模式下,模型仅能在训练集内已出现的样本上进行预测,对于未见过的新样本毫无泛化能力。适用场景:手中仅有一批固定节点,只需为其标注,无需推演至新数据。
- 归纳式学习(Inductive Learning):
归纳式模式旨在从训练数据中提取可迁移的规律,使模型能够对从未见过的新样本做出合理推断。这才是真正意义上的“泛化”,也是生产环境中应用最广的范式。
聚焦GraphSAGE本身。SAGE全称Sampling and Aggregation,核心操作简化为两步:采样邻居、聚合信息。它不依赖全局图拓扑,而是为每个节点划定局部邻居采样范围,再训练一组聚合函数(aggregator),这些函数能够从不同跳数(hops)的邻居逐层提取特征。推理时,即便节点完全陌生,只需以相同方式采样邻居并调用聚合函数,即可输出其嵌入。
GraphSAGE实现步骤:
- a. 邻居采样:对每个目标节点,从其邻居中随机抽取固定数量节点。此举旨在控制计算规模,使算法能扩展至百万级节点。
- b. 邻居信息聚合:定义多种聚合函数(均值、池化、LSTM等),将采样到的邻居节点特征整合为统一的向量表示。
- c. 节点嵌入更新:将聚合后的邻居特征与目标节点自身的特征拼接(或求和),接入全连接层进行非线性变换,得到当前层的嵌入。
- d. 迭代与优化:对整个图的所有节点重复上述流程,通过反向传播训练聚合函数及变换层的参数。
上图中展示了为红色目标节点生成Embedding的完整流程。k表示搜索深度:k=1时采样3个直接邻居,k=2时采样5个二跳邻居。具体来说:第一步采样邻居节点;第二步将邻居信息逐层聚合,更新目标节点嵌入;第三步利用该嵌入执行下游预测任务。
二、GraphSAGE伪代码结构
此处K对应网络层数,也决定了每个节点能够聚合的跳数。例如K=2时,节点可聚合两跳内的邻居信息。每层循环中,首先对节点v的所有邻居,利用上一层的嵌入聚合出邻居的当前层表示,随后与v的上一层表示拼接,经过非线性变换,得到v的当前层嵌入。
三、GraphSAGE的聚合器设计
聚合函数需要解决的核心问题:将一个无序的向量集合压缩为单一向量。图上的邻居不具备天然顺序,因此聚合器必须是对称的——无论邻居节点以何种顺序输入,输出结果均保持一致。同时,还需具备足够的表达能力。
- Mean aggregator:最直接的方法——将目标节点和邻居节点的上一层向量拼接,再对各维度取均值。简单、稳定、高效。
- LSTM aggregator:表达能力更强,但LSTM对输入顺序敏感。为绕过此限制,实践中先将邻居节点的向量集合随机打乱,再送进LSTM,通过随机化消除顺序影响。
- Pooling aggregator:所有邻居节点共享一个权重矩阵,先经过非线性全连接层,再对各维度执行max-pooling(或mean-pooling)。效果通常出色,且计算效率高。
四、GraphSAGE的损失函数设定
有监督损失函数:直接依据下游任务定义。若为节点分类,则使用标准交叉熵损失。
无监督损失函数:
损失函数由两部分构成。蓝色部分期望:若节点u与v在图上是相邻或近距离的,则其嵌入向量的内积应较大,经Sigmoid后接近1,log损失逼近0。粉色部分则施加反向约束:若u与v在图相距较远,其嵌入内积应为绝对值较大的负数(夹角超过90度),经Sigmoid后同样接近1,损失也为0。实际训练中,远离节点v的“负样本”u数量远多于正样本,因此从远距离分布中随机采样部分负节点,并添加极小的epsilon防止log(0)。如此正负样本均衡,模型方能学到有意义的嵌入。
五、GraphSAGE代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader
# 加载PyTorch及PyTorch Geometric核心模块
# torch:深度学习框架基础
# torch.nn:神经网络层构建库
# torch.nn.functional:函数式神经网络操作
# torch_geometric.nn:图神经网络专用模块
# torch_geometric.datasets:常用图数据集封装
# torch_geometric.data:图数据容器与加载工具
class GraphSage(nn.Module):
def __init__(self, in_channels, hidden_channels, num_layers):
super(GraphSage, self).__init__()
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
def forward(self, x, edge_index):
for conv in self.convs:
x = conv(x, edge_index)
x = F.relu(x)
return x
# 定义GraphSage模型类,继承自nn.Module
# __init__中构建多层SAGEConv网络
# forward中定义前向传播流程,逐层加非线性激活
# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
# 通过Planetoid获取Cora引文网络,数据缓存至/tmp/Cora
# 划分数据集为mini-batch
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# 按batch_size=64划分,训练时随机打乱
# 实例化GraphSage模型
model = GraphSage(in_channels=dataset.num_features, hidden_channels=16, num_layers=2)
# 构建两层GraphSAGE,输入特征维度由数据集决定,隐藏层维度16
# 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# 采用Adam优化器与交叉熵损失函数
def train():
model.train()
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
# 训练函数:遍历每个batch,前向传播、计算损失、反向传播更新参数
# 执行100轮训练
for epoch in range(100):
loss = train()
print(f'Epoch {epoch + 1}, Loss: {loss:.4f}')
# 每轮输出平均损失,监控模型收敛
