Graph Convolution Network 理解与实现
https://zhuanlan.zhihu.com/p/51990489
Graph Convolution作为Graph Networks的一个分支,可以说几乎所有的图结构网络都是大同小异,详见综述,而Graph Convolution Network又是Graph Networks中最简单的一个分支。理解了它便可以理解很多近年来的图结构网络,比如Scene Graph Generation中的Message Passing机制等。后续打算持续更新一些原始GCN的变体。
【相关文章和网站】:
- Paper: Semi-Supervised Classification with Graph Convolutional Networks, 2016
- Paper: Gated Graph Sequence Neural Networks, 2016
- Website: How powerful are Graph Convolutional Networks?
- Github: 关于Gated Graph Convolution Network的Pytorch实现 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0
- 其实Graph Convolution Network (GCN)可以看作Graph Networks的一个分支(只有Node feature,无Edge feature和global attribute),而Graph Networks则有一篇2018年的综述:Relational inductive biases, deep learning, and graph networks, 2018
【Graph Convolution Network和传统CNN的关系】
我们不妨把传统的CNN的输入图片(I)也定义为一个Graph,他包含一堆Pixel集合({p_i})看作是Node, 而graph的边则是通过pixel的连通性定义的,所以每个pixel有至多8个edge和他相连。而Convolution其实就是把他的8个neighbour pixel的feature和他自己的feature乘以一个可学习的参数化kernel,来update这个pixel的feature.
那么由此,就不难理解GCN了。GCN主要的区别在于,他的node间的边,不是通过连通性定义的,而是需要给定了一个edge set,或者说graph的adjacent matrix。而且由于每个node可以有任意数量的neighbour node,所以update feature时,所有node其实是乘以了同一套参数。
【公式化】
这里我们参考Semi-Supervised Classification with Graph Convolutional Networks, 2016给出Graph Convolution的最终公式,忽略了原文的推导过程。
GCN可以定义为如下公式:
- 这里(Xin R^{N imes C})是node输入,包含N个node,每个Node有C维的feature, A是Adjacent matrix,,(A_{ij})定义node i和node j间是否有边edge,(Zin R^{N imes F})是输出, F表示新特征的维度
详细展开如下:
- 这里(Thetain R^{C imes F})就是要学习的参数,(hat{A}=A+I),(I)是单位矩阵,(hat{D}_{ii}=sum_{j}hat{A}_{ij})是对角线矩阵,对角线上每个元素表示,node i的neighbor数(包括自身)。所以其实等式右边可以看作(XTheta)把所有node的feature从C映射到F维,而每个node的新feature(Z_iin R^F)等于所有和他相连的node(包括自身)的F维feature的加权和,即average
【伪代码实现】
【Gated Graph Convolution Network】
但是上述Node特征更新的方式比较原始,Gated Graph Sequene Neural Networks, ICLR, 2016将Graph Convolution的X to Z的更新改成了GRU(LSTM)的形式。同时设计了一个Graph-Level的特征。下面实现参考了上文的思想,但做了些简化,比如原文将Incoming Edges和Outgoing Edges区分了这里我就沿用朴素Graph Convolution的A,不做拓展。
【Gated Graph Convolution Network 公式&伪代码】
- 上述W都是可学习的参数
【Graph-Level特征获取】
很多应用需要将一整个graph整合成一个特征,而原始的Graph Convolution则只能生成每个node的特征。graph-level的定义如下:
当然,还有很多文章,采取更为简单的graph-level feature提取方法:
【Code】
关于Gated Graph Convolution Network的代码,可以参考以下Github项目 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0