zoukankan      html  css  js  c++  java
  • Graph Convolution Network 理解与实现

    Graph Convolution Network 理解与实现

    https://zhuanlan.zhihu.com/p/51990489

    Graph Convolution作为Graph Networks的一个分支,可以说几乎所有的图结构网络都是大同小异,详见综述,而Graph Convolution Network又是Graph Networks中最简单的一个分支。理解了它便可以理解很多近年来的图结构网络,比如Scene Graph Generation中的Message Passing机制等。后续打算持续更新一些原始GCN的变体。

    【相关文章和网站】:

    1. Paper: Semi-Supervised Classification with Graph Convolutional Networks, 2016
    2. Paper: Gated Graph Sequence Neural Networks, 2016
    3. Website: How powerful are Graph Convolutional Networks?
    4. Github: 关于Gated Graph Convolution Network的Pytorch实现 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0
    5. 其实Graph Convolution Network (GCN)可以看作Graph Networks的一个分支(只有Node feature,无Edge feature和global attribute),而Graph Networks则有一篇2018年的综述:Relational inductive biases, deep learning, and graph networks, 2018

    【Graph Convolution Network和传统CNN的关系】

    img

    我们不妨把传统的CNN的输入图片(I)也定义为一个Graph,他包含一堆Pixel集合({p_i})看作是Node, 而graph的边则是通过pixel的连通性定义的,所以每个pixel有至多8个edge和他相连。而Convolution其实就是把他的8个neighbour pixel的feature和他自己的feature乘以一个可学习的参数化kernel,来update这个pixel的feature.

    那么由此,就不难理解GCN了。GCN主要的区别在于,他的node间的边,不是通过连通性定义的,而是需要给定了一个edge set,或者说graph的adjacent matrix。而且由于每个node可以有任意数量的neighbour node,所以update feature时,所有node其实是乘以了同一套参数。

    【公式化】

    这里我们参考Semi-Supervised Classification with Graph Convolutional Networks, 2016给出Graph Convolution的最终公式,忽略了原文的推导过程。

    GCN可以定义为如下公式:

    [Z=GCN(X,A) ]

    • 这里(Xin R^{N imes C})是node输入,包含N个node,每个Node有C维的feature, A是Adjacent matrix,,(A_{ij})定义node i和node j间是否有边edge,(Zin R^{N imes F})是输出, F表示新特征的维度

    详细展开如下:

    [Z=hat{D}^{-frac{1}{2}}hat{A}hat{D}^{-frac{1}{2}}XTheta ]

    • 这里(Thetain R^{C imes F})就是要学习的参数,(hat{A}=A+I),(I)是单位矩阵,(hat{D}_{ii}=sum_{j}hat{A}_{ij})是对角线矩阵,对角线上每个元素表示,node i的neighbor数(包括自身)。所以其实等式右边可以看作(XTheta)把所有node的feature从C映射到F维,而每个node的新feature(Z_iin R^F)等于所有和他相连的node(包括自身)的F维feature的加权和,即average

    【伪代码实现】

    [input : X, A, output : Z ]

    [Y = f_c(X) ]

    [Z = (A+I) * Y / (Acdot sum(1)+1) ]

    【Gated Graph Convolution Network】

    但是上述Node特征更新的方式比较原始,Gated Graph Sequene Neural Networks, ICLR, 2016将Graph Convolution的X to Z的更新改成了GRU(LSTM)的形式。同时设计了一个Graph-Level的特征。下面实现参考了上文的思想,但做了些简化,比如原文将Incoming Edges和Outgoing Edges区分了这里我就沿用朴素Graph Convolution的A,不做拓展。

    【Gated Graph Convolution Network 公式&伪代码】

    [input:X^t,output:X^{t+1}(即Z) ]

    [Y=Aast f_c(X^t) ]

    [U=sigma(W_1Y+W_2X^t) ]

    [R=sigma(W_3Y+W_4X^t) ]

    [X^{t+1}_{tem}=tanh(W_5Y+W_6(Rcdot X^t)) ]

    [X^{t+1}=(1-U)cdot X^t+Ucdot X^{t+1}_{tem} ]

    • 上述W都是可学习的参数

    【Graph-Level特征获取】

    很多应用需要将一整个graph整合成一个特征,而原始的Graph Convolution则只能生成每个node的特征。graph-level的定义如下:

    [h_G=tanh(sum_{nodes}sigma([X^T,X^0]))cdot tanh(f_{c_2}([X^T,X^0]) ]

    当然,还有很多文章,采取更为简单的graph-level feature提取方法:

    [h_G=sum_{nodes}X^T_i,or h_G=frac{1}{Num(nodes)}sum_{nodes}X^T_i ]

    【Code】

    关于Gated Graph Convolution Network的代码,可以参考以下Github项目 KaihuaTang/GGNN-for-bAbI-dataset.pytorch.1.0

    本文来自博客园,作者:甫生,转载请注明原文链接:https://www.cnblogs.com/fusheng-rextimmy/p/15387340.html

  • 相关阅读:
    函数、包和错误处理
    程序流程控制
    poj 2515 Birthday Cake
    poj 2094 多项式求和。
    hdu 3625 第一类striling 数
    hdu 4372 第一类stirling数的应用/。。。好题
    poj 1845 Sumdiv
    hdu 3641 Treasure Hunting 强大的二分
    poj 3335 /poj 3130/ poj 1474 半平面交 判断核是否存在 / poj1279 半平面交 求核的面积
    hdu 2841 Visible Trees
  • 原文地址:https://www.cnblogs.com/fusheng-rextimmy/p/15387340.html
Copyright © 2011-2022 走看看