zoukankan      html  css  js  c++  java
  • tensorflow-softmax

    之前在softmax多分类中讲到多用交叉熵作为损失函数,这里顺便写个例子,tensorlflow练手。

    # encoding:utf-8
    import tensorflow as tf
    import input_data
    
    ### softmax 回归
    
    # 自动下载安装数据集
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    
    # 图片    28 * 28 = 784
    x=tf.placeholder('float',[None,784])    # 特征数 784
    # 初始化参数
    w=tf.Variable(tf.zeros([784,10]))       # 10是输出维度,0-9数字的独热编码
    b=tf.Variable(tf.zeros([10]))
    
    # 模型
    y=tf.nn.softmax(tf.matmul(x,w)+b)
    
    
    ### 训练模型
    y_=tf.placeholder('float',[None,10])     # 10维
    
    # 损失函数  交叉熵
    cross_entropy=-tf.reduce_sum(y_*tf.log(y))      # reduce_sum 计算张量的所有元素之和    所有图片的交叉熵综合
    
    # 优化算法 梯度下降
    train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)      # 0.01 学习率  最小化 损失函数
    
    # 初始化变量
    init=tf.initialize_all_variables()
    
    # 启动图  会话
    sess=tf.Session()
    sess.run(init)      # 初始化变量
    
    # 迭代,训练参数
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)            # 训练集
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    
    
    # 模型评估 准确率
    correct_predict = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))       # 输出布尔值,即预测与真实是否一样
    accuracy = tf.reduce_mean(tf.cast(correct_predict, 'float'))        # 将布尔值转化成浮点数,然后求平均, 正确/总数
    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))      # 测试集   0.9194
  • 相关阅读:
    CSS3笔记
    HTML5新标签
    前端工程师面试题JavaScript部分(第五季)
    前端工程师面试题JavaScript部分(第四季)
    前端工程师面试题JavaScript部分(第三季)
    前端组件开发方式(二)
    前端组件开发方式(一)
    面向对象的代码研究(一)
    ServiceDemo,ClientDemo Socket chat
    Socket(java基础)
  • 原文地址:https://www.cnblogs.com/yanshw/p/10460175.html
Copyright © 2011-2022 走看看