zoukankan      html  css  js  c++  java
  • tensorflow-cnnn-mnist

    #coding=utf-8
    import tensorflow as tf
    import numpy as np
    import matplotlib .pyplot as plt
    from tensorflow .examples .tutorials .mnist import input_data



    #define dataset

    mnist=input_data .read_data_sets ("/home/nvidia/Downloads/",one_hot= True )


    #defien agruments


    batch_zize=20
    iter=np.int(mnist .train.images.shape[0]/batch_zize )
    print(iter )


    #define learning_rate

    LEARNING_RATE_STEP=100
    LEARNING_RATE_BASE=0.001
    LEARNING_RATE_DECAY=0.99
    global_step=tf.Variable (0,trainable= False )
    learning_rate=tf.train.exponential_decay (learning_rate= LEARNING_RATE_BASE ,global_step= global_step ,decay_steps= LEARNING_RATE_STEP
    ,decay_rate= LEARNING_RATE_DECAY ,staircase= True )



    #define tool

    def Weight_V(shape):
    weight=tf.truncated_normal (shape=shape,stddev= 0.1)
    return tf.Variable (weight )


    def bias_V(shape):
    bia_=tf.constant (shape=shape,value= 0.1)
    return tf.Variable (bia_ )


    def conv2d_(x,w):
    return tf.nn.conv2d (x,filter= w,padding= "SAME",strides= [1,1,1,1])


    def max_pool(x):
    return tf.nn.max_pool (x,ksize= [1,2,2,1],strides=[1,2,2,1],padding="SAME")



    #define net


    x_input=tf.placeholder (shape=[None,784],dtype= tf.float32)
    y_input=tf.placeholder (shape= [None,10],dtype= tf.float32)



    x =tf.reshape(x_input ,shape= [-1,28,28,1])



    #
    w_conv1=Weight_V(shape= [5,5,1,32])
    b_conv1=bias_V(shape= [32])
    c_conv1=tf.nn.relu (conv2d_(x ,w_conv1 )+b_conv1 )
    m_conv1=max_pool(c_conv1 )
    #14*14*32


    w_conv2=Weight_V(shape= [5,5,32,64])
    b_conv2=bias_V(shape= [64])
    c_conv2=tf.nn.relu (conv2d_(m_conv1 ,w_conv2 )+b_conv2 )
    m_conv2=max_pool(c_conv2 )
    #7*7*64


    w_fc1=Weight_V([7*7*64,1024])
    b_fc1=bias_V(shape= [1024])
    c_fc1=tf.reshape(m_conv2 ,[-1,7*7*64])
    fc1=tf.nn.relu(tf.matmul(c_fc1 ,w_fc1 )+b_fc1 )



    w_fc2=Weight_V(shape= [1024,10])
    b_fc2=bias_V(shape= [10])
    prediction=tf.nn.softmax (tf.matmul(fc1,w_fc2 )+b_fc2 )


    #define

    # correct_accurcy=tf.equal(tf.argmax(prediction,axis=1),tf.argmax(y_input,axis=1))
    # accurcy=tf.reduce_mean(tf.cast(correct_accurcy,dtype=tf.float32))

    correct_accurcy=tf.equal (tf.argmax (prediction ,axis= 1),tf.argmax (y_input ,axis= 1))

    accurcy=tf.reduce_mean (tf.cast(correct_accurcy ,dtype= tf.float32))



    #traing backward
    #
    crosss_entropy =-tf.reduce_mean (y_input *tf.log(prediction ))
    train_step=tf.train.GradientDescentOptimizer (learning_rate).minimize(crosss_entropy,global_step= global_step )




    #initial global argumnets

    init=tf.global_variables_initializer ()


    #SESS

    with tf.Session() as sess:
    sess.run(init)
    for i in range(21):
    X,Y=mnist .test.next_batch(100)
    for j in range(iter ):
    xt,yt=mnist .train.next_batch (batch_zize )
    sess.run(train_step ,feed_dict= {x_input :xt,y_input :yt})


    acc=sess.run(accurcy ,feed_dict= {x_input :X,y_input :Y})
    print(acc)

  • 相关阅读:
    error LNK2019: 无法解析的外部符号 _WinMain@16,该符号在函数 ___tmainCRTStartup 中被引用
    unity官方换装教程Character Customization 学习笔记
    python中执行os.system(),程序处于堵塞状态,debug报pydev debugger: process 11152 is connecting
    python中安装pywinauto成功,执行时报如下错误的解决办法
    jmeter之Ramp-up Period(in seconds)
    jmeter之HTTP信息管理器、正则表达式联合使用(获取登录session
    linux之crontab定时器
    python之删除指定目录指定日期下的日志文件
    python2含有中文路径报错解决办法[xe4xbfxa1xe6x81xaf]
    性能测试之指标参考标准
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/11119925.html
Copyright © 2011-2022 走看看