zoukankan      html  css  js  c++  java
  • GCN code parsing

    GCN code parsing 
    2018-07-18 20:39:11

     

    utils.py 

    --- load data 

    def load_data(path="../data/cora/", dataset="cora"):
        """Load citation network dataset (cora only for now)"""
        print('Loading {} dataset...'.format(dataset))
    
        idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset),
                                            dtype=np.dtype(str))
        features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
        labels = encode_onehot(idx_features_labels[:, -1])
    
        # build graph
        idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
        idx_map = {j: i for i, j in enumerate(idx)}
        edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
                                        dtype=np.int32)
        edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                         dtype=np.int32).reshape(edges_unordered.shape)
        adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(labels.shape[0], labels.shape[0]),
                            dtype=np.float32)
    
        # build symmetric adjacency matrix
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    
        features = normalize(features)
        adj = normalize(adj + sp.eye(adj.shape[0]))
    
        idx_train = range(140)
        idx_val = range(200, 500)
        idx_test = range(500, 1500)
    
        features = torch.FloatTensor(np.array(features.todense()))
        labels = torch.LongTensor(np.where(labels)[1])
        adj = sparse_mx_to_torch_sparse_tensor(adj)
    
        idx_train = torch.LongTensor(idx_train)
        idx_val = torch.LongTensor(idx_val)
        idx_test = torch.LongTensor(idx_test)
    
        return adj, features, labels, idx_train, idx_val, idx_test
    View Code

    ## adj: torch.size([2708, 2708])
    ## features: torch.Size([2708, 1433])
    ## labels: torch.Size([2708])
    ## idx_train: torch.Size([140])
    ## idx_val: torch.Size([300])
    ## idx_test: torch.Size([1000])

    train.py 

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

     

  • 相关阅读:
    淡季买房注意细节 防售楼部“挂羊头卖狗肉”
    买房容易选房难 八大把关教您如何选好房
    socket发送接收字段采用Base64加密笔记
    深入理解JDK、JRE
    Socket读取JSONArray字串越界等相关问题
    android采用MediaPlayer监听EditText实现语音播报手机号码(阿拉伯数字)
    读取properties文件
    关于android客户端在线版本更新的总结(json源码)
    验证码
    base64举例
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/9332476.html
Copyright © 2011-2022 走看看