zoukankan      html  css  js  c++  java
  • 图卷积神经网络GCN系列二:节点分类(含示例及代码)

    图上的机器学习任务通常有三种类型:整图分类、节点分类和链接预测。本篇博客要实现的例子是节点分类,具体来说是用GCN对Cora数据集里的样本进行分类。

    Cora数据集介绍:

    Cora数据集由许多机器学习领域的paper构成,这些paper被分为7个类别:

    • Case_Based
    • Genetic_Algorithms
    • Neural_Networks
    • Probabilistic_Methods
    • Reinforcement_Learning
    • Rule_Learning
    • Theory

    在该数据集中,每一篇论文至少引用了该数据集里面另外一篇论文或者被另外一篇论文所引用,数据集总共有2708篇papers。

    在消除停词以及除去文档频率小于10的词汇,最终词汇表中有1433个词汇,所以特征是1433维。0和1描述的是每个单词在paper中是否存在。

    把每一篇论文作为一个节点,根据论文之间的引用关系可以构建一个graph,包含2708个节点。0~139为训练节点数据,140~539为验证节点数据,1708~2707为测试节点数据。

    代码:(摘自https://github.com/rexrex9/gnn/blob/main/gcn.py。讲解视频https://www.bilibili.com/video/BV1W5411N78Y?from=search&seid=6220646189261474464&spm_id_from=333.337.0.0)

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from dgl.nn.pytorch import GraphConv
    from dgl.data import CoraGraphDataset
    
    class GCN( nn.Module ):
        def __init__(self,
                     g, #DGL的图对象
                     in_feats, #输入特征的维度
                     n_hidden, #隐层的特征维度
                     n_classes, #类别数
                     n_layers, #网络层数
                     activation, #激活函数
                     dropout #dropout系数
                     ):
            super( GCN, self ).__init__()
            self.g = g
            self.layers = nn.ModuleList()
            # 输入层
            self.layers.append( GraphConv( in_feats, n_hidden, activation = activation ))
            # 隐层
            for i in range(n_layers - 1):
                self.layers.append(GraphConv(n_hidden, n_hidden, activation = activation ))
            # 输出层
            self.layers.append( GraphConv( n_hidden, n_classes ) )
            self.dropout = nn.Dropout(p = dropout)
    
        def forward( self, features ):
            h = features
            for i, layer in enumerate( self.layers ):
                if i != 0:
                    h = self.dropout( h )
                h = layer( self.g, h )
            return h
    
    def evaluate(model, features, labels, mask):
        model.eval()
        with torch.no_grad():
            logits = model(features)
            logits = logits[mask]
            labels = labels[mask]
            _, indices = torch.max(logits, dim=1)
            correct = torch.sum(indices == labels)
            return correct.item() * 1.0 / len(labels)
    
    def train(n_epochs=100, lr=1e-2, weight_decay=5e-4, n_hidden=16, n_layers=1, activation=F.relu , dropout=0.5):
        data = CoraGraphDataset()
        g=data[0]   # 图的所有信息,包含2078个节点,每个节点1433维,所有节点可分为7类。10556条边。
        features = g.ndata['feat']
        labels = g.ndata['label']
        train_mask = g.ndata['train_mask']  # 0~139为训练节点
        val_mask = g.ndata['val_mask']      # 140~539为验证节点
        test_mask = g.ndata['test_mask']    # 1708-2707为测试节点
        in_feats = features.shape[1]
        n_classes = data.num_classes
    
        model = GCN(g,
                    in_feats,
                    n_hidden,
                    n_classes,
                    n_layers,
                    activation,
                    dropout)
    
        loss_fcn = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam( model.parameters(),
                                     lr = lr,
                                     weight_decay = weight_decay)
        best_val_acc = 0
        for epoch in range( n_epochs ):
            model.train()
            logits = model( features )
            loss = loss_fcn( logits[ train_mask ], labels[ train_mask ] )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = evaluate(model, features, labels, val_mask)
            print("Epoch {} | Loss {:.4f} | Accuracy {:.4f} ".format(epoch, loss.item(), acc ))
            if acc > best_val_acc:
                best_val_acc = acc
                torch.save(model.state_dict(), 'models/best_model.pth')
    
        model.load_state_dict(torch.load("models/best_model.pth"))
        acc = evaluate(model, features, labels, test_mask)
        print("Test accuracy {:.2%}".format(acc))
    
    if __name__ == '__main__':
        train()
    

      

    运行结果:

     ......

  • 相关阅读:
    pycharm快捷键
    类变量、实例变量--python
    内置窗口 pyqt5
    Python GUI教程(六):使用Qt设计师进行窗口布局
    PyCharm中Qt Designer+PyUIC配置
    PyQT5堆叠布局:切换界面(QStackedLayout)
    面试题之位运算的巧妙应用
    mybatis之Mapped Statements collection does not contain value for...错误原因分析
    tomcat报错Exception loading sessions from persistent storage解决方案
    leetcode数据库sql之Delete Duplicate Emails
  • 原文地址:https://www.cnblogs.com/picassooo/p/15430894.html
Copyright © 2011-2022 走看看