PyTorch第三方库 2. GNN使用PyG的实现( 二 )


该范式包含这样三个步骤:
邻接节点信息变换邻接节点信息聚合到中心节点聚合信息变换
基于此范式,我们可以定义聚合邻接节点信息来生成中心节点表征的图神经网络 。在PyG中,基类是所有基于消息传递范式的图神经网络的基类,它大大地方便了我们对图神经网络的构建 。
【步骤】
class MessagePassing(aggr='add', flow='source_to_target', node_dim=0)
aggr: 定义要使用的聚合方案(“add”、“mean"或"max”)
flow: 定义消息传递的流向(“"或"”)
: 定义沿着哪个轴线传播
()()(查看的源码,可以看到其函数的定义(在PyG中是通过函数来实现上述过程)PyG教程(7):剖析邻域聚合)e()方法()方法继承基类的
数学定义以及PyG实现-
import torchfrom torch_geometric.nn import MessagePassingfrom torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add', flow='source_to_target')# "Add" aggregation (Step 5).# flow='source_to_target' 表示消息从源节点传播到目标节点self.lin = torch.nn.Linear(in_channels, out_channels)#所有逻辑在forward()方法中实现def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]# Step 1: Add self-loops to the adjacency matrix.# 使用torch_geometric.utils.add_self_loops() 给边索引添加自循环边【对应1. 向邻接矩阵添加自环边】edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: Linearly transform node feature matrix.# torch.nn.Linear 线性变换【对应2. 对节点表征做线性变换】x = self.lin(x)# Step 3: Compute normalization.【对应3. 计算归一化系数】row, col = edge_indexdeg = degree(col, x.size(0), dtype=x.dtype)deg_inv_sqrt = deg.pow(-0.5)norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]# Step 4-5: Start propagating messages.# progagate包含了(先调用message(),再aggregate,再update)return self.propagate(edge_index, x=x, norm=norm)#MessagePassing.propagate(edge_index, size=None, **kwargs):# 开始传播消息的起始调用 。它以edge_index(边的端点的索引)和flow(消息的流向)以及一些额外的数据为参数,size=(N,M)设置对称邻接矩阵的形状 。def message(self, x_j, norm):# x_j has shape [E, out_channels]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jfrom torch_geometric.datasets import Planetoiddataset = Planetoid(root='dataset/Cora', name='Cora')data = http://www.kingceram.com/post/dataset[0]print(data.x)#node_featureprint(data.edge_index)#边索引net = GCNConv(data.num_features, 64)h_nodes = net(data.x, data.edge_index)print(h_nodes.shape)
输出结果:
2. 节点/边表征以及图表征
在图节点预测或边预测任务中,首先需要生成节点表征(Node ) 。我们使用图神经网络来生成节点表征,并通过基于监督学习的对图神经网络的训练,使得图神经网络学会产生高质量的节点表征 。高质量的节点表征能够用于衡量节点的相似性,同时高质量的节点表征也是准确分类节点的前提 。
图表征学习要求根据节点属性、边和边的属性(如果有的话)生成一个向量作为图的表征,基于图表征我们可以做图的预测.
基于图同构网络(Graph, GIN)的图表征网络是当前最经典的图表征学习网络,为了得到图表征首先需要做节点表征,然后做图读出 。GIN中节点表征的计算遵循WL Test算法中节点标签的更新方法,因此它的上界是WL Test算法 。在图读出中,我们对所有的节点表征(加权,如果用的话)求和,这会造成节点分布信息的丢失 。

PyTorch第三方库  2. GNN使用PyG的实现