zoukankan      html  css  js  c++  java
  • 机器学习——TensorFlow之数字体识别流程

    import tensorflow as tf
    # 导入mnist数据集
    # 分析mnist样本特点以及定义变量
    # 构建模型
    # 训练模型并输出中间状态参数
    # 测试模型
    # 保存模型
    # 读取模型
    
    
    # 导入mnist数据集
    from tensorflow.examples.tutorials.mnist import input_data
    mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)
    
    # 分析图片的特点,定义变量
    x=tf.placeholder(tf.float32,shape=[None,784])
    y=tf.placeholder(tf.float32,shape=[None,10])
    
    # 构建模型
    W=tf.Variable(tf.zeros([784,10]))
    
    b=tf.Variable(tf.zeros([10]))
    
    # z表示证据
    z=tf.matmul(x,W)+b
    # pred表示是每个数字的可能
    pred=tf.nn.softmax(z)
    # 损失函数,交叉熵,定义反向传播的结构
    loss=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
    
    learn_rate=0.01
    
    # 优化器,梯度下降法
    optimizer=tf.train.GradientDescentOptimizer(learn_rate).minimize(loss)
    
    # 训练次数
    epochs=25
    
    # 批次大小
    batch_size=100
    
    # 把中间具体信息显示出来
    display_step=1
    
    with tf.Session() as sess:
        # 初始化全局变量
        sess.run(tf.global_variables_initializer())
        # 开始训练
        for epoch in range(epochs):
            # 取值大小
            avg_loss=0
            total_loss=0
            total_batch=int(mnist.train.images.shape[0]/batch_size)
            for i in range(total_batch):
                # 从数据集中按照batch_size大小取值
                batch_xs,batch_ys=mnist.train.next_batch(batch_size)
                # 运行优化器
                _,c=sess.run([optimizer,loss],feed_dict={x:batch_xs,y:batch_ys})
                # 计算损失值得平均值
                total_loss+=c
            avg_loss=total_loss/total_batch
            if((epoch+1)%display_step==0):
                print('Epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_loss))
        print('########################Finished!#############################
    ')
        # 测试模型
        print('########################Begin Test############################
    ')
        correct_predict=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
        accuracy=tf.reduce_mean(tf.cast(correct_predict,tf.float32))
        print('Accuracy:',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
        print('########################Save Model############################
    ')
        saver= tf.train.Saver()
        save_path='log/'
        saver.save(sess,save_path)
        print('saved Successfully at :',save_path)
    # 保存模型
    
    
    

    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    [转+]C语言复杂声明
    c和c++数组初始化一点小区别
    [转]Linux ftp命令的使用方法
    Ubuntu 12.04 英文版中文输入法设置
    [转]Android手机中获取手机号码和运营商信息
    把google地圖放在Crm Entity中
    为什么报表里面记录的创建时间 比我们电脑客户端的世界时间 隔8个小时?这个是什么原因?
    print style Iframe
    取出MSCRM父窗口的欄位的值
    Display Fetch in IFRAME – Part 2
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13309445.html
Copyright © 2011-2022 走看看