zoukankan      html  css  js  c++  java
  • CIFAR10-网络训练技术

    1、数据增强

      1)随机裁剪

      在原始图片的每一边pad 4个 pixels,然后再裁切成32*32的图片

    distorted_images = tf.image.resize_image_with_crop_or_pad(record_images, 
                                                              imageHeight+8, imageWidth+8)
    distorted_images = tf.random_crop(distorted_images, size = [batch_size, imageHeight, imageHeight, 3])

      2)随机翻转、调节亮度和对比度、标准化

    for i in xrange(len(distorted_images)):
        distorted_images[i] = tf.image.random_flip_left_right(distorted_images[i])
        distorted_images[i] = tf.image.random_brightness(distorted_images[i], max_delta=63)
        distorted_images[i] = tf.image.random_contrast(distorted_images[i], lower=0.2, upper=1.8)
        distorted_images[i] = tf.image.per_image_standardization(distorted_images[i])

      3)self.data = self.data / 127.5 - 1 #让标签在 -1~1之间

    2、学习率

      1)线性衰减

      2)指数衰减

      3)按区间衰减

    global_step = tf.Variable(0, trainable=False)
    boundaries = [10000, 15000, 20000, 25000]
    values = [0.1, 0.05, 0.01, 0.005, 0.001]
    learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

       总结一下众多paper中训练方法,初始0.01开始训练,validation acc不变时,除以10,一直到0.00001停止。

    3、weight decay

    #Add the l2 weights to the loss
    #Add weight decay to the loss.
    l2_loss = weight_decay * tf.add_n(
    # loss is computed using fp32 for numerical stability.
    [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
    tf.summary.scalar('l2_loss', l2_loss)
    loss = cross_entropy_mean + l2_loss

    4、优化器

    #Define the optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
     
    #Relate to the batch normalization
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_op = optimizer.minimize(loss, global_step)
  • 相关阅读:
    C++服务器开发之基于对象的编程风格
    C++服务器开发之笔记三
    InstallShield 脚本语言学习笔记
    Win.ini和注册表的读取写入
    MFC中如何画带实心箭头的直线
    UE4新手引导之下载和安装虚幻4游戏引擎
    SQLServer · BUG分析 · Agent 链接泄露分析(转载)
    mysql5.6版本开启数据库查询日志方法
    远程读取json数据并写入数据库
    js中的什么时候需要用new来实例化?
  • 原文地址:https://www.cnblogs.com/wt-seu/p/12382130.html
Copyright © 2011-2022 走看看