zoukankan      html  css  js  c++  java
  • GCN code parsing

    GCN code parsing 
    2018-07-18 20:39:11

     

    utils.py 

    --- load data 

    def load_data(path="../data/cora/", dataset="cora"):
        """Load citation network dataset (cora only for now)"""
        print('Loading {} dataset...'.format(dataset))
    
        idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset),
                                            dtype=np.dtype(str))
        features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
        labels = encode_onehot(idx_features_labels[:, -1])
    
        # build graph
        idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
        idx_map = {j: i for i, j in enumerate(idx)}
        edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
                                        dtype=np.int32)
        edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                         dtype=np.int32).reshape(edges_unordered.shape)
        adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)
    
        # build symmetric adjacency matrix
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    
        features = normalize(features)
        adj = normalize(adj + sp.eye(adj.shape[0]))
    
        idx_train = range(140)
        idx_val = range(200, 500)
        idx_test = range(500, 1500)
    
        features = torch.FloatTensor(np.array(features.todense()))
        labels = torch.LongTensor(np.where(labels)[1])
        adj = sparse_mx_to_torch_sparse_tensor(adj)
    
        idx_train = torch.LongTensor(idx_train)
        idx_val = torch.LongTensor(idx_val)
        idx_test = torch.LongTensor(idx_test)
    
        return adj, features, labels, idx_train, idx_val, idx_test
    View Code

    ## adj: torch.size([2708, 2708])
    ## features: torch.Size([2708, 1433])
    ## labels: torch.Size([2708])
    ## idx_train: torch.Size([140])
    ## idx_val: torch.Size([300])
    ## idx_test: torch.Size([1000])

    train.py 

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

  • 相关阅读:
    TCP/IP学习-链路层
    Linux下搭建Wordpress环境
    DiskMgr的限制项
    Win10系统Start Menu上的图标莫名消失
    powershell
    第一个页面的文本域中输入的值怎么在第二个页面中显示
    php 文本框里面显示数据库调出来的资料
    php代码
    php表单提交方法汇总
    php将SQL查询结果赋值给变量
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/9332476.html
Copyright © 2011-2022 走看看