zoukankan      html  css  js  c++  java
  • GCN code parsing

    GCN code parsing 
    2018-07-18 20:39:11

     

    utils.py 

    --- load data 

    def load_data(path="../data/cora/", dataset="cora"):
        """Load citation network dataset (cora only for now)"""
        print('Loading {} dataset...'.format(dataset))
    
        idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset),
                                            dtype=np.dtype(str))
        features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
        labels = encode_onehot(idx_features_labels[:, -1])
    
        # build graph
        idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
        idx_map = {j: i for i, j in enumerate(idx)}
        edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
                                        dtype=np.int32)
        edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                         dtype=np.int32).reshape(edges_unordered.shape)
        adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)
    
        # build symmetric adjacency matrix
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    
        features = normalize(features)
        adj = normalize(adj + sp.eye(adj.shape[0]))
    
        idx_train = range(140)
        idx_val = range(200, 500)
        idx_test = range(500, 1500)
    
        features = torch.FloatTensor(np.array(features.todense()))
        labels = torch.LongTensor(np.where(labels)[1])
        adj = sparse_mx_to_torch_sparse_tensor(adj)
    
        idx_train = torch.LongTensor(idx_train)
        idx_val = torch.LongTensor(idx_val)
        idx_test = torch.LongTensor(idx_test)
    
        return adj, features, labels, idx_train, idx_val, idx_test
    View Code

    ## adj: torch.size([2708, 2708])
    ## features: torch.Size([2708, 1433])
    ## labels: torch.Size([2708])
    ## idx_train: torch.Size([140])
    ## idx_val: torch.Size([300])
    ## idx_test: torch.Size([1000])

    train.py 

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

  • 相关阅读:
    MySQL 一次非常有意思的SQL优化经历:从30248.271s到0.001s
    Oracle 11g 自动收集统计信息
    C# 获取当前方法的名称空间、类名和方法名称
    C# 数值的隐式转换
    C# using 三种使用方式
    C#、Unity 数据类型的默认值
    Unity for VsCode
    C# Lambda
    git push以后GitHub上文件夹灰色 不可点击
    C#保留小数
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/9332476.html
Copyright © 2011-2022 走看看