zoukankan      html  css  js  c++  java
  • 【转载】 深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方,

    原文地址:

    https://blog.csdn.net/weixin_40759186/article/details/87547795

    ---------------------------------------------------------------------------------------------------------------

    用pytorch做dropout和BN时需要注意的地方

    pytorch做dropout:

    就是train的时候使用dropout,训练的时候不使用dropout,
    pytorch里面是通过net.eval()固定整个网络参数,包括不会更新一些前向的参数,没有dropout,BN参数固定,理论上对所有的validation set都要使用net.eval()
    net.train()表示会纳入梯度的计算。

    net_dropped = torch.nn.Sequential(
        torch.nn.Linear(1, N_HIDDEN),
        torch.nn.Dropout(0.5),  # drop 50% of the neuron
        torch.nn.ReLU(),
        torch.nn.Linear(N_HIDDEN, N_HIDDEN),
        torch.nn.Dropout(0.5),  # drop 50% of the neuron
        torch.nn.ReLU(),
        torch.nn.Linear(N_HIDDEN, 1),
    )
    
    for t in range(500): pred_drop = net_dropped(x) loss_drop = loss_func(pred_drop, y) optimizer_drop.zero_grad() loss_drop.backward() optimizer_drop.step() if t % 10 == 0: # change to eval mode in order to fix drop out effect net_dropped.eval() # parameters for dropout differ from train mode test_pred_drop = net_dropped(test_x) # change back to train mode net_dropped.train()

     

    pytorch做Batch Normalization:

    net.eval()固定整个网络参数,固定BN的参数,moving_mean 和moving_var,不懂这个看下图:

                if self.do_bn:
                    bn = nn.BatchNorm1d(10, momentum=0.5)
                    setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module
                    self.bns.append(bn)
    
        for epoch in range(EPOCH):
            print('Epoch: ', epoch)
            for net, l in zip(nets, losses):
                net.eval()              # set eval mode to fix moving_mean and moving_var
                pred, layer_input, pre_act = net(test_x)
    
                net.train()             # free moving_mean and moving_var
            plot_histogram(*layer_inputs, *pre_acts)  

    moving_mean   和   moving_var

    用tensorflow做dropout和BN时需要注意的地方

    dropout和BN都有一个training的参数表明到底是train还是test, 表明test那dropout就是不dropout,BN就是固定住了BN的参数;

    tf_is_training = tf.placeholder(tf.bool, None)  # to control dropout when training and testing
    
    # dropout net
    d1 = tf.layers.dense(tf_x, N_HIDDEN, tf.nn.relu)
    d1 = tf.layers.dropout(d1, rate=0.5, training=tf_is_training)   # drop out 50% of inputs
    
    d2 = tf.layers.dense(d1, N_HIDDEN, tf.nn.relu) d2 = tf.layers.dropout(d2, rate=0.5, training=tf_is_training) # drop out 50% of inputs
    d_out = tf.layers.dense(d2, 1)
    for t in range(500): sess.run([o_train, d_train], {tf_x: x, tf_y: y, tf_is_training: True}) # train, set is_training=True if t % 10 == 0: # plotting plt.cla() o_loss_, d_loss_, o_out_, d_out_ = sess.run( [o_loss, d_loss, o_out, d_out], {tf_x: test_x, tf_y: test_y, tf_is_training: False} # test, set is_training=False )
        def add_layer(self, x, out_size, ac=None):
            x = tf.layers.dense(x, out_size, kernel_initializer=self.w_init, bias_initializer=B_INIT)
            self.pre_activation.append(x)
            # the momentum plays important rule. the default 0.99 is too high in this case!
            if self.is_bn: x = tf.layers.batch_normalization(x, momentum=0.4, training=tf_is_train)    # when have BN
            out = x if ac is None else ac(x)
            return out
     

    当BN的training的参数为train时,只是表示BN的参数是可变化的,并不是代表BN会自己更新moving_mean 和moving_var,因为这个操作是前向更新的op,在做train之前必须确保moving_mean 和moving_var更新了,更新moving_mean 和moving_var的操作在tf.GraphKeys.UPDATE_OPS

            # !! IMPORTANT !! the moving_mean and moving_variance need to be updated,
            # pass the update_ops with control_dependencies to the train_op
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train = tf.train.AdamOptimizer(LR).minimize(self.loss)
  • 相关阅读:
    java Thread和Runnable区别
    java sleep() 、yield()
    Java Thread.join()方法
    内存管理_深入剖析volatile关键字
    内存管理_JAVA内存管理
    内存管理_原子性、可见性、有序性
    小程序wx.showToast()方法实现文字换行
    常用表单校验(手机号、固话、身份证、真是姓名、邮箱、银行卡)
    通过CSS实现 文字渐变色 的两种方式
    substring和substr的区别
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/10632730.html
Copyright © 2011-2022 走看看