[PyG]1.如何使⽤GCN完成⼀个最基本的训练过程(含GCN实现)
0. 前⾔
为啥要学习Pytorch-Geometric呢?(下⽂统⼀简称为PyG) 简单来说,是⽬前做的项⽬有⽤到,还有1个特点,就是相⽐NYU的DeepGraphLibrary, DGL的问题是API⽐较棘⼿,⽽且⽬前没有迁移的必要性。
图卷积框架能做的事情⽐较多,提供了很多⽅便的数据集和各种GNN SOTA的实现,其实最吸引我的就是这个framework的API⽐较友好,再加之使⽤PyG做项⽬的⼈⽐较多,⽣态对我这种做3D mesh的⼈⽐较友好。
注意, 本教程完全基于官⽅最新 (2020.04.14) 的教程,在此基础上,完成了简化版本的GCN的实现,对GCN的官⽅实现感兴趣的童鞋可以康康[1]。
下⾯,我将完全按照[1]的步骤来,不同之处在于,我在这⾥将基于PyG的最新版本(1.4.3)来分析GCN的简化版实现,让⼤家更加理解GCN的实现原理, 以下是阐述顺序:
①图数据的Data Handling
②Common Benchmark Datasets
③Mini-batches
④Data Transforms
⑤Learning Methods on Graphs
此外,我所使⽤的环境是:
Ubuntu 18.04
Cuda10.0
pytorch 1.4.0 conda install pytorch=1.4.0 cudatoolkit=10.0
pytorch geometric 1.4.3
torch-scatter pip install torch-scatter==latest+cu100 -f pytorch-geometric/whl/torch-1.4.0.html
torch-spline-conv pip install torch-spline-conv==latest+cu100 -f pytorch-geometric/whl/torch-1.4.0.html
torch-cluster pip install torch-cluster==latest+cu100 -f pytorch-geometric/whl/torch-1.4.0.html
torch-sparse pip install torch-sparse==latest+cu100 -f pytorch-geometric/whl/torch-1.4.0.html
1. 图结构的数据处理
⾸先,图是什么?图是边和点的相关关系的组合。在PyG中,⼀个简单的graph可以被描述为torch_geometric.data.Data[2]的实例,其中有⼏个重要的属性需要说明,此外,如果你的图需要扩展,那么你可以对torch_geometric.data.Data这个类进⾏修改即可。
图1.1 torch_geometric.data.Data的常⽤成员变量
通常来讲,对分⼀般的任务来说,数据类只需要有x,edge_index,edge_attr,y等⼏个属性即可,⽽且,这些属性都是optional(可选)的,也就是说,Data类并不局限于这些属性。
举个栗⼦,可以扩展data.face(torch.LongTensor, [3, num_faces])来保存3D mesh的三⾓形的连接关系.
图1.2 torch_geometric.data.Data的官⽅说明
图1.3 Data实例(3个节点,4条边(双向), 每个节点有2个特征[-1, 2], [0, 3], [1, 1].)
需要注意的是,尽管图只有2条边,我们还是需要定义4个index tuple来考虑边的双向关系。
图1.3搭建的graph的⽰意图如下:
2. 常见Benchmark数据集
尽管最近Bengio团队是基于DGL开发的6个Benchmark数据集,但是在pyG上做这个也没问题呀~。所以也不必直接因此就转去DGL。
PyTorch Geometric包含了⼤量的基础数据集, 所有的Planetoid datasets (Cora, Citeseer, Pubmed), 来
⾃多特蒙德⼯⼤的清洗过的图分类数据集, ⼀系列3D点云和mesh的数据集,⽐如FAUST,ShapeNet等。
PyG提供了这些数据的⾃动下载,并将其处理成之前说的Data形式,以ENZYMES数据集为例(包含600个图和6个类别):
图2.1 ENZYMES数据集的解析
由图2.1可见,其中的每个样本都是Data的instance,有顶点特征x,连接关系edge_index以及类别y 3个属性. 可以看出,ENZYMES的每个数据都是1个图。
注意: 可以通过使⽤dataset=dataset.shuffle()来对数据集进⾏shuffle。
此外,教程上还提供了Planetoid的Cora数据集的说明(⽤于semi-supervised graph node classification), 这⾥Cora数据集的数据有3个新的属性train_mask, test_mask, val_mask, 这3个属性⽤于表征需要训练、测试和验证的数据节点。
Cora与ENZYMES的区别是,Cora中的每个数据是整个图中的1个节点,⽽ENZYMES的每个数据都是1个独⽴的图。
图2.2 Cora数据集说明
3. Mini-Batches
我们知道,神经⽹络通常是按Batch训练的,PyG通过创建稀疏的邻接矩阵(sparse block diagnol adjacency matrices)实现在mini-batch上的并⾏化。
图3.1 PyG mini-batch对不同的节点、边数量的图的批处理
并按照node dimension来拼接节点特征x和类别特征y。通过这种⽅式,PyG可以在⼀个Batch中塞进不同nodes和edges数的样本。
图3.2 ENZYMES数据集加载说明(未shuffle)
(注意,这⾥的DataLoader⽤的是PyG⾃⼰的,⽽不是pytorch的,此外,use_node_attr=False时, x为[nodes_num, 3];
use_node_attr=True时, x为[nodes_num, 21])
这⾥,torch_geometric.data.Batch继承⾃ torch_geometric.data.Data,多了⼀个名为batch的属性,其作⽤是标⽰每个节点属于哪个图(ENZYMES)/样本.
此外,torch_geometric.data.DataLoader也只是pytorch的Dataloader重写了collate函数的版本⽽已。
正常传递给pytorch的Dataloader的参数,如pin_memory,num_workers等都可以传给torch_geometric.data.DataLoader.
当然,⽤户可以通过使⽤torch-scatter[3]对节点数据特征x进⾏⾃定义的处理并使⽤⾃定义的Dataset和Dataloader来处理⾃⼰的特殊形式数据[4].
4. Data Transforms
同torchvision在pytorch中的使⽤类似,我们也需要对graph数据进⾏处理和变换。PyG提供了⾃⼰的transform⽅式和⼯具包,要求的输⼊为Data对象,并返回transformed的Data对象。
类似地,transform可以通过ansforms.Compose来进⾏⼀系列的拼接。
作者举得例⼦是ShapeNet数据集(包含17,000 3D shape point clouds and per point labels from 16 shape categories)的Airplane类,作者通过pre_transform = T.KNNGraph(k=6)将point cloud数据变为graph数据集。
图4.1 ShapeNet数据集处理(将点云数据变为graph数据)
如有其它需要,⽤户可以⾃⼰去ansforms 进⾏查阅是否有符合⾃⼰⽬的的transform,没有的话⾃⼰写~
5. Learning Methods on Graphs
在搞定前4步后,现在让我们开始搞起第1个GNN~,这⾥,我们将会使⽤最基础的GCN层来复现Cora Citation数据集上的实验,若要理解GCN,需要从Fourier变换讲起,类⽐time domain --> frequency domain, 经过Hemlholtz公式,将vertex domain变到 spectral domain来分析,这样⼀来,vertex domain的卷积就变成了spectral domain的点乘,节省了计算量。
此外,变换的过程中, 还涉及到Laplacian矩阵L的意义(每个vertex的散度Divergence:可以理解为每个vertex的信息的增益情况,出射为正,⼊射为负),因为L的性质(半正定,特征值⼤于等于0等)
,假设其特征值为,特征向量为,通过与频谱图对⽐:就可以类⽐为Fourier变换的basis函数;
就类⽐为频率w GCN [5]就是在此基础上,经由2步优化得到的,它既考虑了self-loop,也考虑了k-localize(局部性),还对度进⾏了renormalization,避免马太效应过于明显,使得模型不会很容易陷⼊local minima。好了,就不再多提了,对GCN的推导和出现感兴趣的,可以看[6-7](先理解Laplacian矩阵和变换在图论中的⼀般含义, 再去油管上看台湾⼤学姜成翰助教关于GNN的教程)进⾏学习,下⾯我们看代码。
ubuntu使用入门教程5.1 GCN 在PyG 的实现
PyG提供了MessagePassing 这个base class,通过继承这个类,我们可以实现各种基于消息传递的GNN,借由MessagePassing
⽤户不再需要关注message的progation的内部实现细节,MessagePassing 主要关注其UPDATE, AGGREGATION, MESSAGE 这3个成员函数。
⽤户在实现⾃⼰的GNN时,⼀般只overwrite AGGREGATION, UPDATE这2个成员函数,MESSAGE/Propagate⽤MessagePassing ⾃带的。(官⽅的GCN就是这样的~)我们的⽬标是: 实现1个与官⽅⼀致的简化版的GCN,并通过实现它来掌握如何在PyG中定义图卷积。⾸先,我们先定义⼀个图数据data (有向图, 4个节点,3条边, 每个节点的特征维度都是1, 值也都为1):λU U λ

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。