zoukankan      html  css  js  c++  java
  • 学习进度笔记13

    观看Tensorflow案例实战视频课程13 模型的保存和读取

    import tensorflow as tf
    
    v1=tf.Variable(tf.random_normal([1,2]),name="v1")
    v2=tf.Variable(tf.random_normal([2,3]),name="v2")
    init_op=tf.global_variables_initializer()
    saver=tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)
        print("V1:",sess.run(v1))
        print("V2:",sess.run(v2))
        saver_path=saver.save(sess,"save/model.ckpt")
        print("Model saved in file:",saver_path)
    
    import tensorflow as tf
    v1=tf.Variable(tf.random_normal([1,2]),name="v1")
    v2=tf.Variable(tf.random_normal([2,3]),name="v2")
    saver=tf.train.Saver()
    
    with tf.Session() as sess:
        saver.restore(sess,"save/model.ckpt")
        print("V1:",sess.run(v1))
        print("V2:",sess.run(v2))
        print("Model restored")
    
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    import input_data
    
    mnist=input_data.read_data_sets('data/',one_hot=True)
    trainimg=mnist.train.images
    trainlabel=mnist.train.lables
    testimg=mnist.test.images
    testlabel=mnist.test.labels
    print("MNIST ready")
    
    n_input=784
    n_output=10
    weights={
        'wc1':tf.Variable(tf.random_normal([3,3,1,64],stddev=0.1)),
        'wc2':tf.Variable(tf.random_normal([3,3,64,128],stddev=0.1)),
        'wd1':tf.Variable(tf.random_normal([7*7*128,1024],stddev=0.1)),
        'wd2':tf.Variable(tf.random_normal([1024,n_output],stddev=0.1))
        }
    biases={
        'bc1':tf.Variable(tf.random_normal([64],stddev=0.1)),
        'bc2':tf.Variable(tf.random_normal([128],stddev=0.1)),
        'bd1':tf.Variable(tf.random_normal([1024],stddev=0.1)),
        'bd2':tf.Variable(tf.random_normal([n_output],stddev=0.1))
        }
    
    def conv_basic(_input,_w,_b,_keepratio):
        #INPUT
        _input_r=tf.reshape(_input,shape=[-1,28,28,1])
        #CONV LAYER 1
        _conv1=tf.nn.conv2d(_input_r,_w['wc1'],strides=[1,1,1,1],padding='SAME')
        #_mean,_var=tf.nn.moments(_conv1,[0,1,2])
        #_conv1=tf.nn.batch_normalization(_conv1,_mean,_var,0,1,0.0001)
        _conv1=tf.nn.relu(tf.nn.bias_add(_conv1,_b['bc1']))
        _pool1=tf.nn.max_pool(_conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
        _pool_dr1=tf.nn.dropout(_pool1,_keepratio)
        #CONV LEYER 2
        _conv2=tf.nn.conv2d(_pool_dr1,_w['wc2'],strides=[1,1,1,1],padding='SAME')
        #_mean,_var=tf.nn.moments(_conv2,[0,1,2])
        #_conv2=tf.nn.batch_normalization(_conv2,_mean,_var,0,1,0.0001)
        _conv2=tf.nn.relu(tf.nn.bias_add(_conv2,_b['bc2']))
        _pool2=tf.nn.max_pool(_conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
        _pool_dr2=tf.nn.dropout(_pool2,_keepratio)
        #VECTORIZE
        _densel=tf.reshape(_pool_dr2,[-1,_w['wd1'].get_shape().as_list()[0]])
        #FULLY CONNECTED LAVER 1
        _fc1=tf.nn.relu(tf.add(tf.matmul(_densel,_w['wd1']),_b['bd1']))
        _fc_dr1=tf.nn.dropout(_fc1,_keepratio)
        #FULLY CONNECTED LAVER 2
        _out=tf.add(tf.matmul(_fc_dr1,_w['wd2']),_b['bd2'])
        #RETURN
        out={'input_r':_input_r,'conv1':_conv1,'pool1':_pool1,'pool1_dr1':_pool_dr1,
            'conv2':_conv2,'pool2':_pool2,'pool_dr2':_pool_dr2,'densel':_densel,
            'fc1':_fc1,'fc_dr1':_fc_dr1,'out':_out 
        }
        return out
    print("CNN READY")
    
    x=tf.placeholder(tf.float32,[None,n_input])
    y=tf.placeholder(tf.float32,[None,n_output])
    keepratio=tf.placeholder(tf.float32)
    
    #FUNCTIONS
    
    _pred=conv_basic(x,weights,biases,keepratio)['out']
    cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(_pred,y))
    optm=tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
    _corr=tf.equal(tf.argmax(_pred,1),tf.argmax(y,1))
    accr=tf.reduce_mean(tf.cast(_corr,tf.float32))
    init=tf.global_variables_initializer()
    
    #SAVER
    save_step=1
    saver=tf.train.Saver(max_to_keep=3)
    
    print("GRAPH READY")
    
    #do_train=1
    do_train=0
    sess=tf.Session()
    sess.run(init)
    
    training_epochs=15
    batch_size=16
    display_step=1
    if do_train==1:
        for epoch in range(training_epochs):
            avg_cost=0
            #total_batch=int(mnist.train.num_examples/batch_size)
            total_batch=10
            #Loop over all batches
            for i in range(total_batch):
                batch_xs,batch_ys=mnist.train.next_batch(batch_size)
                #Fit training using batch data
                sess.run(optm,feed_dict={x:batch_xs,y:batch_ys,keepratio:0.7})
                #Compute average loss
                avg_cost+=sess.run(cost,feed_dict={x:batch_xs,y:batch_ys,keepratio:1.})/total_batch
        
            #Display logs per epoch step
            if epoch % display_step==0:
                print("Epoch:%03d/%03d cost:%.9f" % (epoch,training_epochs,avg_cost))
                train_acc=sess.run(accr,feed_dict={x:batch_xs,y:batch_ys,keepratio:1.})
                print("Training accuracy:%.3f" % (train_acc))
                #test_acc=sess.run(accr,feed_dict={x:testimg,y:testlabel,keepratio:1.})
                #print("Test accuracy:%.3f" % (test_acc))
            
            #Save Net
            if epoch % save_step==0:
                saver.save(sess,"save/nets/cnn_mnist_basic.ckpt-"+str(epoch))
    
    print("OPTIMIZATION FINISHED")
    
    if do_train==0:
        epoch=training_epochs-1
        saver.restore(sess,"save/nets/cnn_mnist_basic.ckpt-"+str(epoch))
        
        test_acc=sess.run(accr,feed_dict={x:testimg,y:testlabel,keepratio:1.})
        print("TEST ACCURACY:%.3f" % (test_acc))
  • 相关阅读:
    Nginx如何配置基础缓存
    Websocket消息过长自动断开连接?
    Docker错误删除Postgresql容器如何恢复?
    Docker安装带中文全文搜索插件zhparser的Postgresql数据库
    Postgresql数据库安装中文全文搜索插件zhparser的问题
    Presto通过RESTful接口新增Connector
    在windows的IDEA运行Presto
    Druid.io通过NiFi摄取流数据
    Druid.io SQL乱码问题
    Druid.io启用SQL支持
  • 原文地址:https://www.cnblogs.com/zql-42/p/14624757.html
Copyright © 2011-2022 走看看