zoukankan      html  css  js  c++  java
  • Tf中的SGDOptimizer学习【转载】

    转自:https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer

    1.tf.train.GradientDescentOptimizer

    其中有函数:

    1.1apply_gradients

    apply_gradients(
        grads_and_vars,
        global_step=None,
        name=None
    )

    Apply gradients to variables.

    This is the second part of minimize(). It returns an Operation that applies gradients.

    将梯度应用到变量上。它是minimize函数的第二部分。

    1.2compute_gradients

    compute_gradients(
        loss,
        var_list=None,
        gate_gradients=GATE_OP,
        aggregation_method=None,
        colocate_gradients_with_ops=False,
        grad_loss=None
    )

     Compute gradients of loss for the variables in var_list.

    This is the first part of minimize(). It returns a list of (gradient, variable) pairs where "gradient" is the gradient for "variable". Note that "gradient" can be a Tensor, an IndexedSlices, or None if there is no gradient for the given variable.

    计算var-list的梯度,它是minimize函数的第一部分,返回的是一个list,对应每个变量都有梯度。准备使用apply_gradient函数更新。

    下面重点来了: 

    参数:

    • loss: A Tensor containing the value to minimize or a callable taking no arguments which returns the value to minimize. When eager execution is enabled it must be a callable.
    • var_list: Optional list or tuple of tf.Variable to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.

     loss就是损失函数,没啥了。

     这个第二个参数变量列表通常是不传入的,那么计算谁的梯度呢?上面说,默认的参数列表是计算图中的 GraphKeys.TRAINABLE_VARIABLES.

     去看这个的API发现:

     tf.GraphKeys

     The following standard keys are defined:

    找到TRAINABLE_VARIABLES是:

    • TRAINABLE_VARIABLES: the subset of Variable objects that will be trained by an optimizer. Seetf.trainable_variables for more details.

    然后再去看:

    tf.trainable_variables

    tf.trainable_variables(scope=None)

    Returns all variables created with trainable=True.

    When passed trainable=True, the Variable() constructor automatically adds new variables to the graph collectionGraphKeys.TRAINABLE_VARIABLES.

    This convenience function returns the contents of that collection.

    Returns:

    A list of Variable objects.

    然后再去看一下tf.Variable函数:

    tf.Variable

    __init__(
        initial_value=None,
        trainable=True,
        collections=None,
        validate_shape=True,
        caching_device=None,
        name=None,
        variable_def=None,
        dtype=None,
        expected_shape=None,
        import_scope=None,
        constraint=None,
        use_resource=None,
        synchronization=tf.VariableSynchronization.AUTO,
        aggregation=tf.VariableAggregation.NONE
    )

    并且:

    • trainable: If True, the default, also adds the variable to the graph collection GraphKeys.TRAINABLE_VARIABLES. This collection is used as the default list of variables to use by the Optimizer classes.

     默认为真,并且加入可训练变量集中,所以:

    在word2vec实现中,

    with tf.device('/cpu:0'):
          # Look up embeddings for inputs.
          with tf.name_scope('embeddings'):
            embeddings = tf.Variable(
                tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
            embed = tf.nn.embedding_lookup(embeddings, train_inputs)

    定义的embeddings应该是可以更新的。怎么更新?:

    with tf.name_scope('loss'):
          loss = tf.reduce_mean(
              tf.nn.nce_loss(
                  weights=nce_weights,
                  biases=nce_biases,
                  labels=train_labels,
                  inputs=embed,
                  num_sampled=num_sampled,
                  num_classes=vocabulary_size))
    
        # Add the loss value as a scalar to summary.
        tf.summary.scalar('loss', loss)
    
        # Construct the SGD optimizer using a learning rate of 1.0.
        with tf.name_scope('optimizer'):
          optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

    使用SGD随机梯度下降,在minimize损失函数中,应该是会对所有的可训练变量求导,对的,没错一定是这样,所以nec_weights,nce_biases,embeddings都是可更新变量。

    都是通过先计算损失函数,求导然后更新变量,在迭代数据计算损失函数,求导更新,

    这样来更新的。

  • 相关阅读:
    python3 字典的常用方法
    python3 列表的常用方法
    【自动化测试之路】目录
    《Python编程从入门到实践》练习题
    【python3】第20章设置项目“学习笔记”的样式
    【python3】第19章用户账户
    【python3】第18章Django入门
    【python3】第15章生成数据
    【python3】第12~14章外星人入侵
    【python3】第10章文件
  • 原文地址:https://www.cnblogs.com/BlueBlueSea/p/10616314.html
Copyright © 2011-2022 走看看