zoukankan      html  css  js  c++  java
  • DAG-GNN: DAG Structure Learning with Graph Neural Networks

    Yu Y., Chen J., Gao T. and Yu M. DAG-GNN: DAG structure learning with graph neural networks. In International Conference on Machine Learning (ICML), 2019.

    有向无环图 + GNN + VAE.

    主要内容

    先前已经有工作(NOTEARS)讨论了如何处理线性SEM模型

    [X = A^TX + Z, ]

    (A in mathbb{R}^{m imes m})为加权的邻接矩阵, (m)代表了有向无环图中变量的数目, (Z)是独立的noise. 需要特别说明的是, 在本文中, 作者假设每一个结点变量(X_i)并非传统的标量而是一个向量 (个人觉得这是很有意思的点, 有点胶囊的感觉), 故(X in mathbb{R}^{m imes d}), 这里(X_i)(X)的第(i)行.

    本文在此基础上更进一步, 考虑非线性的情况:

    [g(X) = A^Tg(X) + f_1(Z), ]

    如果(g)可逆, 则可以进一步表示为

    [X = f_2((I - A^T)^{-1}f_1(Z)). ]

    为了满足这一模型, 作者套用VAE, 进而最大化ELBO:

    [mathcal{L}_{mathrm{ELBO}} = mathbb{E}_{q_{phi}(Z|X)}[log p_{ heta}(X|Z)] - mathbb{D}_{mathrm{KL}}(q_{phi}(Z|X)| p(Z)), ]

    整个VAE的流程是这样的:

    image-20210530180904373

    1. encoder:

      [M_Z, log S_Z = f_4((I - A^T)f_3(X)), \ Z sim mathcal{N}(M_Z, S_Z^2). ]

    2. decoder

    [M_X, S_X = f_2((I - A^T)^{-1}f_1(Z)), \ widehat{X} sim mathcal{N}(M_X, S_X^2). ]

    注: 因为每个结点变量都不是标量, 所以考虑上面的流程还是把(X, Z)拉成向量(md)再看会比较清楚.

    此时

    [mathbb{D}_{mathrm{KL}}(q_{phi}(Z|X)|p(Z)) = frac{1}{2} sum_{i=1}^m sum_{j=1}^d {[S_Z]_{ij}^2 + [M_Z]_{ij}^2 - 2log [S_Z]_{ij} - 1 }. ]

    仅最大化ELBO是不够的, 因为这并不能保证(A)反应有向无环图, 所以我们需要增加条件

    [h(A) = mathrm{tr}[(I+alpha A circ A)^m] = m, ]

    具体推导看NOTEARS, 这里(alpha=frac{c}{m}), (c>0)是一个超参数, 这个原因是

    [(1 + alpha |lambda|)^m le e^{c|lambda|}, ]

    所以合适的(c)能够让条件更加稳定.

    最后目标可以总结为:

    [min_{phi, heta, A} quad -mathcal{L}_{mathrm{ELBO}} \ mathrm{s.t.} quad h(A) = 0. ]

    同样的, 作者采用了augmented Lagrangian来求解

    [(A^k, phi^k, heta^k) = arg min_{A,phi, heta} : -mathcal{L}_{mathrm{ELBO}} + lambda h(A) + frac{c}{2}|h(A)|^2, \ lambda^{k+1} = lambda^k + c^k h(A^k), \ c^{k+1} = left { egin{array}{ll} eta c^k, & mathrm{if} : |h(A^k)| > gamma |h(A^{k-1})|, \ c^k, & otherwise. end{array} ight. ]

    这里(eta > 1, gamma < 1), 作者选择(eta=10, gamma=1/4).

    注: (c)逐渐增大的原因是, 显然当(c = +infty)的时候, (h(A))必须为0.

    注: 作者关于图神经网络的部分似乎就集中在(X)的模型上, 关于图神经网络不是很懂, 就不写了.

    代码

    原文代码

  • 相关阅读:
    struts2实现文件上传和下载
    Struts2中Action之ResultType
    初识Struts2
    Hibernate中get()和load()方法区别
    初识Hibernate框架,进行简单的增删改查操作
    memge和saveOrUpdate的区别
    apt-get install 出现could not open lock file /var/lib/dpkg/lock错误问题
    vscode工程目录文件及文件夹摘选
    C++引用
    内存分配区基本模型
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/14828782.html
Copyright © 2011-2022 走看看