zoukankan      html  css  js  c++  java
  • 【TF-3-2】Tensorflow-mnist的手写识别

    一、代码

    # Author:yifan
    import matplotlib as mpl
    from tensorflow.examples.tutorials.mnist import input_data   #从tensorflow中导入数据
    import  tensorflow as tf
    ## 设置属性防止中文乱码及拦截异常信息
    mpl.rcParams['font.sans-serif'] = [u'SimHei']
    mpl.rcParams['axes.unicode_minus'] = False
    
    #加载数据,从tensorflow中导入到本地
    mnist = input_data.read_data_sets("data/",one_hot=True)
    # print(mnist.train.num_examples)   #55000
    # print(mnist.train.labels.shape)   #(55000, 10)
    
    #构建神经网络(4层,1 input,2 hidden 1 output)
    n_unit_hidden_1 = 256  #第一层的神经元数目
    n_unit_hidden_2 = 128  #第二层的神经元数目
    n_input = 784  #输入是28*28像素的。
    n_classes = 10
    
    #定义占位符
    x = tf.placeholder(tf.float32,shape=[None,n_input],name='x')  #None 表示输入不限制
    y = tf.placeholder(tf.float32,shape=[None,n_classes],name='y')
    
    #构建初始化
    weights = {
        "w1" : tf.Variable(tf.random.normal(shape=[n_input,n_unit_hidden_1],stddev=0.1)),
        "w2" : tf.Variable(tf.random.normal(shape=[n_unit_hidden_1,n_unit_hidden_2],stddev=0.1)),
        "out" : tf.Variable(tf.random.normal(shape=[n_unit_hidden_2,n_classes],stddev=0.1))
    }
    
    biases = {
        "b1" : tf.Variable(tf.random.normal(shape=[n_unit_hidden_1],stddev=0.1)),#注意,这里shape只有一个参数
        "b2" : tf.Variable(tf.random.normal(shape=[n_unit_hidden_2],stddev=0.1)),
        "out" : tf.Variable(tf.random.normal(shape=[n_classes],stddev=0.1))
    }
    
    def multiplayer_perceotron(_X,_weights,_biases):
        #第一层 ->第二层,  input-- hidden1
        layer1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
        #第二层 ->第三层,  hidden1-- hidden2
        layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1,_weights['w2']),_biases['b2']))
        return tf.matmul(layer2,_weights['out'])+ _biases['out']
    
    # 获取预测值
    pred = multiplayer_perceotron(x,weights,biases)
    #构建损失函数,计算softmax中的每个样本交叉熵,logits指定预测值,lables指定实际值
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
    #使用梯度下降,最小化误差损失
    train = tf.train.GradientDescentOptimizer(learning_rate= 0.01).minimize(cost)
    #得到预测类别是哪一个
    predict = tf.equal(tf.argmax(pred ,axis=1),tf.argmax(y,axis=1))
    #正确率
    acc = tf.reduce_mean(tf.cast(predict,tf.float32))
    #初始化
    init = tf.global_variables_initializer()
    #执行模型
    batch_size = 100
    display_step = 4
    
    with tf.Session() as  sess:
        sess.run(init)
        #保存
        saver = tf.train.Saver()
        epoch  = 0
        while True:
            avg_cost = 0
            #计算出中的批次
            total_batch = int(mnist.train.num_examples / batch_size)
            #迭代更新
            for i in range(total_batch):
                batch_xs,batch_ys = mnist.train.next_batch(batch_size)
                feeds = {x:batch_xs ,y:batch_ys}
                #模型训练
                sess.run(train,feed_dict=feeds)
                #获取损失函数值
                avg_cost += sess.run(cost,feed_dict=feeds)
            avg_cost = avg_cost / total_batch #计算均值损失
    
            if (epoch + 1) % display_step == 0 :
                print("批次:%03d 损失函数值:%.9f" %(epoch , avg_cost))
                feeds = {x:mnist.train.images , y:mnist.train.labels}
                train_acc = sess.run(acc,feed_dict= feeds)
                print("训练准确度:%.3f" %train_acc)
                feeds = {x:mnist.test.images,y:mnist.test.labels}
                test_acc = sess.run(acc,feed_dict=feeds)
                print("测试准确度:%.3f" %test_acc)
    
                if train_acc > 0.8 and test_acc > 0.8:
                    saver.save(sess,"./mn/model")
                    break
            epoch += 1
        #模型可视化输出
        writer = tf.summary.FileWriter("./mn/graph",tf.get_default_graph())
        writer.close()
    

      

    二、结果

    三、图形中展示

    3.1 Copy 图形的路径

    3.2 打开命令行,tensorboard --logdir 3.1的路径

    3.3 在浏览器中输入:

    3.4 结果:

    我在代码中加入这些:(最早的代码中没有,显示的为上图)

    上图将变成这样:

  • 相关阅读:
    解决Cell重绘导致 重复的问题
    给Cell间隔颜色
    NSUserDefault 保存自定义对象
    xcode6 下载
    unrecognized selector sent to instance
    16进制颜色转换
    local unversioned, incoming add upon update问题
    应用崩溃邮件通知
    TabBar变透明
    代码手写UI,xib和StoryBoard间的博弈,以及Interface Builder的一些小技巧
  • 原文地址:https://www.cnblogs.com/yifanrensheng/p/12594367.html
Copyright © 2011-2022 走看看