zoukankan      html  css  js  c++  java
  • GraphSAGE 代码解析

    class EdgeMinibatchIterator

        """ This minibatch iterator iterates over batches of sampled edges or
        random pairs of co-occuring edges.
    
        G -- networkx graph
        id2idx -- dict mapping node ids to index in feature tensor
        placeholders -- tensorflow placeholders object
        context_pairs -- if not none, then a list of co-occuring node pairs (from random walks)
        batch_size -- size of the minibatches
        max_degree -- maximum size of the downsampled adjacency lists
        n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model
        fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context
        """

    def __init__(self, G, id2idx, placeholders, context_pairs=None, batch_size=100, max_degree=25,

    n2v_retrain=False, fixed_n2v=False, **kwargs) 中具体介绍以下:

    1 self.nodes = np.random.permutation(G.nodes())
    2 # 函数shuffle与permutation都是对原来的数组进行重新洗牌,即随机打乱原来的元素顺序
    3 # shuffle直接在原来的数组上进行操作,改变原来数组的顺序,无返回值
    4 # permutation不直接在原来的数组上进行操作,而是返回一个新的打乱顺序的数组,并不改变原来的数组。
    1 self.adj, self.deg = self.construct_adj()

    这里重点看construct_adj()函数。

     1 def construct_adj(self):
     2         adj = len(self.id2idx) * 
     3             np.ones((len(self.id2idx) + 1, self.max_degree))
     4         # 该矩阵记录训练数据中各节点的邻居节点的编号
     5         # 采样只取max_degree个邻居节点,采样方法见下
     6         # 同样进行了行数加一操作
     7 
     8         deg = np.zeros((len(self.id2idx),))
     9         # 该矩阵记录了每个节点的度数
    10 
    11         for nodeid in self.G.nodes():
    12             if self.G.node[nodeid]['test'] or self.G.node[nodeid]['val']:
    13                 continue
    14             neighbors = np.array([self.id2idx[neighbor]
    15                                   for neighbor in self.G.neighbors(nodeid)                   
    16                                   if (not self.G[nodeid][neighbor]['train_removed'])])
    17             # Graph.neighbors() Return a list of the nodes connected to the node n.
    18             # 在选取邻居节点时进行了筛选,对于G.neighbors(nodeid) 点node的邻居,
    19             # 只取该node与neighbor相连的边的train_removed = False的neighbor
    20             # 也就是只取不是val, test的节点。
    21             # neighbors得到了邻居节点编号数列。
    22 
    23             deg[self.id2idx[nodeid]] = len(neighbors)
    24             # deg各位取值为该位对应nodeid的节点的度数,
    25             # 也即经过上面筛选后得到的邻居数
    26 
    27             if len(neighbors) == 0:
    28                 continue
    29             if len(neighbors) > self.max_degree:
    30                 neighbors = np.random.choice(
    31                     neighbors, self.max_degree, replace=False)
    32             # range: neighbors; size = max_degree; replace: replace the origin matrix or not
    33             # np.random.choice为选取size大小的数列
    34 
    35             elif len(neighbors) < self.max_degree:
    36                 neighbors = np.random.choice(
    37                     neighbors, self.max_degree, replace=True)
    38             # 经过choice随机选取,得到了固定大小max_degree = 25的直接相连的邻居数列
    39 
    40             adj[self.id2idx[nodeid], :] = neighbors
    41            # 把该node的邻居数列,赋值给adj矩阵中对应nodeid位的向量。
    42         return adj, deg

    construct_test_adj()  函数中,与上不同之处在于,可以直接得到邻居而无需根据val/test/train_removed筛选.

    1 neighbors = np.array([self.id2idx[neighbor]
    2                           for neighbor in self.G.neighbors(nodeid)])

  • 相关阅读:
    iOS开发UI—Button基础
    iOS开发UI—UIWindow介绍
    第43月第27天 nginx keeplike高可用
    第43月第23天 商品秒杀 乐观锁
    第43月第22天 github 工程 svn checkout ipa瘦身
    第43月第21天 h264文件格式
    第43月第17天 iOS 子线程开启、关闭runloop performSelector
    第43月第15天 nginx负载均衡 redis
    第43月第11天 kCVPixelFormatType_420YpCbCr8BiPlanarFullRange转rgb
    第43月第10天 uiimage写文件
  • 原文地址:https://www.cnblogs.com/shiyublog/p/9902423.html
Copyright © 2011-2022 走看看