zoukankan      html  css  js  c++  java
  • Tensorflow卷积神经网络

    '''
    ##卷积神经网络,两个卷积层:32和64个特征平面,两个全连接层1024和10个神经元
    '''
    #加载数据,设定batch
    mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)
    batch_size = 100
    n_batch = mnist.train.num_examples // batch_size
    
    #初始化权值
    def weight_var(shape):
        return tf.Variable(tf.truncated_normal(shape,stddev=0.1))
    
    #初始化偏移值
    def bias_var(shape):
        return tf.Variable(tf.constant(0.1,shape=shape))
    
    #卷积操作
    def conv2d(x,W):
        return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')
    
    #池化操作
    def max_pool_2x2(x):
        return tf.nn.max_pool(x,strides=[1,2,2,1],ksize=[1,2,2,1],padding='SAME')
    
    #定义三个占位符,数据,标签和dropout
    x = tf.placeholder(tf.float32,shape=[None,784])
    y = tf.placeholder(tf.float32,shape=[None,10])
    keep_prob = tf.placeholder(tf.float32)
    
    #把x变成一个4d向量,其第2、第3维对应图片的宽、高,最后一维代表图片的颜色通道数
    x_image = tf.reshape(x,[-1,28,28,1])
    
    #卷积层1,32个特征平面,卷积操作后[-1,28,28,32],池化操作后[-1,14,14,32]
    W_conv1 = weight_var([5,5,1,32])
    b_conv1 = bias_var([32])
    h_conv1 = tf.nn.relu(conv2d(x_image,W_conv1)+b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)
    
    #卷积层2,64个特征平面,卷积操作后[-1,14,14,64],池化操作后[-1,7,7,64]
    W_conv2 = weight_var([5,5,32,64])
    b_conv2 = bias_var([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)
    
    #全连接层1,1024个神经元,先将卷积层2的输出扁平化处理
    h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
    W_fc1 = weight_var([7*7*64,1024])
    b_fc1 = bias_var([1024])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1) + b_fc1)
    h_fc1 = tf.nn.dropout(h_fc1,keep_prob)
    
    #全连接层2,10个神经元
    W_fc2 = weight_var([1024,10])
    b_fc2 = bias_var([10])
    y_ = tf.nn.softmax(tf.matmul(h_fc1,W_fc2)+b_fc2)
    
    #交叉熵损失函数,Adam优化器
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_))
    train = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    
    #准确率
    correct = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
    
    #变量初始化
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        for iteration in range(21):
            for batch in range(n_batch):
                train_xs,train_ys = mnist.train.next_batch(batch_size)
                sess.run(train,feed_dict={x:train_xs,y:train_ys,keep_prob:0.5})
                
            print('iter: ',iteration,'accuracy: ',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1}))
    
  • 相关阅读:
    操作系统读书笔记01
    k-mean鸢尾花分类
    利用numpy完成波士顿房价预测任务
    软件过程管理读书笔记01
    软件测试读书笔记01
    数据分析与数据挖掘
    oracle 导出导入操作
    oracle降低高水位操作
    dubbo工程刚初始化报错明明找得到jar包还是报错
    get请求参数中带有url
  • 原文地址:https://www.cnblogs.com/54hys/p/10233754.html
Copyright © 2011-2022 走看看