zoukankan      html  css  js  c++  java
  • DGL学习(六): GCN实现

    GCN可以认为由两步组成:

    对于每个节点 $u$

    1)汇总邻居的表示$h_v$ 产生中间表示 $hat h_u$

    2) 使用$W_u$线性投影 $hat h_v$, 再经过非线性变换 $f$ , 即 $h_u = f(W_u hat h_u)$

    首先定义message函数和reduce函数。

    import dgl
    import dgl.function as fn
    import torch as th
    import torch.nn as nn
    import torch.nn.functional as F
    from dgl import DGLGraph
    
    ## 定义消息函数 和 reduce函数
    gcn_msg = fn.copy_src(src='h', out='m')
    gcn_reduce = fn.sum(msg='m', out='h')

    定义GCN

    ## 定义GCNLayer
    class GCNLayer(nn.Module):
        def __init__(self, in_feats, out_feats):
            super(GCNLayer, self).__init__()
            self.linear = nn.Linear(in_feats, out_feats)
    
        def forward(self, g, feature):
            # Creating a local scope so that all the stored ndata and edata
            # (such as the `'h'` ndata below) are automatically popped out
            # when the scope exits.
            with g.local_scope():
                g.ndata['h'] = feature
                g.update_all(gcn_msg, gcn_reduce)
                h = g.ndata['h']
                return self.linear(h)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.layer1 = GCNLayer(1433, 16)
            self.layer2 = GCNLayer(16, 7)
    
        def forward(self, g, features):
            x = F.relu(self.layer1(g, features))
            x = self.layer2(g, x)
            return x
    net = Net()
    print(net)
  • 相关阅读:
    33. 搜索旋转排序数组
    54. 螺旋矩阵
    46. 全排列
    120. 三角形最小路径和
    338. 比特位计数
    746. 使用最小花费爬楼梯
    spring boot的一些常用注解
    SSM整合Dubbo案例
    一些面试题
    Spring Aop和Spring Ioc(二)
  • 原文地址:https://www.cnblogs.com/liyinggang/p/13370943.html
Copyright © 2011-2022 走看看