zoukankan      html  css  js  c++  java
  • 【图卷积】【Graph Convolution】

    Graph Convolution

    基础版本

    github

    dataset

    wget https://data.deepai.org/Cora.zip

    import math
    
    import torch
    
    from torch.nn.parameter import Parameter
    from torch.nn.modules.module import Module
    
    
    class GraphConvolution(Module):
        """
        Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
        """
    
        def __init__(self, in_features, out_features, bias=True):
            super(GraphConvolution, self).__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = Parameter(torch.FloatTensor(in_features, out_features))
            if bias:
                self.bias = Parameter(torch.FloatTensor(out_features))
            else:
                self.register_parameter('bias', None)
            self.reset_parameters()
    
        def reset_parameters(self):
            stdv = 1. / math.sqrt(self.weight.size(1))
            self.weight.data.uniform_(-stdv, stdv)
            if self.bias is not None:
                self.bias.data.uniform_(-stdv, stdv)
    
        def forward(self, input, adj):
            # input [in_features]
            # adj [num]
            support = torch.mm(input, self.weight)
            # support [out_feature]
            output = torch.spmm(adj, support)
            # output [out_feature]
            if self.bias is not None:
                return output + self.bias
            else:
                return output
    
        def __repr__(self):
            return self.__class__.__name__ + ' (' 
                   + str(self.in_features) + ' -> ' 
                   + str(self.out_features) + ')'
    

    torch-geometric

    依赖

    pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}.html
    
    pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu111.html
    

    cuda

    版本选择

    conda install cudatoolkit=11.0 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/linux-64/
    
    conda install cudnn=7.4.1 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64/
    
    
  • 相关阅读:
    阅读之分布式架构的数据一致
    阅读之MySQL数据库分表
    阅读笔记1
    问题账户需求分析
    软件需求分析阅读笔记
    开发进度第四天
    开发进度第三天
    开发进度第二天
    线程中三个关键对象闭锁,栅栏,信号量
    java多线程中关于原子操作
  • 原文地址:https://www.cnblogs.com/linzhenyu/p/14896937.html
Copyright © 2011-2022 走看看