zoukankan      html  css  js  c++  java
  • 4 TensorFlow入门之dropout解决overfitting问题

    ————————————————————————————————————

    写在开头:此文参照莫烦python教程(墙裂推荐!!!)

    ————————————————————————————————————

    dropout解决overfitting问题

    • overfitting:当机器学习学习得太好了,就会出现过拟合(overfitting)问题。所以,我们就要采取一些措施来避免过拟合的问题。此实验就来看一下dropout对于解决过拟合问题的效果。
    • 例子实验内容:识别手写数字。此实验的步骤和上一篇的识别手写数字步骤很相似。
    • 例子实验的数据集:sklearn中的datasets

    • 主要运用的函数tf.nn.dropout()

    • 主要参数keep_prob。keep_prob表示留下来的结果的百分比,比如你要drop0.4,那么keep_prob就为0.6
    import tensorflow as tf
    from sklearn.datasets import load_digits
    from sklearn.cross_validation import train_test_split
    from sklearn.preprocessing import LabelBinarizer
    
    #加载数据
    digits = load_digits()
    X = digits.data
    y = digits.target
    y = LabelBinarizer().fit_transform(y)  #把数字变成1x10的向量
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = .3)  #把数据分成train数据和test数据
    
    #定义添加层
    def add_layer(inputs,in_size,out_size,activation_function=None):
        #定义添加层内容,返回这层的outputs
        Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一个in_size行、out_size列的矩阵,开始时用随机数填满
        biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一个1行out_size列的矩阵,用0.1填满
        Wx_plus_b = tf.matmul(inputs,Weights)+biases  #预测
        #实现dropout,keep_drop为丢弃后剩下的百分比
        Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)
        if activation_function is None:  #如果没有激励函数,那么outputs就是预测值
            outputs = Wx_plus_b
        else:  #如果有激励函数,那么outputs就是激励函数作用于预测值之后的值
            outputs = activation_function(Wx_plus_b)
        return outputs
    
    #定义计算正确率的函数
    def t_accuracy(t_xs,t_ys):
        global prediction
        y_pre = sess.run(prediction,feed_dict={xs:t_xs,keep_prob:1})#测试结果不dropout
        correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
        accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
        result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys,keep_prob:1})
        return result
    
    #定义输入输出值,和keep_drop值
    keep_prob = tf.placeholder(tf.float32)
    xs = tf.placeholder(tf.float32, [None, 64])  # 8x8
    ys = tf.placeholder(tf.float32, [None, 10])
    
    #添加层
    l1 = add_layer(xs, 64, 50,activation_function=tf.nn.tanh)
    prediction = add_layer(l1, 50, 10,activation_function=tf.nn.softmax)
    
    #误差
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))  # loss
    
    #训练
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    
    #开始训练
    sess = tf.Session()
    merged = tf.summary.merge_all()
    init = tf.global_variables_initializer()
    sess.run(init)
    for i in range(1000):
        # 设置keep_drop为1,即不进行dropout
        sess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})
        if i % 50 == 0:
            # 输出正确率
            print (t_accuracy(X_test,y_test)) 
    0.20925926
    0.7574074
    0.81296295
    0.8388889
    0.85555553
    0.8537037
    0.84814817
    0.8537037
    0.85555553
    0.8537037
    0.85555553
    0.8537037
    0.8574074
    0.85555553
    0.8574074
    0.8574074
    0.8611111
    0.8574074
    0.85925925
    0.8611111
    
    for i in range(1000):
        # 设置keep_drop为0.5
        sess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 0.5})
        if i % 50 == 0:
            # 输出正确率
            print (t_accuracy(X_test,y_test)) 
    0.86851853
    0.89444447
    0.91481483
    0.9166667
    0.91481483
    0.9222222
    0.9259259
    0.9222222
    0.9296296
    0.94074076
    0.94074076
    0.9351852
    0.9351852
    0.9351852
    0.9351852
    0.93333334
    0.94074076
    0.9351852
    0.93703705
    0.9351852
    

    由上面的结果可知,当dropout为0.5时,效果明显比一点儿也不丢弃的好!


    *点击[这儿:TensorFlow]发现更多关于TensorFlow的文章*


  • 相关阅读:
    <JSP> 入门
    <Html> 标签
    <MyBatis>入门八 工作原理
    <MyBatis>入门七 缓存机制
    <Zookeeper>入门 概念
    <SpringMvc>入门七 拦截器
    <SpringMvc>入门六 异常处理
    <Ajax> 入门
    <设计模式> 代理模式 Proxy Pattern
    <SpringMvc>入门五 文件上传
  • 原文地址:https://www.cnblogs.com/surecheun/p/9648967.html
Copyright © 2011-2022 走看看