zoukankan      html  css  js  c++  java
  • tf.trainable_variables()

    https://blog.csdn.net/shwan_ma/article/details/78879620

    一般来说,打印tensorflow变量的函数有两个:
    tf.trainable_variables () 和 tf.all_variables()
    不同的是:
    tf.trainable_variables () 指的是需要训练的变量
    tf.all_variables() 指的是所有变量

    一般而言,我们更关注需要训练的训练变量:
    值得注意的是,在输出变量名时,要对整个graph进行初始化

    一、打印需要训练的变量名称

    variable_names = [v.name for v in tf.trainable_variables()]
    print(variable_names)
    1
    2
    二、打印需要训练的变量名称和变量值

    variable_names = [v.name for v in tf.trainable_variables()]
    values = sess.run(variable_names)
    for k,v in zip(variable_names, values):
    print("Variable: ", k)
    print("Shape: ", v.shape)
    print(v)
    1
    2
    3
    4
    5
    6
    这里提供一个函数,打印变量名称,shape及其变量数目

    def print_num_of_total_parameters(output_detail=False, output_to_logging=False):
    total_parameters = 0
    parameters_string = ""

    for variable in tf.trainable_variables():

    shape = variable.get_shape()
    variable_parameters = 1
    for dim in shape:
    variable_parameters *= dim.value
    total_parameters += variable_parameters
    if len(shape) == 1:
    parameters_string += ("%s %d, " % (variable.name, variable_parameters))
    else:
    parameters_string += ("%s %s=%d, " % (variable.name, str(shape), variable_parameters))

    if output_to_logging:
    if output_detail:
    logging.info(parameters_string)
    logging.info("Total %d variables, %s params" % (len(tf.trainable_variables()), "{:,}".format(total_parameters)))
    else:
    if output_detail:
    print(parameters_string)
    print("Total %d variables, %s params" % (len(tf.trainable_variables()), "{:,}".format(total_parameters)))

    萍水相逢逢萍水,浮萍之水水浮萍!
  • 相关阅读:
    Matlab随笔之三维图形绘制
    Matlab随笔之模拟退火算法
    Matlab随笔之矩阵入门知识
    Matlab随笔之求解线性方程
    Matlab随笔之分段线性函数化为线性规划
    Matlab随笔之指派问题的整数规划
    Matlab随笔之线性规划
    Android单位转换 (px、dp、sp之间的转换工具类)
    Android禁止输入表情符号
    设计模式之策略模式
  • 原文地址:https://www.cnblogs.com/AIBigTruth/p/10504535.html
Copyright © 2011-2022 走看看