zoukankan      html  css  js  c++  java
  • 由浅入深之Tensorflow(2)----logic_regression实现

    import tensorflow as tf
    import numpy as np
    
    from tensorflow.examples.tutorials.mnist import input_data
    
    def initWeights(shape):
        return tf.Variable(tf.random_normal(shape, stddev = 0.1))
    
    def initBiases(shape):
        return tf.Variable(tf.random_normal(shape, stddev = 0.1))
    
    def model(X, weights, baises):
        return tf.matmul(X, weights) + baises
    
    mnist = input_data.read_data_sets('MNIST_data/', one_hot = True)
    trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
    
    X = tf.placeholder('float', [None, 784])
    Y = tf.placeholder('float', [None, 10])
    
    learning_rate = 0.05
    epcoh = 100
    
    weights = initWeights([784,10])
    biases = initBiases([10])
    
    y_ = model(X, weights, biases)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y_, Y))
    train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    predict_op = tf.argmax(y_, 1)
    
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        for i in range(epcoh):
            for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)):
                sess.run(train_op, feed_dict = {X: trX[start:end], Y: trY[start:end]})
            print (i, np.mean(np.argmax(teY, axis=1) == sess.run(predict_op, feed_dict={X: teX})))
  • 相关阅读:
    Android控制软键盘的现实与隐藏
    Android调用手机浏览器
    DatePicker隐藏年/月/日
    ecplise中设置字符编码
    Git问题总结
    Git的简单使用
    资源
    equals和==
    class文件查看
    Class file collision
  • 原文地址:https://www.cnblogs.com/upright/p/6136199.html
Copyright © 2011-2022 走看看