zoukankan      html  css  js  c++  java
  • CIFAR10-网络训练技术

    1、数据增强

      1)随机裁剪

      在原始图片的每一边pad 4个 pixels,然后再裁切成32*32的图片

    distorted_images = tf.image.resize_image_with_crop_or_pad(record_images, 
                                                              imageHeight+8, imageWidth+8)
    distorted_images = tf.random_crop(distorted_images, size = [batch_size, imageHeight, imageHeight, 3])

      2)随机翻转、调节亮度和对比度、标准化

    for i in xrange(len(distorted_images)):
        distorted_images[i] = tf.image.random_flip_left_right(distorted_images[i])
        distorted_images[i] = tf.image.random_brightness(distorted_images[i], max_delta=63)
        distorted_images[i] = tf.image.random_contrast(distorted_images[i], lower=0.2, upper=1.8)
        distorted_images[i] = tf.image.per_image_standardization(distorted_images[i])

      3)self.data = self.data / 127.5 - 1 #让标签在 -1~1之间

    2、学习率

      1)线性衰减

      2)指数衰减

      3)按区间衰减

    global_step = tf.Variable(0, trainable=False)
    boundaries = [10000, 15000, 20000, 25000]
    values = [0.1, 0.05, 0.01, 0.005, 0.001]
    learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

       总结一下众多paper中训练方法,初始0.01开始训练,validation acc不变时,除以10,一直到0.00001停止。

    3、weight decay

    #Add the l2 weights to the loss
    #Add weight decay to the loss.
    l2_loss = weight_decay * tf.add_n(
    # loss is computed using fp32 for numerical stability.
    [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
    tf.summary.scalar('l2_loss', l2_loss)
    loss = cross_entropy_mean + l2_loss

    4、优化器

    #Define the optimizer
    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
     
    #Relate to the batch normalization
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_op = optimizer.minimize(loss, global_step)
  • 相关阅读:
    【个人笔记】MySQL聚合函数、子查询
    artTemplate的使用总结
    【学习笔记】node.js入门基础
    【知了堂学习笔记】SQL查询基础语句(单表查询、多表查询)
    浅谈java final关键字
    3D魔幻旋转
    3分钟实现星空图
    深入浅出,谈谈面向对象几大特征
    html,css常用标签
    java中常用集合的理解
  • 原文地址:https://www.cnblogs.com/wt-seu/p/12382130.html
Copyright © 2011-2022 走看看