zoukankan      html  css  js  c++  java
  • Masked Gradient-Based Causal Structure Learning

    Ng I., Fang Z., Zhu S., Chen Z. and Wang J. Masked Gradient-Based Causal Structure Learning. arXiv preprint arXiv:1911.10500, 2019.

    非线性, 自动地学习因果图.

    主要内容

    NOTEARS将有向无环图凝练成了易处理的条件, 本文将这种思想扩展至非线性的情况:

    [X_i = f_i(X_{mathrm{pa}(i)}) + epsilon_i, ]

    其中(X_i)是因果图的结点, (X_{mathrm{pa}(i)})是其父结点, (epsilon)是无关的噪声.

    上述等式等价于

    [X_i = f_i(A_i circ X) + epsilon_i, ]

    (A_i)是邻接矩阵(A=[A_1|A_2|cdots|A_d] in {0, 1}^{d imes d})的第i列, (A_{ij}=1)表示结点(X_i)直接作用于(X_j).

    所以本文的目标就可以转换为如何估计(A)(实际上有了(A)也就知道了因果图了). (A)应当满足的条件:

    1. (A) 能够表示有向无环图;
    2. (X_i)(f(A_i circ X))必须接近, 比如用常见的

    [|X_i - f(A_i circ X)|_2^2 ]

    来度量.

    直接处理非常麻烦, 首先对上面的问题进行放松, 等价于

    [X_i = f_i(W_i circ X) + epsilon_i, ]

    此时(A = mathcal{A}(W)), 即

    [W_{ij} ot = 0 ightarrow A_{ij} = 1; W_{ij} = 0 ightarrow A_{ij} = 0. ]

    本文更进一步, 令

    [W = g_{ au}(U), quad U in mathbb{R}^{d imes d}. ]

    [[g_{ au}(U)]_{ij} = sigma((u_{ij} + g) / au) = frac{1}{1 + exp(-(u_{ij}+ (g_1 - g_0)) / au)}, ]

    其中

    [g = g_1 - g_0, : g_i mathop{sim}limits^{i.i.d.} mathrm{Gumbel}(0, 1). ]

    注: Gumbel.

    此类操作能保证(g_{ au}(U) in (0, 1)^{d imes d}), 此时能够把([g_{ au}(U)]_{ij})看成是(X_i), (X_j)的关系的紧密型的度量, 在这种情况下

    [[g_{ au}(U)]_{ij} le omega Rightarrow A_{ij} = 0. ]

    或许会问, 为什么不用sigmoid而用一个这么麻烦的东西, 原因是当( au)足够小的时候(如本文取的0.2), ([g_{ au}(U)]_{ij})非常接近(0)或者(1), 而用sigmoid, 作者发现这些值都接近0, 不能很好的模拟有向无环图, 故采用了这个方案.

    接下来, 只需要满足

    [mathbb{E}[mathrm{tr}(e^{g_{ au}(U)}) - d] = 0, ]

    即可保证(g_{ au}(U))能够代表有效无环图. 在实际中, 只需

    [mathbb{E}[mathrm{tr}(e^{g_{ au}(U)}) - d] le xi. ]

    注: 期望是关于(g)的.

    最终的目标

    总结下来,

    [min_{U, heta} quad mathbb{E}_g[frac{1}{2n} sum_{k=1}^n mathcal{L}(x^{(k)}, f(g_{ au}, x^{(k)}; heta))] \ mathrm{s.t.} quad mathbb{E}_g[mathrm{tr}(e^{g_{ au}(U)}) - d] le xi. ]

    注: (mathbb{E})是关于(g)的, (n)的观测数据的总数.

    进一步地, 我们希望(g_{ au})是稀疏的, 故加上正则化项:

    [min_{U, heta} quad mathbb{E}_g[frac{1}{2n} sum_{k=1}^n mathcal{L}(x^{(k)}, f(g_{ au}, x^{(k)}; heta)) + lambda |g_{ au}(U)|_1] \ mathrm{s.t.} quad mathbb{E}_g[mathrm{tr}(e^{g_{ au}(U)}) - d] le xi. ]

    利用augmented Lagrange multiplier, 可得

    [L_p(U, phi, alpha) = mathbb{E}_g[frac{1}{2n} sum_{k=1}^n mathcal{L}(x^{(k)}, f(g_{ au}, x^{(k)}; heta)) + lambda |g_{ au}(U)|_1 + alpha h(U)] + frac{ ho}{2} (mathbb{E}[h(U)])^2, ]

    其中(h(U):= mathrm{tr}(e^{g_{ au}(U)}) - d).

    采用分布更新:

    [U^{t+1}, heta^{t+1} = arg min_{U, heta} L_{ ho^t}(U, phi, alpha^t); \ alpha^{t+1} = alpha^t + ho^t mathbb{E}[h(U^{t+1})]; \ ho^{t+1} = left { egin{array}{ll} eta ho^t, & mathrm{if} : mathbb{E}[h(U^{t+1})] ge gamma mathbb{E}[h(U^t)], \ ho^t, & mathrm{otherwise}. end{array} ight . ]

    其中第一步使用Adam执行1000次迭代计算的.

    文中还讨论了后处理的一些方法, 和(A)是否唯一.

    代码

    GES and PC

    CAM

    NOTEARS

    DAG-GNN

    GraN-DAG

  • 相关阅读:
    安装pykeyboard模块
    Windows Defender Antivirus Service经常性出现占用CPU厉害
    Xpath 语法笔记
    通过docker部署rocketmq双主双从集群
    解决提取Mybatis多数据源公共组件“At least one base package must be specified”的问题
    设计模式-单例模式
    通过阳历生日计算星座,阴历生日,生辰八字,生肖五行
    设计模式-抽象工厂模式
    设计模式-工厂方法模式
    常用的MD5工具类
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/14826309.html
Copyright © 2011-2022 走看看