zoukankan      html  css  js  c++  java
  • DGL学习(六): GCN实现

    GCN可以认为由两步组成:

    对于每个节点 $u$

    1)汇总邻居的表示$h_v$ 产生中间表示 $hat h_u$

    2) 使用$W_u$线性投影 $hat h_v$, 再经过非线性变换 $f$ , 即 $h_u = f(W_u hat h_u)$

    首先定义message函数和reduce函数。

    import dgl
    import dgl.function as fn
    import torch as th
    import torch.nn as nn
    import torch.nn.functional as F
    from dgl import DGLGraph
    
    ## 定义消息函数 和 reduce函数
    gcn_msg = fn.copy_src(src='h', out='m')
    gcn_reduce = fn.sum(msg='m', out='h')

    定义GCN

    ## 定义GCNLayer
    class GCNLayer(nn.Module):
        def __init__(self, in_feats, out_feats):
            super(GCNLayer, self).__init__()
            self.linear = nn.Linear(in_feats, out_feats)
    
        def forward(self, g, feature):
            # Creating a local scope so that all the stored ndata and edata
            # (such as the `'h'` ndata below) are automatically popped out
            # when the scope exits.
            with g.local_scope():
                g.ndata['h'] = feature
                g.update_all(gcn_msg, gcn_reduce)
                h = g.ndata['h']
                return self.linear(h)
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.layer1 = GCNLayer(1433, 16)
            self.layer2 = GCNLayer(16, 7)
    
        def forward(self, g, features):
            x = F.relu(self.layer1(g, features))
            x = self.layer2(g, x)
            return x
    net = Net()
    print(net)
  • 相关阅读:
    EXCEL自动导出HTML
    亡灵序曲超清
    支持国产动画-唐伯卿和曾小兰
    中国表情
    logging 日志
    datetime库运用
    hashlib 加密
    os2
    python json数据处理
    python操作redis
  • 原文地址:https://www.cnblogs.com/liyinggang/p/13370943.html
Copyright © 2011-2022 走看看