zoukankan      html  css  js  c++  java
  • TensorFlow(十三):模型的保存与载入

    一:保存

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次100张照片
    batch_size = 100
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    
    #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    prediction = tf.nn.softmax(tf.matmul(x,W)+b)
    
    #二次代价函数
    # loss = tf.reduce_mean(tf.square(y-prediction))
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(11):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
            
            acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
        #保存模型
        saver.save(sess,'net/my_net.ckpt')

    结果:

    Iter 0,Testing Accuracy 0.8252
    Iter 1,Testing Accuracy 0.8916
    Iter 2,Testing Accuracy 0.9008
    Iter 3,Testing Accuracy 0.906
    Iter 4,Testing Accuracy 0.9091
    Iter 5,Testing Accuracy 0.9104
    Iter 6,Testing Accuracy 0.911
    Iter 7,Testing Accuracy 0.9127
    Iter 8,Testing Accuracy 0.9145
    Iter 9,Testing Accuracy 0.9166
    Iter 10,Testing Accuracy 0.9177

    二:载入

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次100张照片
    batch_size = 100
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    
    #创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    prediction = tf.nn.softmax(tf.matmul(x,W)+b)
    
    #二次代价函数
    # loss = tf.reduce_mean(tf.square(y-prediction))
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=prediction))
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(init)
        # 未载入模型时的识别率
        print('未载入识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
        saver.restore(sess,'net/my_net.ckpt')
        # 载入模型后的识别率
        print('载入后识别率',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))

    结果:

    未载入识别率 0.098
    INFO:tensorflow:Restoring parameters from net/my_net.ckpt
    载入后识别率 0.9177
  • 相关阅读:
    表达式计算 java 后缀表达式
    动态规划略有所得 数字三角形(POJ1163)
    SharedPreferences的基本数据写入和读取
    安卓 io流 写入文件,再读取的基本使用
    SqLite的基本使用
    安卓手机开机开启指定Activity
    Calling startActivity() from outside of an Activity context requires the FLAG_ACTIVITY_NEW_TASK fla
    广播的基本使用,判断是否有可用网络,并弹出设置窗口
    AsyncTask下载网络图片的简单应用
    Intellij_idea-14官方快捷键中文版
  • 原文地址:https://www.cnblogs.com/felixwang2/p/9190692.html
Copyright © 2011-2022 走看看