zoukankan      html  css  js  c++  java
  • GraphSAGE 代码解析(一)

    原创文章~转载请注明出处哦。其他部分内容参见以下链接~

    GraphSAGE 代码解析(二) - layers.py

    GraphSAGE 代码解析(三) - aggregators.py

    GraphSAGE 代码解析(四) - models.py

    GraphSAGE代码详解

    example_data:

    1. toy-ppi-G.json 图的信息

    { 
      directed: false
      graph : {
                  {name: disjoint_union(,) }
               nodes:  [
                            {  
                                    test: false
                             id: 0
                             features: [ ... ]
                             val: false
                              lable: [ ... ]
                           }
                           {...}
                             ...
                      ]
    
                links: [
                           {  
                                    test_removed: false
                            train_removed: false
                            target: 800 # 指向的节点id(默认从小节点指向大节点)
                            source: 0   # 从0节点按顺序展示
                             }
                             {...}
                               ...
                        ]
          }
    }
    View Code

    2. toy-ppi-class_map.json

    3. toy-ppi-feats.npy 预训练好得到的features

    4. toy-ppi-id_map.json 节点编号与序号的一一对应;数据格式为:{"0": 0, "1": 1,..., "14754": 14754}

    5. toy-ppi-walks.txt 

        从一点出发随机游走到邻居节点的情况,对于每个点取198次(即可能有重复情况)

        例如:0    708 表示从0点走到708点。

    1. __init__.py

    1 from __future__ import print_function  
    2 #即使在python2.X,使用print就得像python3.X那样加括号使用。
    3 
    4 from __future__ import division          
    5 # 导入python未来支持的语言特征division(精确除法),
    6 # 当我们没有在程序中导入该特征时,"/"操作符执行的是截断除法(Truncating Division);
    7 # 当我们导入精确除法之后,"/"执行的是精确除法, "//"执行截断除除法

    2. unsupervised_train.py

    1 if __name__ == '__main__':
    2   tf.app.run()
    3 # https://blog.csdn.net/fxjzzyo/article/details/80466321
    4 # tf.app.run()的作用:通过处理flag解析,然后执行main函数
    5 # 如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如test(),则你应该这样写入口tf.app.run(test())
    6 # 如果你的代码中的入口函数叫main(),则你就可以把入口写成tf.app.run()
    1 def main(argv=None):
    2   print("Loading training data..")
    3   train_data = load_data(FLAGS.train_prefix, load_walks=True)
    4   # load_data函数在graphsage.utils中定义
    5 
    6   print("Done loading training data..")
    7   train(train_data)
    8   # train函数在该文件中定义def train(train_data, test_data=None)

    3. utils.py - func: load_data

    (1) 读入id_map, class_map

    1 if isinstance(G.nodes()[0], int):
    2         def conversion(n): return int(n)
    3     else:
    4         def conversion(n): return n

    a. isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()。 

    isinstance(object, classinfo)

    参数
    object -- 实例对象。
    classinfo -- 可以是直接或间接类名、基本类型或者由它们组成的元组。

    返回值
    如果对象的类型与参数二的类型(classinfo)相同则返回 True,否则返回 False。

    >>>a = 2
    >>> isinstance (a,int)
    True
    >>> isinstance (a,str)
    False
    >>> isinstance (a,(str,int,list))    # 是元组中的一个返回 True
    True

    type() 与 isinstance() 区别:

    type() 不会认为子类是一种父类类型,不考虑继承关系。
    isinstance() 会认为子类是一种父类类型,考虑继承关系。
    如果要判断两个类型是否相同推荐使用 isinstance()。

     1 class A:
     2     pass
     3  
     4 class B(A):
     5     pass
     6  
     7 isinstance(A(), A)    # returns True
     8 type(A()) == A        # returns True
     9 isinstance(B(), A)    # returns True
    10 type(B()) == A        # returns False
    View Code

    b. G.nodes()

    返回的是图中节点n与节点属性nodedata。https://networkx.github.io/documentation/stable/reference/classes/generated/networkx.Graph.nodes.html

    例子:

    >>> G = nx.path_graph(3)
    >>> list(G.nodes)
    [0, 1, 2]
    >>> list(G)
    [0, 1, 2]
    View Code

    获取nodedata:

    >>> G.add_node(1, time='5pm')
    >>> G.nodes[0]['foo'] = 'bar'
    >>> list(G.nodes(data=True))
    [(0, {'foo': 'bar'}), (1, {'time': '5pm'}), (2, {})]
    >>> list(G.nodes.data())
    [(0, {'foo': 'bar'}), (1, {'time': '5pm'}), (2, {})]
    
    >>> list(G.nodes(data='foo'))
    [(0, 'bar'), (1, None), (2, None)]
    
    >>> list(G.nodes(data='time'))
    [(0, None), (1, '5pm'), (2, None)]
    
    >>> list(G.nodes(data='time', default='Not Available'))
    [(0, 'Not Available'), (1, '5pm'), (2, 'Not Available')]
    View Code

    If some of your nodes have an attribute and the rest are assumed to have a default attribute value you can create a dictionary from node/attribute pairs using the default keyword argument to guarantee the value is never None:

    >>> G = nx.Graph()
    >>> G.add_node(0)
    >>> G.add_node(1, weight=2)
    >>> G.add_node(2, weight=3)
    >>> dict(G.nodes(data='weight', default=1))
    {0: 1, 1: 2, 2: 3}
    View Code

    ----------------------------

    在utils.py中,判断G.nodes()[0] 是否为int型(即不带nodedata)。

    若为int型,则将n转为int型;否则直接返回n.

    b. conversion() 函数

    1 id_map = json.load(open(prefix + "-id_map.json"))
    2 id_map = {conversion(k): int(v) for k, v in id_map.items()}

    前面定义的conversion()函数在id_map这里用到了,把外存中的文件内容读到内存中,用dict类型的id_map存储。

    id_map.json文件中数据格式为:{"0": 0, "1": 1,..., "14754": 14754},也即id_map的迭代中k为str类型,v为int型。数据文件中G.nodes()[0] 显然是带nodedata的,也就算一般采用 def conversion(n): return n,返回的n为类型的(就是前面形参k的类型);

    但是为什么当G.nodes()[0] 不带nodedata时,要返回int(n)?

    c. class_map:  {"0": [.0,1,..], "1": [.0,1,..]...} ?含义?

    list(class_map.values()): [ [...], [...], ... ,[...] ]
    list(class_map.values())[0]: 表示取第一个[...] =>含义? 
    if isinstance(list(class_map.values())[0], list):
        def lab_conversion(n): return n
    else:
        def lab_conversion(n): return int(n)

    (2) Remove node

    1 # Remove all nodes that do not have val/test annotations
    2     # (necessary because of networkx weirdness with the Reddit data)
    3     broken_count = 0
    4     for node in G.nodes():
    5         if not 'val' in G.node[node] or not 'test' in G.node[node]:
    6             G.remove_node(node)
    7             broken_count += 1

    这里删除的节点是不具有'val','test'属性 的节点,而不是'val','test' 属性值为None的节点。

    区分开 if not 'val' in G.node[node] 和 if not G.node[n]['val']的不同意义。

    broken_count  记录删去的没有val 或者 test的属性的节点的数目。

    e. G.edges()

    1 for edge in G.edges():
    2         if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
    3                 G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
    4             G[edge[0]][edge[1]]['train_removed'] = True
    5         else:
    6             G[edge[0]][edge[1]]['train_removed'] = False

    G.edges() 得到edge_list, [( , ), ( , ), ... ( , )].list中每一个元素是所表示边的两个节点信息。若设置data = True,则会显示边的权重等属性信息。

    >>> G = nx.Graph()   # or DiGraph, MultiGraph, MultiDiGraph, etc
    >>> G.add_path([0,1,2])
    >>> G.add_edge(2,3,weight=5)
    >>> G.edges()
    [(0, 1), (1, 2), (2, 3)]
    >>> G.edges(data=True) # default edge data is {} (empty dictionary)
    [(0, 1, {}), (1, 2, {}), (2, 3, {'weight': 5})]
    >>> list(G.edges_iter(data='weight', default=1))
    [(0, 1, 1), (1, 2, 1), (2, 3, 5)]
    >>> G.edges([0,3])
    [(0, 1), (3, 2)]
    >>> G.edges(0)
    [(0, 1)]
    View Code

    代码中edge对edges迭代,每次去list中的一个元组,而edge[0], edge[1]则分别表示两个顶点。

    若两个顶点中至少有一个的val/test不为空,则将该边的'train_removed'设为True,否则为False.

    该操作为保证'train_removed'不为空。

    (3) 获取训练数据features并标准化

    1 if normalize and not feats is None:
    2         from sklearn.preprocessing import StandardScaler
    3         train_ids = np.array([id_map[n] for n in G.nodes(
    4         ) if not G.node[n]['val'] and not G.node[n]['test']])
    5         train_feats = feats[train_ids]
    6         scaler = StandardScaler()
    7         scaler.fit(train_feats)
    8         feats = scaler.transform(feats)

    这里if not feats is None 等价于 if feats is not None.

    将val,test均为None的node选为训练数据,通过id_map获取其在feature表中的索引值,添加到train_ids数组中。根据索引train_ids,train_fests获取这些nodes的features.

    StandardScaler的用法:

    http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html

    Methods:

    fit(X[, y]) : Compute the mean and std to be used for later scaling.

    transform(X[, y, copy]) : Perform standardization by centering and scaling

    fit_transform(X[, y]) : Fit to data, then transform it.

    例子:

    >>> from sklearn.preprocessing import StandardScaler
    >>> data = [[0, 0], [0, 0], [1, 1], [1, 1]]
    >>> scaler = StandardScaler()
    >>> print(scaler.fit(data))
    StandardScaler(copy=True, with_mean=True, with_std=True)
    >>> print(scaler.mean_)
    [0.5 0.5]
    >>> print(scaler.transform(data))
    [[-1. -1.]
     [-1. -1.]
     [ 1.  1.]
     [ 1.  1.]]
    >>> print(scaler.transform([[2, 2]]))
    [[3. 3.]]
    
    # 计算得
    # 均值[0.5, 0.5], 
    # 方差:1/4 * [(0 - 0.5)^2 * 2 + (1 - 0.5)^2 * 2] = 1/4 = 0.25
    # 标准差:0.5
    # 对于[2,2] transform 标准化之后: (2 - 0.5) / 0.5 = 3
    View Code

    (4) Load walks

    在unsupervised_train.py的main函数中:

    1 train_data = load_data(FLAGS.train_prefix, load_walks=True)

    load_walks = True,需要执行utils.py中的load_walks操作。

    1 if load_walks:  # false by default
    2         with open(prefix + "-walks.txt") as fp:
    3             for line in fp:
    4                 walks.append(map(conversion, line.split()))

    map() 的用法:http://www.runoob.com/python/python-func-map.html

    map(function, iterable, ...)

    map() 会根据提供的函数对指定序列做映射。

    第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新列表。

    例子:

    >>>def square(x) :            # 计算平方数
    ...     return x ** 2
    ... 
    >>> map(square, [1,2,3,4,5])   # 计算列表各个元素的平方
    [1, 4, 9, 16, 25]
    >>> map(lambda x: x ** 2, [1, 2, 3, 4, 5])  # 使用 lambda 匿名函数
    [1, 4, 9, 16, 25]
     
    # 提供了两个列表,对相同位置的列表数据进行相加
    >>> map(lambda x, y: x + y, [1, 3, 5, 7, 9], [2, 4, 6, 8, 10])
    [3, 7, 11, 15, 19]
    View Code

    walks初始化为[], 之后append的是游走的节点对的对象。

    例子:walks.txt:

    0    708
    0    3163
    0    276
    1 def conversion(n): return n
    2 walks = []
    3 with open("walks.txt") as fp:
    4     for line in fp:
    5         print(line.split())
    6         walks.append(map(conversion, line.split()))
    7 print(walks) 
    8 print(len(walks))
    View Code

    输出:

    ['0', '708']
    ['0', '3163']
    ['0', '276']
    [<map object at 0x7f5bc0d68da0>, <map object at 0x7f5bc0d68e48>, <map object at 0x7f5bc0d68f28>]
    3

    (5) 函数返回值

    1 return G, feats, id_map, walks, class_map

    ------------------------------------------------------------------------------------

    4. unsupervised_train.py - func: train(train_data)

    1 def train(train_data, test_data=None):

    这里的train_data是上文所述的load_data函数的返回值。

    变量含义:

    G = train_data[0]    #
    features = train_data[1]    # 训练数据的features
    id_map = train_data[2]     # "n" : n
    context_pairs = train_data[3] if FLAGS.random_context else None #random walk的点对
    1 if not features is None:
    2     # pad with dummy zero vector
    3     features = np.vstack([features, np.zeros((features.shape[1],))])

    这里vstack为features添加列一行0向量,用于WX + b中与b相加。

    1 placeholders = construct_placeholders()
    2 # def construct_placeholders()定义的placeholders包含:
    3 # batch1, batch2, neg_samples, dropout, batch_size

    minibatch是EdgeMinibatchIterator的一个实例,转至minibatch.py看class EdgeMinibatchIterator(object)的定义。

    5. minibatch.py - class EdgeMinibatchIterator

    https://www.cnblogs.com/shiyublog/p/9902423.html

    6. unsupervised_train.py - func train

    继续回来看unsupervised_trian.py 中的train函数

    变量:

    1 adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    2 adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")

    adj_info记录邻居信息,是一个矩阵,矩阵每一行对应每一个节点的邻居节点编号数组。

    (1)选择模型

    接下来根据输入参数判断选择6种模型(graphsage_mean,gcn,graphsage_seq,graphsage_maxpool,graphsage_meanpool,n2v)中的哪一种。

    以graphsage开头的几种是graphsage的几种变体,由于aggregator不同而不同。可以通过设定SampleAndAggregate()中的aggregator_type进行选择。默认为mean.

    其中gcn与graphsage的参数不同在于:

    gcn的aggregator中进行列concat的操作,因此其维数是graphsage的二倍。

    a. graphsage_maxpool 

     1 sampler = UniformNeighborSampler(adj_info)

    首先看UniformNeighborSampler,该类用于sample节点的邻居,在neigh_samplers.py中。

    neigh_samplers.py

     1 class UniformNeighborSampler(Layer):
     2     """
     3     Uniformly samples neighbors.
     4     Assumes that adj lists are padded with random re-sampling
     5     """
     6     def __init__(self, adj_info, **kwargs):
     7         super(UniformNeighborSampler, self).__init__(**kwargs)
     8         self.adj_info = adj_info
     9 
    10     def _call(self, inputs):
    11         ids, num_samples = inputs
    12         adj_lists = tf.nn.embedding_lookup(self.adj_info, ids) 
    13         adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))
    14         adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples])
    15         return adj_lists

    1.  tf.nn.embedding_lookup 用于根据ids在adj_info中找到各个对应位的向量。

    2. adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))

        adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples]) 的过程见下:

    id0 id1 id2...   --transpose--> id0 [...]  --shuffle--> id1 [...]  --transpose--> id1 id2 id0 --slice--> id1 id2

    []    []    []                                id1 [...]                     id2 [...]                         []     []    []                  []    []

                                                  id2 [...]                     id0 [...]

    均匀:shuffle打乱0维的顺序,即打乱行顺序,以此使下面采样可以“均匀”。为了使用shuffle函数,需要在shuffle前后transpose一下。

    采样:slice之后,相当于随机挑选了num_samples个样本,并保留了这些样本的全部属性特征。

    3. 最后的adj_lists即为均匀采样后的表示邻居信息的矩阵。

    ---------------------------------------------------

    回到unsupervised_train.py 的train()函数.

    1 sampler = UniformNeighborSampler(adj_info)

    sampler获取均匀采样后的邻居节点信息。

    ---------------------------------------------------

    1 layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
    2                SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]

    其中SAGEInfo在models.py中。

    models.py 

    https://www.cnblogs.com/shiyublog/p/9879875.html

    1 # SAGEInfo is a namedtuple that specifies the parameters 
    2 # of the recursive GraphSAGE layers
    3 SAGEInfo = namedtuple("SAGEInfo",
    4     ['layer_name', # name of the layer (to get feature embedding etc.)
    5      'neigh_sampler', # callable neigh_sampler constructor
    6      'num_samples',
    7      'output_dim' # the output (i.e., hidden) dimension
    8     ])

    namedtuple 命名元组,可以给tuple命名,用法见下:

    https://www.cnblogs.com/chenlin163/p/7259061.html

     1 import collections
     2 
     3 MyTupleClass = collections.namedtuple('MyTupleClass',['name', 'age', 'job'])
     4 obj = MyTupleClass("Tomsom",12,'Cooker')
     5 print(obj.name)
     6 print(obj.age)
     7 print(obj.job)
     8 
     9 # Output:
    10 # Tomsom
    11 # 12
    12 # Cooker
    13 #############################
    14 
    15 Person=collections.namedtuple('Person','name age gender') 
    16 # 以空格分开,表示这个namedtuple有三个元素
    17 
    18 print( 'Type of Person:',type(Person))
    19 Bob=Person(name='Bob',age=30,gender='male')
    20 print( 'Representation:',Bob)
    21 Jane=Person(name='Jane',age=29,gender='female')
    22 print( 'Field by Name:',Jane.name)
    23 for people in [Bob,Jane]:
    24     print ("%s is %d years old %s" % people)
    25 
    26 # Output:
    27 # Type of Person: <class 'type'>
    28 # Representation: Person(name='Bob', age=30, gender='male')
    29 # Field by Name: Jane
    30 # Bob is 30 years old male
    31 # Jane is 29 years old female
    32 #############################
    33 
    34 # 在使用namedtyuple的时候要注意其中的名称不能使用Python的关键字,如class def等
    35 # 不能有重复的元素名称,比如:不能有两个’age age’。如果出现这些情况,程序会报错。
    36 # 但是,在实际使用的时候可能无法避免这种情况,
    37 # 比如:可能我们的元素名称是从数据库里读出来的记录,这样很难保证一定不会出现Python关键字。
    38 # 这种情况下的解决办法是将namedtuple的重命名模式打开,
    39 # 这样如果遇到Python关键字或者有重复元素名 时,自动进行重命名。
    40 
    41 with_class=collections.namedtuple('Person','name age class gender',rename=True)
    42 print with_class._fields
    43 two_ages=collections.namedtuple('Person','name age gender age',rename=True)
    44 print two_ages._fields
    45 
    46 # Output:
    47 # ('name', 'age', '_2', 'gender')
    48 # ('name', 'age', 'gender', '_3')
    49 
    50 # 使用rename=True的方式打开重命名选项。
    51 # 可以看到第一个集合中的class被重命名为 ‘_2' ; 
    52 # 第二个集合中重复的age被重命名为 ‘_3'
    53 # namedtuple在重命名的时候使用了下划线 _ 加元素所在索引数的方式进行重命名
    54 ##############################
    55 
    56 # 附两段官方文档代码实例:
    57 # 1) namedtuple基本用法
    58 >>> # Basic example
    59 >>> Point = namedtuple('Point', ['x', 'y'])
    60 >>> p = Point(11, y=22) # instantiate with positional or keyword arguments
    61 >>> p[0] + p[1] # indexable like the plain tuple (11, 22)
    62 33
    63 >>> x, y = p # unpack like a regular tuple
    64 >>> x, y
    65 (11, 22)
    66 >>> p.x + p.y # fields also accessible by name
    67 33
    68 >>> p # readable __repr__ with a name=value style
    69 Point(x=11, y=22)
    70 
    71 # 2) namedtuple结合csv和sqlite用法
    72 EmployeeRecord = namedtuple('EmployeeRecord', 'name, age, title, department, paygrade')
    73 import csv
    74 for emp in map(EmployeeRecord._make, csv.reader(open("employees.csv", "rb"))):
    75 print(emp.name, emp.title)
    76 
    77 import sqlite3
    78 conn = sqlite3.connect('/companydata')
    79 cursor = conn.cursor()
    80 cursor.execute('SELECT name, age, title, department, paygrade FROM employees')
    81 for emp in map(EmployeeRecord._make, cursor.fetchall()):
    82 print(emp.name, emp.title)
    View Code

    对于FLAGS.dim_1FLAGS.dim_2,定义为:

    1 flags.DEFINE_integer(
    2     'dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
    3 flags.DEFINE_integer(
    4     'dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')

    若GCN,因为有concat操作,故使用2x.

    对于FLAGS.samples_1FLAGS.samples_2,定义为:

    1 flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')
    2 flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2')

    对应论文中的K = 1 ,第一层S1 = 25; K = 2 ,第二层S2 = 10。

    ----------------------------------------------------------

    1 model = SampleAndAggregate(placeholders,
    2                            features,
    3                            adj_info,
    4                            minibatch.deg,
    5                            layer_infos=layer_infos,
    6                            aggregator_type="maxpool",
    7                            model_size=FLAGS.model_size,
    8                            identity_dim=FLAGS.identity_dim,
    9                            logging=True)

    SampleAndAggregate在models.py中。

    class SampleAndAggregate(GeneralizedModel)主要包含的函数有:

    1. def __init__(self, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean",  model_size="small", identity_dim=0, **kwargs)

    2. def sample(self, inputs, layer_infos, batch_size=None)

    3. def aggregate(self, samples, input_features, dims, num_samples, support_sizes, batch_size=None,
    aggregators=None, name=None, concat=False, model_size="small")

    4. def _build(self)

    5. def build(self)

    6. def _loss(self)

    7. def _accuracy(self)

    ---------------------------------------------------------------

    (2) Session

    Config

     1 config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
     2 # 参数初始化为False:
     3 # tf.app.flags.DEFINE_boolean('log_device_placement', False,
     4 #                     """Whether to log device placement.""")
     5 
     6 config.gpu_options.allow_growth = True
     7 # 控制GPU资源使用率
     8 # 使用allow_growth option,刚一开始分配少量的GPU容量,然后按需慢慢的增加,
     9 # 由于不会释放内存,所以会导致碎片
    10 
    11 #config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    12 # 设置每个GPU应该拿出多少容量给进程使用,
    13 # per_process_gpu_memory_fraction =0.4代表 40%
    14 
    15 config.allow_soft_placement = True
    16 # 自动选择运行设备
    17 # 在tf中,通过命令 "with tf.device('/cpu:0'):",允许手动设置操作运行的设备。
    18 # 如果手动设置的设备不存在或者不可用,就会导致tf程序等待或异常,
    19 # 为了防止这种情况,可以设置tf.ConfigProto()中参数allow_soft_placement=True,
    20 # 允许tf自动选择一个存在并且可用的设备来运行操作。

     Initialize session

     1 # Initialize session
     2 sess = tf.Session(config=config)
     3 merged = tf.summary.merge_all()
     4 # tf.summary()能够保存训练过程以及参数分布图并在tensorboard显示。
     5 # merge_all 可以将所有summary全部保存到磁盘,以便tensorboard显示。
     6 # 如果没有特殊要求,一般用这一句就可一显示训练时的各种信息了
     7 
     8 summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
     9 # 指定一个文件用来保存图。
    10 # 格式:tf.summary.FileWritter(path,sess.graph)
    11 # 可以调用其add_summary()方法将训练过程数据保存在filewriter指定的文件中

    Init variables

    1 sess.run(tf.global_variables_initializer(),
    2      feed_dict={adj_info_ph: minibatch.adj})

    ---------------------------------------------------------

    (4) Train model

    1 feed_dict = minibatch.next_minibatch_feed_dict()

    next_minibatch_feed_dict() 在minibatch.py的class EdgeMinibatchIterator(object)中定义。

    1 def next_minibatch_feed_dict(self):
    2     start_idx = self.batch_num * self.batch_size
    3     self.batch_num += 1
    4     end_idx = min(start_idx + self.batch_size, len(self.train_edges))
    5     batch_edges = self.train_edges[start_idx: end_idx]
    6     return self.batch_feed_dict(batch_edges)
    View Code

    函数中获取下个edgeminibatch的起始与终止序号,将batch后的边的信息传给batch_feed_dict(self, batch_edges)函数,更新placeholders中的batch1, batch2, batch_size信息。

     1 def batch_feed_dict(self, batch_edges):
     2     batch1 = []
     3     batch2 = []
     4     for node1, node2 in batch_edges:
     5         batch1.append(self.id2idx[node1])
     6         batch2.append(self.id2idx[node2])
     7 
     8     feed_dict = dict()
     9     feed_dict.update({self.placeholders['batch_size']: len(batch_edges)})
    10     feed_dict.update({self.placeholders['batch1']: batch1})
    11     feed_dict.update({self.placeholders['batch2']: batch2})
    12 
    13     return feed_dict
    View Code

    也即next_minibatch_feed_dict()返回的是下一个edge minibatch的placeholders信息。

    =======================================

         感谢您的支持!             感谢您的支持!

    感谢您的打赏!

    (梦想还是要有的,万一您喜欢我的文章呢)

  • 相关阅读:
    2018CodeM复赛
    poj3683
    bzoj3991
    bzoj2809
    bzoj1001
    bzoj1412
    计蒜之道2018复赛
    HDU2255
    bzoj1010
    bzoj2006
  • 原文地址:https://www.cnblogs.com/shiyublog/p/9819086.html
Copyright © 2011-2022 走看看