zoukankan      html  css  js  c++  java
  • tensorflow学习之(八)使用dropout解决overfitting(过拟合)问题

    #使用dropout解决overfitting(过拟合)问题
    #如果有dropout,在feed_dict的参数中一定要加入dropout的值
    import tensorflow as tf
    from sklearn.datasets import load_digits
    from sklearn.cross_validation import train_test_split
    from sklearn.preprocessing import LabelBinarizer
    
        #load datas 导入klearn中digits手写字体数据集
    digits = load_digits()
    X = digits.data         #加载从0-9的数字集
    y = digits.target       #y为X所对应的标签
        #fit(y) 返回一个实例
        #fit_transform(y) 返回 和y一样的形状
    y = LabelBinarizer().fit_transform(y)
        #train_test_split(train_data,train_target,test_size=0.4, random_state=0)
        # 是交叉验证中常用的函数,功能是从样本中随机的按比例选取train_data和test_data
        #参数解释:
        #train_data:所要划分的样本特征集
        #train_target:所要划分的样本结果
        #test_size:样本占比,如果是整数的话就是样本的数量
        #random_state:是随机数的种子。
        #随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。
        # 比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。
        #随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:
        #种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3)
    
    '''
    #fit_transform()、inverse_transform使用的例子
    #程序
    from sklearn import preprocessing
    feature = [[0,1], [1,1], [0,0], [1,0]]
    label= ['yes', 'no', 'yes', 'no']
    lb = preprocessing.LabelBinarizer() #构建一个转换对象
    Y = lb.fit_transform(label)
    re_label = lb.inverse_transform(Y)#还原之前的label
    print(Y)
    print(re_label)
    #结果
    [[1]
     [0]
     [1]
     [0]]
    ['yes' 'no' 'yes' 'no']
    '''
    
    # 定义一个神经层
    def add_layer(inputs, in_size, out_size,layer_name, activation_function=None):
        #add one more layer and return the output of the layer
        Weights = tf.Variable(tf.random_normal([in_size, out_size]))
        biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
        Wx_plus_b = tf.matmul(inputs, Weights) + biases
        Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_pro)#使用dropout机制,解决overfitting问题
        if activation_function is None:
            outputs = Wx_plus_b
        else:
            outputs = activation_function(Wx_plus_b)
        tf.summary.histogram(layer_name+'/output',outputs)
        return outputs
    
    #define placeholder for inputs to network
    keep_pro = tf.placeholder(tf.float32)#dropout机制使用
    xs = tf.placeholder(tf.float32, [None, 64])  # none表示无论给多少个例子都行,64=8*8
    ys = tf.placeholder(tf.float32, [None, 10])  #表示10个需要识别的数字
    
    #add output layer
    l1 = add_layer(xs, 64, 50,'l1',activation_function=tf.nn.tanh)
    prediction = add_layer(l1, 50, 10,'l2', activation_function=tf.nn.softmax)
    
    #the error between prediction and real data
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1])) #loss function
    tf.summary.scalar('loss',cross_entropy)
    train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy)
    
    sess = tf.Session()
    merged = tf.summary.merge_all()
    sess.run(tf.initialize_all_variables())#tf.initialize_all_variables()以被弃用
    #sess.run(tf.global_variables_initializer())
    
    #summary writer goes in here
    train_writer = tf.summary.FileWriter("../../logs/train",sess.graph)
    test_writer = tf.summary.FileWriter("../../logs/test",sess.graph)
    
    
    
    
    for i in range(500):
        sess.run(train_step,feed_dict={xs: X_train, ys: y_train,keep_pro:0.6})#保持0.6的概率不被drop掉
        if i%50 == 0:
            # record loss
            train_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train,keep_pro:1})
            test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test,keep_pro:1})
            train_writer.add_summary(train_result, i)
            test_writer.add_summary(test_result, i)
  • 相关阅读:
    SSH中使用延迟加载报错Exception occurred during processing request: could not initialize proxy
    SSH整合方案二(不带hibernate.cfg.xml)
    SSH整合方案一(带有hibernate.cfg.xml)
    hibernate4整合spring3出现java.lang.NoClassDefFoundError: [Lorg/hibernate/engine/FilterDefinition;
    jquery实现图片上传前的预览
    EL11个内置对象
    linux修改主机名,关闭图形化界面,绑定ip地址,修改ip地址
    VMTurbo:应对散乱虚拟机的强劲工具
    虚拟架构与云系统监控与管理解决方案
    VMTurbo采用红帽企业虚拟化软件
  • 原文地址:https://www.cnblogs.com/Harriett-Lin/p/9591448.html
Copyright © 2011-2022 走看看