zoukankan      html  css  js  c++  java
  • GraphSAGE 代码解析(二)

     1 # global unique layer ID dictionary for layer name assignment
     2 _LAYER_UIDS = {}
     3 
     4 def get_layer_uid(layer_name=''):
     5     """Helper function, assigns unique layer IDs."""
     6     if layer_name not in _LAYER_UIDS:
     7         _LAYER_UIDS[layer_name] = 1
     8         return 1
     9     else:
    10         _LAYER_UIDS[layer_name] += 1
    11         return _LAYER_UIDS[layer_name]

    这里_LAYER_UIDS = {} 是记录layer及其出现次数的字典。

    在 get_layer_uid()函数中,若layer_name从未出现过,如今出现了,则将_LAYER_UIDS[layer_name]设为1;否则累加。

    作用: 在class Layer中,当未赋variable scope的name时,通过实例化Layer的次数来标定不同的layer_id.

    例子:简化一下class Layer可以看出:

     1 class Layer():
     2     def __init__(self):
     3         layer = self.__class__.__name__
     4         name = layer + '_' + str(get_layer_uid(layer))
     5         print(name) 
     6 
     7 layer1 = Layer()
     8 layer2 = Layer()
     9 
    10 # Output:
    11 # Layer_1
    12 # Layer_2
    View Code

     2. class Layer

    class Layer主要定义基本的层的API。

     1 class Layer(object):
     2     """Base layer class. Defines basic API for all layer objects.
     3     Implementation inspired by keras (http://keras.io).
     4     # Properties
     5         name: String, defines the variable scope of the layer.
     6         logging: Boolean, switches Tensorflow histogram logging on/off
     7 
     8     # Methods
     9         _call(inputs): Defines computation graph of layer
    10             (i.e. takes input, returns output)
    11         __call__(inputs): Wrapper for _call()
    12         _log_vars(): Log all variables
    13     """
    14 
    15     def __init__(self, **kwargs):
    16         allowed_kwargs = {'name', 'logging', 'model_size'}
    17         for kwarg in kwargs.keys():
    18             assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
    19         name = kwargs.get('name')
    20         if not name:
    21             layer = self.__class__.__name__.lower() # "layer"
    22             name = layer + '_' + str(get_layer_uid(layer))
    23         self.name = name
    24         self.vars = {}
    25         logging = kwargs.get('logging', False)
    26         self.logging = logging
    27         self.sparse_inputs = False
    28 
    29     def _call(self, inputs):
    30         return inputs
    31 
    32     def __call__(self, inputs):
    33         with tf.name_scope(self.name):
    34             if self.logging and not self.sparse_inputs:
    35                 tf.summary.histogram(self.name + '/inputs', inputs)
    36             outputs = self._call(inputs)
    37             if self.logging:
    38                 tf.summary.histogram(self.name + '/outputs', outputs)
    39             return outputs
    40 
    41     def _log_vars(self):
    42         for var in self.vars:
    43             tf.summary.histogram(self.name + '/vars/' + var, self.vars[var])
    View Code

    方法:

    __init__(): 获取传入的name, logging, model_size参数。初始化实例变量name, vars{}, logging, sparse_inputs

    _call(inputs): 定义层的计算图:获取input, 返回output.

    __call__(inputs): 相当于_call()的装饰器,在实现列_call()基本功能后,丰富了其功能,这里主要通过tf.summary.histogram() 可以查看inputs与outputs分布情况的直方图。

    _log_vars(): 记录所有变量。实现时主要将vars中的各个变量以直方图形式显示。

    3. class Dense

    Dense layer主要用于实现全连接层的基本功能。即为了最终得到 Relu(Wx + b)。

    __init__(): 用于获取初始化成员变量。其中num_features_nonzero和featureless的作用目前还不清楚。

    _call(): 用于实现并且返回Relu(Wx + b)

     1 class Dense(Layer):
     2     """Dense layer."""
     3 
     4     def __init__(self, input_dim, output_dim, dropout=0.,
     5                  act=tf.nn.relu, placeholders=None, bias=True, featureless=False,
     6                  sparse_inputs=False, **kwargs):
     7         super(Dense, self).__init__(**kwargs)
     8 
     9         self.dropout = dropout
    10 
    11         self.act = act
    12         self.featureless = featureless
    13         self.bias = bias
    14         self.input_dim = input_dim
    15         self.output_dim = output_dim
    16 
    17         # helper variable for sparse dropout
    18         self.sparse_inputs = sparse_inputs
    19         if sparse_inputs:
    20             self.num_features_nonzero = placeholders['num_features_nonzero']
    21 
    22         with tf.variable_scope(self.name + '_vars'):
    23             self.vars['weights'] = tf.get_variable('weights', shape=(input_dim, output_dim),
    24         dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(),                                             
    25         regularizer=tf.contrib.layers.l2_regularizer(FLAGS.weight_decay))
    26             if self.bias:
    27                 self.vars['bias'] = zeros([output_dim], name='bias')
    28 
    29         if self.logging:
    30             self._log_vars()
    31 
    32     def _call(self, inputs):
    33         x = inputs
    34         x = tf.nn.dropout(x, 1 - self.dropout)
    35 
    36         # transform
    37         output = tf.matmul(x, self.vars['weights'])
    38 
    39         # bias
    40         if self.bias:
    41             output += self.vars['bias']
    42 
    43         return self.act(output)
    View Code

  • 相关阅读:
    BoundsChecker使用
    完成端口(Completion Port)详解
    VC内存泄露检查工具:VisualLeakDetector
    AcceptEx函数与完成端口的结合使用例子
    IOCP之accept、AcceptEx、WSAAccept的区别
    Visual C++ 6.0安装
    eclipse中在线安装FindBugs
    几种开源SIP协议栈对比
    全情投入是做好工作的基础——Leo鉴书39
    CheckStyle检查项目分布图
  • 原文地址:https://www.cnblogs.com/shiyublog/p/9894617.html
Copyright © 2011-2022 走看看