PyTorch第三方库 2. GNN使用PyG的实现( 三 )


文章插图
详情可参考:
????????????
此模块首先采用模块对图上每一个节点做节点嵌入(Node ),得到节点表征;然后对节点表征做图池化得到图的表征;最后用一层线性变换对图表征转换为对图的预测 。
-节点嵌入——节点表征——图池化——图的表征——线性变换
输入到此节点嵌入模块的节点属性为类别型向量 。
步骤:
1)嵌入 。用对输入向量做嵌入得到第0层节点表征
2)计算节点表征 。
从第1层开始到第层逐层计算节点表征 。(每一层节点表征的计算都以上一层的节点表征[layer]、边和边的属性为输入)
注意事项:的层数越多,此节点嵌入模块的感受野( field)越大,结点i的表征最远能捕获到结点i的距离为的邻接节点的信息 。
- 输入的边属性(为类别型),先将类别型边属性转换为边表征,模块遵循:“消息传递,消息聚合,消息更新”这一过程 。
节点(原子)和边(化学键)的属性都为离散值,属于不同的空间 。
所以通过,将节点属性和边属性分别映射到一个新的空间,在这个新的空间中对节点和边进行消息聚合 。
注:节点属性有多少维,就需要多少个嵌入函数(通过调用 torch.nn.(dim,))可以实例化一个嵌入函数
其中dim为:被嵌入数据可能取值的数量;:要映射到的空间的维度 。
得到的嵌入函数,接收一个x(0
在()函数中,我们对不同属性值得到的不同嵌入向量进行了相加操作,实现了节点不同属性融合在一起(消息聚合) 。
【‘sclub 数据集进行PyG小白入门实战】
数据集——数据集展示——GCN网络定义——输入特征展示——训练模型
1. 数据集
from torch_geometric.datasets import KarateClubdataset = KarateClub()print(f'Dataset:{dataset}:')print('===================')print(f'Number of graphs:{len(dataset)}')print(f'Number of features:{dataset.num_features}')print(f'Number of classes:{dataset.num_classes}')'''Dataset:KarateClub():===================Number of graphs:1 #只有一个图,对点做分类Number of features:34 #每一个点有34个特征Number of classes:4#每个点做4分类'''dataset[0]'''Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])'''edge_index = dataset[0].edge_indexprint(edge_index.t())'''tensor([[ 0,1],[ 0,2],[ 0,3],[ 0,4],[ 0,5],[ 0,6],[ 0,7],[ 0,8],[ 0, 10],[ 0, 11],[ 0, 12],[ 0, 13],[ 0, 17],[ 0, 19],[ 0, 21],[ 0, 31],[ 1,0],[ 1,2],[ 1,3],[ 1,7],[ 1, 13],[ 1, 17],[ 1, 19],[ 1, 21],[ 1, 30],...[33, 29],[33, 30],[33, 31],[33, 32]])'''
图的表示用Data格式
其中上述的 x=[34,34],34*34(M×F——M:为样本的个数;F:每个样本的特征维度)=[2, 156] (:表示图的连接关系(start,end两个序列))
表示是稀疏的,可以看做邻接矩阵,但并不是传统意义上的n*n的邻接矩阵(而是[2,边的个数])y=[34] 标签(34个node)=[34] 哪些点是有标签的,哪些点时无标签的(有的节点木有标签,用来表示哪些节点要计算损失)
.nn:是可以调用的一些层
.data:是可以调用的一些数据(数据结构)
.:是可以调用的一些数据集
.utils:是可以调用的一些基本处理的函数
2. 数据集展示(使用可视化展示)
import matplotlib.pyplot as pltimport networkx as nxdef visualize_graph(G,color):plt.figure(figsize=(7,7))plt.xticks([])plt.yticks([])nx.draw_networkx(G,pos=nx.spring_layout(G,seed=42),with_labels=False,node_color=color,cmap="Set2")plt.show()from torch_geometric.utils import to_networkxdata = http://www.kingceram.com/post/dataset[0]G = to_networkx(data,to_undirected=True)visualize_graph(G,color=data.y)