zoukankan      html  css  js  c++  java
  • tensorflow入门(三)

    三种代价函数

    1,二次代价函数

      式子代表预测值与样本值的差得平方和

    由于使用的是梯度下降法,我们对变量w,b分别求偏导:

    这种函数对于处理线性的关系比较好,但是如果遇到s型函数(如下图所示),效率不高。

    从图中我们看出:当我们想要趋近于1时,B点接近于1,变化趋势变小(很正确),A点与1距离较远,变化趋势较大(很正确),C点(假设在x = -3处)远离1,变化趋势很小(发生错误),因此,二次代价函数中单凭梯度的大小决定变化的快慢是不对的。

    由此我们引出了第二个代价函数——交叉熵代价函数

    2,交叉熵代价函数

    右边是babababab的推导过程,最终得到表达式:

    结论如上↑

    3,对数释然代价函数

    我们对上次的代码进行修改,修改了loss的函数

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 100
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784]) #图片
    y = tf.placeholder(tf.float32,[None,10]) #标签
    
    #创建一个简单的神经网络
    w = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    prediction = tf.nn.softmax(tf.matmul(x,w)+b)
    
    #二次代价函数
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
    #梯度下降算法
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个bool类型的列表中,argmax()返回一维张量中最大值所在的位置,equal函数判断两者是否相等
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    #求准确率,cast:将bool型转换为0,1然后求平均正好算到是准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(21):
            for batch in range(batch_size):#next_batch:不断获取下一组数据
                batch_xs,batch_ys = mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
    
            acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
            print("Iter "+str(epoch)+",Testing Accuracy "+str(acc))

    得到结果:

    Iter 0,Testing Accuracy 0.7051
    Iter 1,Testing Accuracy 0.8002
    Iter 2,Testing Accuracy 0.8105
    Iter 3,Testing Accuracy 0.8183
    Iter 4,Testing Accuracy 0.8218
    Iter 5,Testing Accuracy 0.8268
    Iter 6,Testing Accuracy 0.8585
    Iter 7,Testing Accuracy 0.8724
    Iter 8,Testing Accuracy 0.8866
    Iter 9,Testing Accuracy 0.8899
    Iter 10,Testing Accuracy 0.8936
    Iter 11,Testing Accuracy 0.8968
    Iter 12,Testing Accuracy 0.8979
    Iter 13,Testing Accuracy 0.8985
    Iter 14,Testing Accuracy 0.8994
    Iter 15,Testing Accuracy 0.9007
    Iter 16,Testing Accuracy 0.9015
    Iter 17,Testing Accuracy 0.9024
    Iter 18,Testing Accuracy 0.9047
    Iter 19,Testing Accuracy 0.9041
    Iter 20,Testing Accuracy 0.9057

    哇,准确率提高很多诶!

     ——

  • 相关阅读:
    2017-2018-1 20155204 《信息安全系统设计基础》第七周学习总结
    2017-2018-1 20155203 20155204 实验二 固件程序设计
    2017-2018-1 20155204 《信息安全系统设计基础》第六周学习总结
    20155204《信息安全系统设计》第六周课下作业:缓冲区溢出漏洞实验
    《信息安全技术》实验2——Windows口令破解
    2017-2018-1 20155204 《信息安全系统设计基础》第五周学习总结
    实现mypwd
    2017-2018-1 20155331 《信息安全系统设计基础》第九周学习总结
    2017-2018-1 20155331 《信息安全系统设计基础》第八周课堂测试
    2017-2018-1 20155331 《信息安全系统设计基础》第八周学习总结
  • 原文地址:https://www.cnblogs.com/cunyusup/p/8636255.html
Copyright © 2011-2022 走看看