zoukankan      html  css  js  c++  java
  • Mnist字符识别-神经网络实现(TF框架)

    Mnist字符识别-神经网络实现(TF框架)

    该段代码即贴即用,先贴一下代码,有空的时候写个注释解析。大三的代码了,特别适合新手入门,现在都用Pytorch了。

    电脑用的tensorflow版本是1.13.1的,用CPU跑也挺快的。之前用GPU跑了半小时准确率能达到98%左右。

    代码

    # -*- coding:utf-8 -*-
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    from matplotlib import pyplot
    import matplotlib.pyplot as plt
    import numpy as np
    
    mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
    
    seed=547
    np.random.seed(seed)
    
    epoch_time = 20;
    ALPHY = 0.5
    batch_size = 10
    
    n_batch_all = mnist.train.num_examples // batch_size
    n_batch = 1000 // batch_size
    
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    
    def xavier_init(size):
        in_dim = size[0]
        xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
        return tf.random_normal(shape=size, stddev=xavier_stddev)
    
    W1 =tf.Variable(xavier_init([784, 30]))
    B1 = tf.Variable(tf.zeros([30]))
    
    L1 =  tf.nn.sigmoid(tf.matmul(x,W1) + B1)
    
    W2 =tf.Variable(xavier_init([30, 10]))
    B2 = tf.Variable(tf.zeros([10]))
    
    logit_prediction = tf.matmul(L1,W2) + B2
    prediction = tf.nn.sigmoid(logit_prediction)
    # MSE损失函数
    # loss = tf.reduce_mean(tf.square(y - prediction))
    
    #交叉熵损失函数
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit_prediction,labels=y)
    
    train_setup = tf.train.GradientDescentOptimizer(ALPHY).minimize(loss)
    
    init = tf.global_variables_initializer()
    
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    def getBatch(inputs):
        np.random.shuffle(inputs)
        batch = inputs[:10]
        fina_x = batch[:, :784]
        fina_y = batch[:, 784:794]
        return fina_x, fina_y
    
    def draw(train, text):
        names = range(0, epoch_time)
        names = [str(x) for x in list(names)]
        x = range(len(names))
    
    
        plt.plot(x, train, marker='o', mec='r', mfc='w', label='train_1000')
        plt.plot(x, text, marker='*', ms=10, label='train_all')
        plt.legend()  
        plt.xticks(x, names, rotation=1)
        plt.margins(0)
        plt.subplots_adjust(bottom=0.10)
        plt.xlabel('epoch')  
        plt.ylabel("accuracy")  
        pyplot.yticks([0, 0.5, 1])
        # plt.title("A simple plot") 
        plt.savefig('accuracy.jpg', dpi=900)
    
    def train_1000():
        sess.run(init)
        train = tf.zeros(epoch_time)
        # batch_xs_all, batch_ys_all = mnist.train.next_batch(1000);
        # print("X shape:", batch_xs_all.shape)
        # print("Y shape:", batch_ys_all.shape)
        X_mb, Y_mb = mnist.train.next_batch(1000)
        Y_mb = Y_mb.astype(np.float32)
        inputs = tf.concat(axis=1, values=[X_mb, Y_mb])
        inputs = inputs.eval(session=sess)
        train = train.eval(session=sess)
        for epoch in range(epoch_time):
            for batch in range(n_batch):
                fina_x, fina_y = getBatch(inputs)
                # batch_xs,batch_ys=mnist.train.next_batch(batch_size)
                sess.run(train_setup, feed_dict={x: fina_x, y: fina_y})
            acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
            train[epoch] = acc;
            print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))
        return train;
    
    def train_all():
        sess.run(init)
        text = tf.zeros(epoch_time)
        text = text.eval(session=sess)
        for epoch in range(epoch_time):
            for batch in range(n_batch_all):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                sess.run(train_setup, feed_dict={x: batch_xs, y: batch_ys})
            acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
            text[epoch] = acc;
            print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))
    
    with tf.Session() as sess:
        p1 = train_1000();
        p2 = train_all();
        draw(p1, p2)
    

    结果

    Iter0, Testing Accuracy=0.0982 
    Iter1, Testing Accuracy=0.2913 
    Iter2, Testing Accuracy=0.2973 
    Iter3, Testing Accuracy=0.3493 
    Iter4, Testing Accuracy=0.4311 
    Iter5, Testing Accuracy=0.3789 
    Iter6, Testing Accuracy=0.49   
    Iter7, Testing Accuracy=0.4547 
    Iter8, Testing Accuracy=0.4079 
    Iter9, Testing Accuracy=0.4748 
    Iter10, Testing Accuracy=0.564 
    Iter11, Testing Accuracy=0.5026
    Iter12, Testing Accuracy=0.6053
    Iter13, Testing Accuracy=0.6379
    Iter14, Testing Accuracy=0.5863
    Iter15, Testing Accuracy=0.6443
    Iter16, Testing Accuracy=0.6487
    Iter17, Testing Accuracy=0.5809
    Iter18, Testing Accuracy=0.6616
    Iter19, Testing Accuracy=0.6465
    Iter0, Testing Accuracy=0.7625 
    Iter1, Testing Accuracy=0.864  
    Iter2, Testing Accuracy=0.8596 
    Iter3, Testing Accuracy=0.8694 
    Iter4, Testing Accuracy=0.9028 
    Iter5, Testing Accuracy=0.9046 
    Iter6, Testing Accuracy=0.902  
    Iter7, Testing Accuracy=0.9021 
    Iter8, Testing Accuracy=0.8874 
    Iter9, Testing Accuracy=0.9192 
    Iter10, Testing Accuracy=0.9175
    Iter11, Testing Accuracy=0.9226
    Iter12, Testing Accuracy=0.9233
    Iter13, Testing Accuracy=0.9156
    Iter14, Testing Accuracy=0.93  
    Iter15, Testing Accuracy=0.9251
    Iter16, Testing Accuracy=0.9232
    Iter17, Testing Accuracy=0.9176
    Iter18, Testing Accuracy=0.9287
    Iter19, Testing Accuracy=0.9273
    
  • 相关阅读:
    第十二周作业
    第十一周作业
    第十周作业
    第九周作业*
    #**第八周作业+预习作业**
    第七周作业
    Linux 日志查看常用命令
    Linux tar命令
    Java 数组
    设计模式 观察者模式(Observer)
  • 原文地址:https://www.cnblogs.com/lwp-nicol/p/15262544.html
Copyright © 2011-2022 走看看