zoukankan      html  css  js  c++  java
  • TensorFlow(十三):模型的保存与载入

    一:保存

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次100张照片
    batch_size = 100
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    
    #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    prediction = tf.nn.softmax(tf.matmul(x,W)+b)
    
    #二次代价函数
    # loss = tf.reduce_mean(tf.square(y-prediction))
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(11):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
            
            acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
        #保存模型
        saver.save(sess,'net/my_net.ckpt')

    结果:

    Iter 0,Testing Accuracy 0.8252
    Iter 1,Testing Accuracy 0.8916
    Iter 2,Testing Accuracy 0.9008
    Iter 3,Testing Accuracy 0.906
    Iter 4,Testing Accuracy 0.9091
    Iter 5,Testing Accuracy 0.9104
    Iter 6,Testing Accuracy 0.911
    Iter 7,Testing Accuracy 0.9127
    Iter 8,Testing Accuracy 0.9145
    Iter 9,Testing Accuracy 0.9166
    Iter 10,Testing Accuracy 0.9177

    二:载入

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次100张照片
    batch_size = 100
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    
    #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    prediction = tf.nn.softmax(tf.matmul(x,W)+b)
    
    #二次代价函数
    # loss = tf.reduce_mean(tf.square(y-prediction))
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(init)
        # 未载入模型时的识别率
        print('未载入识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
        saver.restore(sess,'net/my_net.ckpt')
        # 载入模型后的识别率
        print('载入后识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))

    结果:

    未载入识别率 0.098
    INFO:tensorflow:Restoring parameters from net/my_net.ckpt
    载入后识别率 0.9177
  • 相关阅读:
    bzoj 1444: [Jsoi2009]有趣的游戏【AC自动机+dp+高斯消元】
    bzoj 3270: 博物馆【dp+高斯消元】
    bzoj 3105: [cqoi2013]新Nim游戏【线性基+贪心】
    bzoj 1923: [Sdoi2010]外星千足虫【高斯消元】
    bzoj 3629: [JLOI2014]聪明的燕姿【线性筛+dfs】
    bzoj 1296: [SCOI2009]粉刷匠【dp+背包dp】
    bzoj 3329: Xorequ【数位dp+矩阵乘法】
    bzoj 1306: [CQOI2009]match循环赛【dfs+剪枝】
    bzoj 4720: [Noip2016]换教室【期望dp】
    bzoj 2257: [Jsoi2009]瓶子和燃料【裴蜀定理+gcd】
  • 原文地址:https://www.cnblogs.com/felixwang2/p/9190692.html
Copyright © 2011-2022 走看看