zoukankan      html  css  js  c++  java
  • GCN code parsing

    GCN code parsing 
    2018-07-18 20:39:11

     

    utils.py 

    --- load data 

    def load_data(path="../data/cora/", dataset="cora"):
        """Load citation network dataset (cora only for now)"""
        print('Loading {} dataset...'.format(dataset))
    
        idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset),
                                            dtype=np.dtype(str))
        features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
        labels = encode_onehot(idx_features_labels[:, -1])
    
        # build graph
        idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
        idx_map = {j: i for i, j in enumerate(idx)}
        edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
                                        dtype=np.int32)
        edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                         dtype=np.int32).reshape(edges_unordered.shape)
        adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)
    
        # build symmetric adjacency matrix
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    
        features = normalize(features)
        adj = normalize(adj + sp.eye(adj.shape[0]))
    
        idx_train = range(140)
        idx_val = range(200, 500)
        idx_test = range(500, 1500)
    
        features = torch.FloatTensor(np.array(features.todense()))
        labels = torch.LongTensor(np.where(labels)[1])
        adj = sparse_mx_to_torch_sparse_tensor(adj)
    
        idx_train = torch.LongTensor(idx_train)
        idx_val = torch.LongTensor(idx_val)
        idx_test = torch.LongTensor(idx_test)
    
        return adj, features, labels, idx_train, idx_val, idx_test
    View Code

    ## adj: torch.size([2708, 2708])
    ## features: torch.Size([2708, 1433])
    ## labels: torch.Size([2708])
    ## idx_train: torch.Size([140])
    ## idx_val: torch.Size([300])
    ## idx_test: torch.Size([1000])

    train.py 

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

  • 相关阅读:
    【网络】【交换机】相关字符串处理
    python【telnet】使用
    【?】使用汇总
    【dbm】【功率】换算
    快速【kill进程】
    常用功能【时间log】
    python切换镜像源
    git报remote HTTP Basic Access denied错误的解决方法
    求取1到n的素数的数学思想——埃拉托斯特尼筛法
    MySQL count(*) 和 count(字段) 区别
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/9332476.html
Copyright © 2011-2022 走看看