zoukankan      html  css  js  c++  java
  • tensorflow学习之路---解决过拟合

    '''

    思路:
    1、调用数据集 2、定义用来实现神经元功能的函数(包括解决过拟合) 3、定义输入和输出的数据
    4、定义隐藏层(函数)和输出层(函数) 5、分析误差和优化数据(改变权重)
    6、执行神经网络

    '''
    import tensorflow as tf
    from sklearn.datasets import load_digits
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import LabelBinarizer


    #调用数据
    digits = load_digits()#下载数据
    X = digits.data #样本特征
    Y = digits.target #样本准确值
    y = LabelBinarizer().fit_transform(Y) #将数据转化为二值数组

    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = 0.3)#分配数据
    '''

    扩展知识点
    train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和test data,形式为: 
    X_train,X_test, y_train, y_test = cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)
    参数代表含义: 
    train_data:所要划分的样本特征集 
    train_target:所要划分的样本结果 
    test_size:样本占比,如果是整数的话就是样本的数量 
    random_state:是随机数的种子。
    '''
    print(len(X_train))

    #定义用来实现神经元功能的函数
    def add_layer(inputs,in_size,out_size,keep_prob,layer_name,activation_function=None):
      Weights = tf.Variable(tf.random_normal([in_size,out_size]))
      biases = tf.Variable(tf.zeros([1,out_size])+0.1)
      Wx_plus_Bx = tf.matmul(inputs,Weights)+biases
    #在这里处理过拟合
      Wx_plus_b = tf.nn.dropout(Wx_plus_Bx,keep_prob)
      if activation_function==None:
        outputs = Wx_plus_b
      else:
        outputs = activation_function(Wx_plus_b)

      tf.summary.histogram(layer_name+'/outputs',outputs)
      return outputs

    #定义输入和输出的数据

    x_data = tf.placeholder(tf.float32,[None,64])#这是因为sklearn中的手写图片的像素和、为8*8
    y_data = tf.placeholder(tf.float32,[None,10])#数字只有10个
    keep_prob = tf.placeholder(tf.float32)#定义过拟合数

    #定义隐藏层和输出层
    layer = add_layer(x_data,64,50,keep_prob,'l1',tf.nn.tanh)#隐藏层
    prediction = add_layer(layer,50,10,keep_prob,'l2',tf.nn.softmax)#输出层

    #分析误差和优化数据
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_data*tf.log(prediction),reduction_indices=[1]))

    scalar_loss = tf.summary.scalar('loss',cross_entropy)
    train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy)

    #初始化所有的变量
    init = tf.global_variables_initializer()

    merged = tf.summary.merge_all()#定义一个图框

    '''
    因为sess是在sess的时候才出现的,所以应该写在sess的面
    train_writer = tf.summary.FileWriter('logs/train',sess.graph)
    test_writer = tf.summary.FileWriter('logs/test',sess.graph)
    '''

    #执行
    with tf.Session()as sess:
      sess.run(init)
      #写入网页,这当中只有histogram和scaler同时出现才能写入网页
      train_writer = tf.summary.FileWriter('Logs/train',sess.graph)
      test_writer = tf.summary.FileWriter('Logs/test',sess.graph)
      for i in range(1000):
        sess.run(train_step,feed_dict = {x_data:X_train,y_data:y_train,keep_prob:0.6})
        if i%50==0:
          train_result = sess.run(merged,feed_dict={x_data:X_train,y_data:y_train,keep_prob:1})
    '''
    这个merged会自动的将预测值的精确度求出来
    '''
          test_result = sess.run(merged,feed_dict={x_data:X_test,y_data:y_test,keep_prob:1})
          train_writer.add_summary(train_result,i)#将数据划入图中
          test_writer.add_summary(test_result,i)#将数据划入图中

    '''
    这里出现一个错误:就是test_result = sess.run(scalar_loss,feed_dict={x_data:X_test,y_data:y_test,keep_prob:1})
    train_result = sess.run(scalar_loss,feed_dict={x_data:X_train,y_data:y_train,keep_prob:1})中的scalar_loss
    改为merged的时候,再次执行就会报错
    解决办法
    1、我们可以关机,然后把logs文件里面的东西删除,然后在执行一次。因为他是系统日志文件
    2、由于我,这里只是想损失函数loss通过tensorboard显示出来而已,并且字典表也正常赋值了:
    result = sess.run(merged,feed_dict={xs:x_data,ys:y_data})
    一切都很正常,想来想去感觉这个函数应该可以采用其他方式替换:
    merged = tf.summary.merge_all()
    这是tensorflow提供的合并所有summary信息的api,但是我只是想合并损失函数loss的summary
    '''

  • 相关阅读:
    码农提高工作效率-黄博文
    myeclipse与tomcat,运行jsp程序
    Ultraedit和写字板修改Tomcat 6.0的server.xml不生效
    MySQL5.5.33对应的JDBC驱动包怎样使用?
    Java是用JDBC连接MySQL数据库
    myeclipse trial expired暂时解决办法
    Json数据使用及学习方法
    在C#中使用json字符串
    vs2012换肤功能,vs2012主题及自定义主题
    给Notepad++换主题
  • 原文地址:https://www.cnblogs.com/MyUniverse/p/9432221.html
Copyright © 2011-2022 走看看