zoukankan      html  css  js  c++  java
  • 图卷积神经网络GCN系列二:节点分类(含示例及代码)

    图上的机器学习任务通常有三种类型:整图分类、节点分类和链接预测。本篇博客要实现的例子是节点分类,具体来说是用GCN对Cora数据集里的样本进行分类。

    Cora数据集介绍:

    Cora数据集由许多机器学习领域的paper构成,这些paper被分为7个类别:

    • Case_Based
    • Genetic_Algorithms
    • Neural_Networks
    • Probabilistic_Methods
    • Reinforcement_Learning
    • Rule_Learning
    • Theory

    在该数据集中,每一篇论文至少引用了该数据集里面另外一篇论文或者被另外一篇论文所引用,数据集总共有2708篇papers。

    在消除停词以及除去文档频率小于10的词汇,最终词汇表中有1433个词汇,所以特征是1433维。0和1描述的是每个单词在paper中是否存在。

    把每一篇论文作为一个节点,根据论文之间的引用关系可以构建一个graph,包含2708个节点。0~139为训练节点数据,140~539为验证节点数据,1708~2707为测试节点数据。

    代码:(摘自https://github.com/rexrex9/gnn/blob/main/gcn.py。讲解视频https://www.bilibili.com/video/BV1W5411N78Y?from=search&seid=6220646189261474464&spm_id_from=333.337.0.0)

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from dgl.nn.pytorch import GraphConv
    from dgl.data import CoraGraphDataset
    
    class GCN( nn.Module ):
        def __init__(self,
                     g, #DGL的图对象
                     in_feats, #输入特征的维度
                     n_hidden, #隐层的特征维度
                     n_classes, #类别数
                     n_layers, #网络层数
                     activation, #激活函数
                     dropout #dropout系数
                     ):
            super( GCN, self ).__init__()
            self.g = g
            self.layers = nn.ModuleList()
            # 输入层
            self.layers.append( GraphConv( in_feats, n_hidden, activation = activation ))
            # 隐层
            for i in range(n_layers - 1):
                self.layers.append(GraphConv(n_hidden, n_hidden, activation = activation ))
            # 输出层
            self.layers.append( GraphConv( n_hidden, n_classes ) )
            self.dropout = nn.Dropout(p = dropout)
    
        def forward( self, features ):
            h = features
            for i, layer in enumerate( self.layers ):
                if i != 0:
                    h = self.dropout( h )
                h = layer( self.g, h )
            return h
    
    def evaluate(model, features, labels, mask):
        model.eval()
        with torch.no_grad():
            logits = model(features)
            logits = logits[mask]
            labels = labels[mask]
            _, indices = torch.max(logits, dim=1)
            correct = torch.sum(indices == labels)
            return correct.item() * 1.0 / len(labels)
    
    def train(n_epochs=100, lr=1e-2, weight_decay=5e-4, n_hidden=16, n_layers=1, activation=F.relu , dropout=0.5):
        data = CoraGraphDataset()
        g=data[0]   # 图的所有信息,包含2078个节点,每个节点1433维,所有节点可分为7类。10556条边。
        features = g.ndata['feat']
        labels = g.ndata['label']
        train_mask = g.ndata['train_mask']  # 0~139为训练节点
        val_mask = g.ndata['val_mask']      # 140~539为验证节点
        test_mask = g.ndata['test_mask']    # 1708-2707为测试节点
        in_feats = features.shape[1]
        n_classes = data.num_classes
    
        model = GCN(g,
                    in_feats,
                    n_hidden,
                    n_classes,
                    n_layers,
                    activation,
                    dropout)
    
        loss_fcn = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam( model.parameters(),
                                     lr = lr,
                                     weight_decay = weight_decay)
        best_val_acc = 0
        for epoch in range( n_epochs ):
            model.train()
            logits = model( features )
            loss = loss_fcn( logits[ train_mask ], labels[ train_mask ] )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = evaluate(model, features, labels, val_mask)
            print("Epoch {} | Loss {:.4f} | Accuracy {:.4f} ".format(epoch, loss.item(), acc ))
            if acc > best_val_acc:
                best_val_acc = acc
                torch.save(model.state_dict(), 'models/best_model.pth')
    
        model.load_state_dict(torch.load("models/best_model.pth"))
        acc = evaluate(model, features, labels, test_mask)
        print("Test accuracy {:.2%}".format(acc))
    
    if __name__ == '__main__':
        train()
    

      

    运行结果:

     ......

  • 相关阅读:
    深入nginx之《获取用户的真实IP》
    深入Nginx之《常用参数配置技巧》
    深入Nginx之《HTTP请求报文与HTTP响应报文》
    webapck html-loader 静态html模块化
    webpack四个基础概念
    从原生Android 跳转到hbuilder项目
    移动端适配方案 flexible.js
    vue使用px2rem
    koa2 post请求ctx.request.body空获取不到的解决办法
    url、href、src
  • 原文地址:https://www.cnblogs.com/picassooo/p/15430894.html
Copyright © 2011-2022 走看看