zoukankan      html  css  js  c++  java
  • tensorflow实现简单的卷积神经网络

     1 # MNIST训练
     2 
     3 import tensorflow as tf
     4 import matplotlib.pyplot as plt
     5 from tensorflow.examples.tutorials.mnist import input_data
     6 import numpy as np
     7 
     8 mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)
     9 
    10 def weight_variable(shape):
    11     initial = tf.truncated_normal(shape,stddev=0.1)
    12     return tf.Variable(initial)
    13 
    14 def bias_variable(shape):
    15     initial = tf.constant(0.1,shape=shape)
    16     return tf.Variable(initial)
    17 
    18 def conv(x,w):
    19     return tf.nn.conv2d(x,w,strides=[1,1,1,1], padding='SAME')
    20 
    21 def max_pool(x):
    22     return tf.nn.max_pool(x,ksize = [1, 2, 2, 1],strides=[1,2,2,1],padding='SAME')
    23 
    24 x = tf.placeholder(tf.float32,shape=[None,784])
    25 y_ = tf.placeholder(tf.float32,shape=[None,10])
    26 x_image = tf.reshape(x,[-1,28,28,1])
    27 
    28 #卷积层1-池化层1
    29 w_conv1 = weight_variable([5,5,1,32])
    30 b_conv1 = bias_variable([32])
    31 h_conv1 = tf.nn.relu(conv(x_image,w_conv1)+b_conv1)
    32 h_pool1 = max_pool(h_conv1)
    33 
    34 #卷积层2-池化层2
    35 w_conv2 = weight_variable([5,5,32,64])
    36 b_conv2 = bias_variable([64])
    37 h_conv2 = tf.nn.relu(conv(h_pool1,w_conv2)+b_conv2)
    38 h_pool2 = max_pool(h_conv2)
    39 
    40 #全连接层
    41 w_fc1 = weight_variable([7 * 7 *64,1024])
    42 b_fc1 = bias_variable([1024])
    43 h_pool_flat = tf.reshape(h_pool2,[-1,7 * 7 *64])
    44 h_fc1 = tf.nn.relu(tf.matmul(h_pool_flat,w_fc1)+b_fc1)
    45 
    46 #dropout层
    47 keep_drop = tf.placeholder(tf.float32)
    48 h_fc1_drop = tf.nn.dropout(h_fc1,keep_drop)
    49 
    50 #softmax层
    51 w_fc2 = weight_variable([1024,10])
    52 b_fc2 = bias_variable([10])
    53 y = tf.matmul(h_fc1_drop,w_fc2)+b_fc2
    54 
    55 #loss
    56 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y,labels=y_))
    57 train_step = tf.train.AdamOptimizer(0.0001).minimize(loss)
    58 #计算模型预测的准确率
    59 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    60 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    61 
    62 sess = tf.InteractiveSession()
    63 init = tf.global_variables_initializer()
    64 sess.run(init)
    65 losses = []
    66 acc = []
    67 for i in range(2000):
    68     batch = mnist.train.next_batch(50)
    69     if i % 100 == 0:
    70         train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_drop:1.0})
    71         print('step %d,training accuracy %g' %(i,train_accuracy))
    72         acc.append(train_accuracy)
    73         loss_tmp = sess.run(loss,feed_dict={x:batch[0],y_:batch[1],keep_drop:1.0})
    74         losses.append(loss_tmp)
    75     sess.run(train_step,feed_dict={x: batch[0], y_: batch[1], keep_drop: 0.5})
    76 print("test accuracy",accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_drop:1.0}))

    参考文章:

    1.https://www.cnblogs.com/willnote/p/6874699.html

    作者:舟华520

    出处:https://www.cnblogs.com/xfzh193/

    本文以学习,分享,研究交流为主,欢迎转载,请标明作者出处!

  • 相关阅读:
    移动开发 Native APP、Hybrid APP和Web APP介绍
    urllib与urllib2的学习总结(python2.7.X)
    fiddler及postman讲解
    接口测试基础
    UiAutomator2.0 和1.x 的区别
    adb shell am instrument 命令详解
    GT问题记录
    HDU 2492 Ping pong (树状数组)
    CF 567C Geometric Progression
    CF 545E Paths and Trees
  • 原文地址:https://www.cnblogs.com/xfzh193/p/13890305.html
Copyright © 2011-2022 走看看