zoukankan      html  css  js  c++  java
  • TensorFlow基础入门(四)

    注意:本部分的ppt来源于中国大学mooc网站:https://www.icourse163.org/learn/ZUCC-1206146808?tid=1206445215&from=study#/learn/content?type=detail&id=1211168244&cid=1213754001

    #MNIST手写数字识别数据集
    import tensorflow as tf 
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    import numpy as np 
    import matplotlib.pyplot as plt
    
    mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
    #了解MNIST手写数字识别数据集
    print("训练集train数量:",mnist.train.num_examples,
          ",验证集 validation数量:",mnist.validation.num_examples,
          ",测试集 test 数量:",mnist.test.num_examples)
    print("train image shape:",mnist.train.images.shape,
          "labels shape:",mnist.train.labels.shape)

    全部源码:

    #MNIST手写数字识别数据集
    import tensorflow as tf 
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    import numpy as np 
    import matplotlib.pyplot as plt
    import os
    #读取相关的数据
    mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
    #定义待输入数据的占位符
    #mnist中每张图片共有28*28=784个像素点
    x=tf.placeholder(tf.float32,[None,784],name="X")
    #0-9一共10个数字====》10个类别
    y=tf.placeholder(tf.float32,[None,10],name="y")
    #定义模型变量
    '''
    在本案例中,以正态分布的随机数初始化权重W,以常数0初始化偏置b
    '''
    #定义变量
    w=tf.Variable(tf.random_normal([784,10]),name="w")
    b=tf.Variable(tf.zeros([10]),name="b")
    #用单个神经元构建神经网络
    forward=tf.matmul(x,w)+b#前向计算
    pred=tf.nn.softmax(forward)#softmax分类
    #设置训练参数
    train_epochs=100#训练轮数
    batch_size=100#单次训练样本数(批次大小)
    total_batch=int(mnist.train.num_examples/batch_size)#一轮训练有多少批次
    display_step=1#显示粒度
    learning_rate=0.01#学习率
    #定义损失函数(定义交叉商的损失函数)
    loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
    #梯度下降优化器
    optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
    #检查预测类别tf.argmax(ored,1)与实际类别tf.argmax(y,1)的匹配情况
    correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
    #准确率,将布尔值转化为浮点数,并计算平均值
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    sess=tf.Session()#声明会话
    init=tf.global_variables_initializer()#变量初始化
    sess.run(init)
    
    #训练模型的保存
    #储存模型的粒子
    save_step=5
    #创建保存模型文件的目录
    ckpt_dir="./ckpt_dir/"
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    
    #声明完所有 变量之后,使用tf.train.Saver()
    saver=tf.train.Saver()
    
    
    #模型训练
    #开始训练
    for epoch in range(train_epochs):
        for batch in range(total_batch):
            xs,ys=mnist.train.next_batch(batch_size)#读取批次数据
            sess.run(optimizer,feed_dict={x:xs,y:ys})#执行批次训练
        #total_batch个批次训练完成后,使用验证数据计算误差与准确率:验证没有分批
        loss,acc=sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
    
        #打印训练过程中的详细信息
        if(epoch+1)%display_step==0:
            print("Train Epoch:",'%02d'%(epoch+1),"Loss=","{:.9}".format(loss),"Accuracy=","{:.4f}".format(acc))
        if(epoch+1)%save_step==0:
            saver.save(sess,os.path.join(ckpt_dir,'mnist_h256_model_{:06d}.ckpt'.format(epoch+1)))
            print('mnist_h256_model_{:06d}.ckpt'.format(epoch+1))
    #对训练的模型进行保存
    saver.save(sess,os.path.join(ckpt_dir,'mnist_h256_model_ckpt'))    
    print("Train Finished")
    
    #评估模型
    #完成训练之后,在测试集上评估模型的准确率
    def accu_test():
        accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Test Accuracy:",accu_test)
    
    def acc_validation():
        #完成训练之后在验证集上评估模型的准确率
        acc_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
        print("Validation Accuracy:",acc_validation)
    
    def acc_train():
        #完成训练之后,在训练集上评估模型的准确率
        acc_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
        print("Train Accuracy:",acc_train)
    
    #定义数据可视化
    def plot_image_labels_prediction(images,labels,prediction,index,num=10):
        '''
        image:图像列表
        labels:标签列表
        prediction:预测值列表
        index:从第index个开始显示
        num:一次显示多少副图片,缺省的话一次显示10个
        '''
        fig=plt.gcf()#获取当前图表,Get Current Figure
        fig.set_size_inches(10,12)#1英寸等于1.54cm
        if num>25:
            num=25#设置最多显示25个子图
        for i in range(0,num):
            ax=plt.subplot(5,5,i+1)#获取当前要处理的子图
            ax.imshow(np.reshape(images[index],(28,28)),cmap="binary")
            title="label="+str(np.argmax(labels[index]))#构建该图上要显示的title信息
            if len(prediction)>0:
                title+=",predict="+str(prediction[index])
    
            ax.set_title(title,fontsize=10)#显示图上的title信息
            ax.set_xticks([])#不显示坐标轴
            ax.set_yticks([])
            index+=1
        plt.show()

    独热标码:

    #MNIST手写数字识别数据集
    import tensorflow as tf 
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    import numpy as np 
    import matplotlib.pyplot as plt
    
    mnist=input_data.read_data_sets("MNST_data/",one_hot=True)
    #独热编码如何取值
    print(mnist.train.labels[1])
    #argmax()取出独热编码中最大值的下标
    print(np.argmax(mnist.train.labels[1]))
    一纸高中万里风,寒窗读破华堂空。 莫道长安花看尽,由来枝叶几相同?
  • 相关阅读:
    推荐一款国内首个开源全链路压测平台
    redis 你真的懂了吗?
    吊炸天的可视化安全框架,轻松搭建自己的认证授权平台!
    一条简单的更新语句,MySQL是如何加锁的?
    mysql 表删除一半数据,表空间会变小吗?
    调研字节码插桩技术,用于系统监控设计和实现
    这个开源工具把网页变成本地应用程序
    20160924-2——mysql常见问题集锦
    20160924-1——mysql存储引擎
    20160916-4:数据恢复
  • 原文地址:https://www.cnblogs.com/byczyz/p/12079660.html
Copyright © 2011-2022 走看看