zoukankan      html  css  js  c++  java
  • 『TensorFlow』正则化添加方法整理

    一、基础正则化函数

    tf.contrib.layers.l1_regularizer(scale, scope=None)

    返回一个用来执行L1正则化的函数,函数的签名是func(weights)
    参数:

    • scale: 正则项的系数.
    • scope: 可选的scope name

    tf.contrib.layers.l2_regularizer(scale, scope=None)

    先看看tf.contrib.layers.l2_regularizer(weight_decay)都执行了什么:

    import tensorflow as tf
    sess=tf.Session()
    weight_decay=0.1
    tmp=tf.constant([0,1,2,3],dtype=tf.float32)
    """
    l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)
    a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) 
    """
    #**上面代码的等价代码
    a=tf.get_variable("I_am_a",initializer=tmp)
    a2=tf.reduce_sum(a*a)*weight_decay/2;
    a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2)
    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2)
    #**
    sess.run(tf.global_variables_initializer())
    keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    for key in keys:
      print("%s : %s" %(key.name,sess.run(key)))
    我们很容易可以模拟出tf.contrib.layers.l2_regularizer都做了什么,不过会让代码变丑。
    以下比较完整实现L2 正则化。
    import tensorflow as tf
    sess=tf.Session()
    weight_decay=0.1                                                #(1)定义weight_decay
    l2_reg=tf.contrib.layers.l2_regularizer(weight_decay)           #(2)定义l2_regularizer()
    tmp=tf.constant([0,1,2,3],dtype=tf.float32)
    a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp)  #(3)创建variable,l2_regularizer复制给regularizer参数。
                                                                    #目测REXXX_LOSSES集合
    #regularizer定义会将a加入REGULARIZATION_LOSSES集合
    print("Global Set:")
    keys = tf.get_collection("variables")
    for key in keys:
      print(key.name)
    print("Regular Set:")
    keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    for key in keys:
      print(key.name)
    print("--------------------")
    sess.run(tf.global_variables_initializer())
    print(sess.run(a))
    reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)   #(4)则REGULARIAZTION_LOSSES集合会包含所有被weight_decay后的参数和,将其相加
    l2_loss=tf.add_n(reg_set)
    print("loss=%s" %(sess.run(l2_loss)))
    """
    此处输出0.7,即:
       weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7
    其实代码自己写也很方便,用API看着比较正规。
    在网络模型中,直接将l2_loss加入loss就好了。(loss变大,执行train自然会decay)
    """

    二、添加正则化方法

    a、原始办法

    正则化常用到集合,下面是最原始的添加正则办法(直接在变量声明后将之添加进'losses'集合或tf.GraphKeys.LOESSES也行):

    import tensorflow as tf
    import numpy as np
    
    def get_weights(shape, lambd):
    
        var = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
        tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(lambd)(var))
        return var
    
    
    x = tf.placeholder(tf.float32, shape=(None, 2))
    y_ = tf.placeholder(tf.float32, shape=(None, 1))
    batch_size = 8
    layer_dimension = [2, 10, 10, 10, 1]
    n_layers = len(layer_dimension)
    cur_lay = x
    in_dimension = layer_dimension[0]
    
    for i in range(1, n_layers):
        out_dimension = layer_dimension[i]
        weights = get_weights([in_dimension, out_dimension], 0.001)
        bias = tf.Variable(tf.constant(0.1, shape=[out_dimension]))
        cur_lay = tf.nn.relu(tf.matmul(cur_lay, weights)+bias)
        in_dimension = layer_dimension[i]
    
    mess_loss = tf.reduce_mean(tf.square(y_-cur_lay))
    tf.add_to_collection('losses', mess_loss)
    loss = tf.add_n(tf.get_collection('losses'))

    b、tf.contrib.layers.apply_regularization(regularizer, weights_list=None)

    先看参数

    • regularizer:就是我们上一步创建的正则化方法
    • weights_list: 想要执行正则化方法的参数列表,如果为None的话,就取GraphKeys.WEIGHTS中的weights.

    函数返回一个标量Tensor,同时,这个标量Tensor也会保存到GraphKeys.REGULARIZATION_LOSSES中.这个Tensor保存了计算正则项损失的方法.

    tensorflow中的Tensor是保存了计算这个值的路径(方法),当我们run的时候,tensorflow后端就通过路径计算出Tensor对应的值

    现在,我们只需将这个正则项损失加到我们的损失函数上就可以了.

    如果是自己手动定义weight的话,需要手动将weight保存到GraphKeys.WEIGHTS中,但是如果使用layer的话,就不用这么麻烦了,别人已经帮你考虑好了.(最好自己验证一下tf.GraphKeys.WEIGHTS中是否包含了所有的weights,防止被坑)

    c、使用slim

    使用slim会简单很多:

     with slim.arg_scope([slim.conv2d, slim.fully_connected],
                                activation_fn=tf.nn.relu,
                                weights_regularizer=slim.l2_regularizer(weight_decay)):
        pass

    此时添加集合为tf.GraphKeys.REGULARIZATION_LOSSES。

  • 相关阅读:
    Python_Openpyxl 浅谈(最全总结 足够初次使用)
    @requestMapping的produces和consumes属性
    JDK、CGLIB、Spring 三种实现代理的区别(三)Spring的ProxyFactory
    SpringBoot实现限制ip访问次数
    springboot过滤器禁止ip频繁访问
    工厂模式,简单工厂模式,抽象工厂模式三者有什么区别
    eclipse中git用本地或线上分支完全覆盖本地分支——reset
    Eclipse 中git插件文件冲突解决
    客名利
    Bootstrap3基础 栅格系统 页面布局随 浏览器大小的变化而变化
  • 原文地址:https://www.cnblogs.com/hellcat/p/9474393.html
Copyright © 2011-2022 走看看