zoukankan      html  css  js  c++  java
  • 图融合GCN(Graph Convolutional Networks)

    图融合GCN(Graph Convolutional Networks)

    数据其实是图(graph),图在生活中无处不在,如社交网络,知识图谱,蛋白质结构等。本文介绍GNN(Graph Neural Networks)中的分支:GCN(Graph Convolutional Networks)

     

     

     

     

     

     

     

     

     GCN的PyTorch实现

    虽然GCN从数学上较难理解,但是,实现是非常简单的,值得注意的一点是,一般情况下邻接矩阵是稀疏矩阵,所以,在实现矩阵乘法时,采用稀疏运算会更高效。首先,图卷积层的实现:

    import torch
    import torch.nn as nn


    class GraphConvolution(nn.Module):
    """GCN layer"""

    def __init__(self, in_features, out_features, bias=True):
    super(GraphConvolution, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
    if bias:
    self.bias = nn.Parameter(torch.Tensor(out_features))
    else:
    self.register_parameter('bias', None)

    self.reset_parameters()

    def reset_parameters(self):
    nn.init.kaiming_uniform_(self.weight)
    if self.bias isnotNone:
    nn.init.zeros_(self.bias)

    def forward(self, input, adj):
    support = torch.mm(input, self.weight)
    output = torch.spmm(adj, support)
    if self.bias isnotNone:
    return output + self.bias
    else:
    return output

    def extra_repr(self):
    return'in_features={}, out_features={}, bias={}'.format(
    self.in_features, self.out_features, self.bias isnotNone
    )
    对于GCN,只需要将图卷积层堆积起来就可以,这里,实现一个两层的GCN:
    class GCN(nn.Module):
    """a simple two layer GCN"""
    def __init__(self, nfeat, nhid, nclass):
    super(GCN, self).__init__()
    self.gc1 = GraphConvolution(nfeat, nhid)
    self.gc2 = GraphConvolution(nhid, nclass)

    def forward(self, input, adj):
    h1 = F.relu(self.gc1(input, adj))
    logits = self.gc2(h1, adj)
    return logits

    这里的激活函数采用ReLU,后面,将用这个网络实现一个图中节点的半监督分类任务。

    数据的提取,只需要load就可以:

    # https://github.com/tkipf/pygcn/blob/master/pygcn/utils.py
    adj, features, labels, idx_train, idx_val, idx_test = load_data(path="./data/cora/")

    值得注意的有两点,一是论文引用应该是单向图,但是在网络时,要先将其转成无向图,或者说建立双向引用,这个对模型训练结果影响较大:

    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    另外,官方实现中对邻接矩阵采用的是普通均值归一化,当然,也可以采用对称归一化方式:

    def normalize_adj(adj):
        """compute L=D^-0.5 * (A+I) * D^-0.5"""
        adj += sp.eye(adj.shape[0])
        degree = np.array(adj.sum(1))
        d_hat = sp.diags(np.power(degree, -0.5).flatten())
        norm_adj = d_hat.dot(adj).dot(d_hat)
        return norm_adj

    这里,只采用图中140个有标签样本对GCN进行训练,每个epoch计算出这些节点特征,然后计算loss:

        loss_history = []
        val_acc_history = []
        for epoch in range(epochs):
            model.train()
            logits = model(features, adj)
            loss = criterion(logits[idx_train], labels[idx_train])
           
            train_acc = accuracy(logits[idx_train], labels[idx_train])
           
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
           
            val_acc = test(idx_val)
            loss_history.append(loss.item())
            val_acc_history.append(val_acc.item())
            print("Epoch {:03d}: Loss {:.4f}, TrainAcc {:.4}, ValAcc {:.4f}".format(
                epoch, loss.item(), train_acc.item(), val_acc.item()))

    只需要训练200个epoch,就可以在测试集上达到80%左右的分类准确,GCN的强大可想而知:

     融合BN和Conv层

    在PyTorch中实现这个融合操作:nn.Conv2d参数:

    • filter weights,W: conv.weight;
    • bias,b: conv.bias;

    nn.BatchNorm2d参数:

     

    具体的实现代码如下(Google Colab, https://colab.research.google.com/drive/1mRyq_LlJW4u_rArzzhEe_T6tmEWoNN1K):

    import torch
        import torchvision
       
        def fuse(conv, bn):
       
            fused = torch.nn.Conv2d(
                conv.in_channels,
                conv.out_channels,
                kernel_size=conv.kernel_size,
                stride=conv.stride,
                padding=conv.padding,
                bias=True
            )
       
            # setting weights
            w_conv = conv.weight.clone().view(conv.out_channels, -1)
            w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
            fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
           
            # setting bias
            if conv.bias isnotNone:
                b_conv = conv.bias
            else:
                b_conv = torch.zeros( conv.weight.size(0) )
            b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                                  torch.sqrt(bn.running_var + bn.eps)
                                )
            fused.bias.copy_( b_conv + b_bn )
       
            return fused
       
        # Testing
        # we need to turn off gradient calculation because we didn't write it
        torch.set_grad_enabled(False)
        x = torch.randn(16, 3, 256, 256)
        resnet18 = torchvision.models.resnet18(pretrained=True)
        # removing all learning variables, etc
        resnet18.eval()
        model = torch.nn.Sequential(
            resnet18.conv1,
            resnet18.bn1
        )
        f1 = model.forward(x)
        fused = fuse(model[0], model[1])
        f2 = fused.forward(x)
        d = (f1 - f2).mean().item()
        print("error:",d)

    参考链接:

      1. Semi-Supervised Classification with Graph Convolutional Networks https://arxiv.org/abs/1609.02907
      2. How to do Deep Learning on Graphs with Graph Convolutional Networks https://towardsdatascience.com/how-to-do-deep-learning-on-graphs-with-graph-convolutional-networks-7d2250723780
      3. Graph Convolutional Networks http://tkipf.github.io/graph-convolutional-networks
      4. Graph Convolutional Networks in PyTorch https://github.com/tkipf/pygcn
      5. 回顾频谱图卷积的经典工作:从ChebNet到GCN https://www.jianshu.com/p/2fd5a2454781
      6. 图数据集之cora数据集介绍- 用pyton处理 - 可用于GCN任务 https://blog.csdn.net/yeziand01/article/details/93374216
      7. Speeding up model with fusing batch normalization and convolution (http://learnml.today/speeding-up-model-with-fusing-batch-normalization-and-convolution-3)
    人工智能芯片与自动驾驶
  • 相关阅读:
    基于协程实现并发的套接字通信
    基于tcp协议的套接字通信:远程执行命令
    Java开发中的23种设计模式详解(转)
    SonarLint实践总结
    Java代码规范与质量检测插件SonarLint
    ES的基本介绍和使用
    ES基本介绍(简介)
    弗洛伊德追悼会 事发地市长跪在灵柩前大哭
    阿里云部署Web项目
    SpringBoot上传图片无法走复制流
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/15240958.html
Copyright © 2011-2022 走看看