zoukankan      html  css  js  c++  java
  • 利用Tensorflow实现逻辑回归模型

     官方mnist代码:

    #下载Mnist数据集
    import tensorflow.examples.tutorials.mnist.input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    #Tensorflow实现回归模型
    import tensorflow as tf
    
    #定义变量为float型,行因为不确定先给无穷大None;列给28*28=784
    x = tf.placeholder("float", [None, 784])
    y_ = tf.placeholder("float", [None,10])
    
    #向量相乘y = wx + b,w的行即为x的列,否则无法相乘;输出大小给10(因为是一个10分类任务)
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    
    #逻辑回归模型
    #nn模块下的softmax解决多分类问题,参数:是一个预测值,即wx+b会计算出一个分值
    #softmax 完成归一化操作
    #得到的y是一个预测结果
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    
    #计算损失值:-log(p);求均值:reduce_mean
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y)), reduction_indices=1))
    
    #训练模型
    #优化器使用梯度下降
    learning_rate = 0.01    #学习率
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    
    #评估模型
    #比较一下预测值和这个标记的Label值,如果一致返回true,否则返回false
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    #计算准确率tf.cast
    #计算均值tf.reduce_mean
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

    常用函数:

    arr = np.array([
        [31,23,4,24,27,34],[18,3,25,0,6,35],[28,14,33,22,20,8]
    ])
    #按列找出每列的最大值的索引 0按列 1按行
    tf.argmax(arr, 0).eval()
    #计算矩阵的维数
    tf.rank(arr).eval()
    #计算矩阵的行和列
    tf.shape(arr).eval()
  • 相关阅读:
    BZOJ2219数论之神——BSGS+中国剩余定理+原根与指标+欧拉定理+exgcd
    Luogu 3690 Link Cut Tree
    CF1009F Dominant Indices
    CF600E Lomsat gelral
    bzoj 4303 数列
    CF1114F Please, another Queries on Array?
    CF1114B Yet Another Array Partitioning Task
    bzoj 1858 序列操作
    bzoj 4852 炸弹攻击
    bzoj 3564 信号增幅仪
  • 原文地址:https://www.cnblogs.com/hunttown/p/6807455.html
Copyright © 2011-2022 走看看