zoukankan      html  css  js  c++  java
  • 笔记:GAT入门学习

    GAT图注意力网络

    GAT 采用了 Attention 机制,可以为不同节点分配不同权重,训练时依赖于成对的相邻节点,而不依赖具体的网络结构,可以用于 inductive 任务。

    假设 Graph 包含 $N$ 个节点,每个节点的特征向量为 $h_i$,维度是 $F$,如下所示:

    \begin{gathered}
    \boldsymbol{h}=\left\{h_{1}, h_{2}, \ldots, h_{N}\right\} \\
    h_{1} \in R^{F}
    \end{gathered}

    节点 $j$ 是节点 $i$ 的邻居,则可以使用 Attention 机制计算节点 $j$ 对于节点 $i$ 的重要性,即 Attention Score:

    \begin{gathered}
    e_{i j}=\operatorname{Attention}\left(W h_{i}, W h_{j}\right) \\
    \alpha_{i j}=\operatorname{Softmax}_{j}\left(e_{i j}\right)=\frac{\exp \left(e_{i j}\right)}{\sum_{k \in N_{i}} \exp \left(e_{i k}\right)}
    \end{gathered}

    注意这个 $w$ 都是同一个

    GAT 具体的 Attention 做法如下,把节点 $i、j$ 的特征向量 $h'_i$、$h'_j$ 拼接在一起,然后和一个 $2F'$ 维的向量 $a$ 计算内积。激活函数采用 LeakyReLU,公式如下:

    $$
    \alpha_{i j}=\frac{\exp \left(\operatorname{LeakyReLU}\left(a^{T}\left[W h_{i} \| W h_{j}\right]\right)\right)}{\sum_{k \in N_{i}} \exp \left(\operatorname{LeakyReLU}\left(a^{T}\left[W h_{i} \| W h_{k}\right]\right)\right)}
    $$
    || 表示拼接操作

    经过 Attention 之后节点 $i$ 的特征向量如下:

    $$h_{i}^{\prime}=\sigma\left(\sum_{j \in N_{i}} \alpha_{i j} W h_{j}\right)$$

    GAT 也可以采用 Multi-Head Attention,如果有 K 个 Attention,则需要把 K 个 Attention 生成的向量拼接在一起,如下:

    $$h_{i}^{\prime}=\operatorname{concat}\left(\sigma\left(\sum_{j \in N_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right)\right)$$

    但是如果是最后一层,则 K 个 Attention 的输出不进行拼接,而是求平均:

    $$h_{i}^{\prime}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in N_{i}} \alpha_{i j}^{k} W^{k} h_{j}\right)$$

    网络结构:

    样例来自 https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gat.py

    class GAT(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GAT, self).__init__()
    
            # num_features: Alias for num_node_features.
            self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)
    
            # On the Pubmed dataset, use heads=8 in conv2.
            self.conv2 = GATConv(8 * 8, out_channels, heads=1, concat=False,
                                 dropout=0.6)
    
        def forward(self, x, edge_index):
            ipdb.set_trace()
            x_copy = x.clone()
            x = F.dropout(x, p=0.6, training=self.training)
            x = F.elu(self.conv1(x, edge_index))
            x = F.dropout(x, p=0.6, training=self.training)
            x = self.conv2(x, edge_index)
            return x + x_copy  # Residual connection, 避免孤立节点变成全0
            
            # return F.log_softmax(x, dim=-1)  # log_softmax ??
            return x   #  我觉得这个位置还不要softmax
    

    参考链接:https://ai.baidu.com/forum/topic/show/972764

  • 相关阅读:
    [转]Could not load file or assembly 'XXX' or one of its dependencies.
    网页上显示别人电脑没安装的字体,例如LED字体
    JS 保留小数点后面2位小数
    ASP.NET2.0揭秘读书笔记五——维护应用程序状态之cookie
    C#高级编程读书笔记之.NET体系结构
    ASP.NET2.0揭秘读书笔记之八——页面输出缓存
    《大话设计模式》读书笔记一 简单工厂模式
    C#高级编程读书笔记之继承
    ASP.NET 2.0揭秘读书笔记七——使用用户配置文件Profile
    终于成功安装了SQL SqlServer2005
  • 原文地址:https://www.cnblogs.com/lfri/p/15546394.html
Copyright © 2011-2022 走看看