zoukankan      html  css  js  c++  java
  • node2vec实现源码详解

    一、按照程序执行的顺序,第一步是walker.py中的preprocess_transition_probs()函数

    这个函数的作用是生成两个采样预备数据,alias_nodes,alias_edges。

    两份数据又各自包含两个列表,这两个列表分别对应着alias采样中的概率和另一个选项,具体alias采样详见https://blog.csdn.net/haolexiao/article/details/65157026

    alias_nodes:根据node和它的邻居之间的权重确定采样的概率,权重越高,被采中的概率越大。

    alias_edges:调用get_alias_edge()函数生成,返回在前一个访问顶点为t,当前顶点为v时决定下一次访问哪个邻接点时需要的alias表

     1 def preprocess_transition_probs(self):
     2         """
     3         Preprocessing of transition probabilities for guiding the random walks.
     4         """
     5         G = self.G
     6 
     7         alias_nodes = {}
     8         for node in G.nodes():
     9             unnormalized_probs = [G[node][nbr].get('weight', 1.0)
    10                                   for nbr in G.neighbors(node)]
    11             norm_const = sum(unnormalized_probs)
    12             normalized_probs = [
    13                 float(u_prob)/norm_const for u_prob in unnormalized_probs]
    14             alias_nodes[node] = create_alias_table(normalized_probs)
    15 
    16         alias_edges = {}
    17 
    18         for edge in G.edges():
    19             alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])
    20 
    21         self.alias_nodes = alias_nodes
    22         self.alias_edges = alias_edges
    23 
    24         return

    二、第二个比较重要的函数是node2vec_walk()函数

    该函数是从start_node开始,生成walk_length长度的序列,序列的生成除了考虑当前节点,还考虑前一个遍历的节点。

    采样方法是根据之前生成的alias数据进行采样。

    对每一个节点都生成一个序列

    def node2vec_walk(self, walk_length, start_node):

      

     1 def node2vec_walk(self, walk_length, start_node):
     2 
     3         G = self.G
     4         alias_nodes = self.alias_nodes
     5         alias_edges = self.alias_edges
     6 
     7         walk = [start_node]
     8 
     9         while len(walk) < walk_length:
    10             cur = walk[-1]
    11             cur_nbrs = list(G.neighbors(cur))
    12             if len(cur_nbrs) > 0:
    13                 if len(walk) == 1:
    14                     walk.append(
    15                         cur_nbrs[alias_sample(alias_nodes[cur][0], alias_nodes[cur][1])])
    16                 else:
    17                     prev = walk[-2]
    18                     edge = (prev, cur)
    19                     try:
    20                         prob=alias_edges[edge][0]
    21                         alias=alias_edges[edge][1]
    22                     except KeyError:
    23                         print()
    24                     next_node = cur_nbrs[alias_sample(prob,alias)]
    25                     walk.append(next_node)
    26             else:
    27                 break
    28 
    29         return walk

    三、之后就是调用gensim中的Word2Vec进行训练,得到每个节点的embedding。

  • 相关阅读:
    使用Session防止表单重复提交
    Mysql中的排序规则utf8_unicode_ci、utf8_general_ci的区别总结
    Eclipse 设置文件的默认打开方式
    使用maven创建web项目
    solr配置中文分词器——(十二)
    solr后台界面介绍——(十一)
    solr4.10.3部署到tomcat——(十)
    Java与计算机常识
    solr简介——(九)
    Redis简介——(一)
  • 原文地址:https://www.cnblogs.com/stAr-1/p/13025779.html
Copyright © 2011-2022 走看看