zoukankan      html  css  js  c++  java
  • convolution-卷积神经网络

    训练mnist数据集

    结构组成:

    input_image --> convolution1 --> pool1 --> convolution2 --> pool2 --> full_connecion1 --> full_connection2
    # 卷积
    import tensorflow as tf
    
    import input_data
    
    # 加载mnist数据集
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    
    # 构建多层卷积网络
    # 权重及偏置初始化, ReLU神经元 用一个较小的正数来初始化偏置项来打破对称性以及避免0梯度
    def weight_variable(shape):
        """
        :param shape:二维tensor,第一个维度代表层中权重变量所连接(connect from)的单元数目,
        第二个维度代表层中权重变量所连接(connect to)到的单元数量
        :return: W
        """
        initial = tf.truncated_normal(shape, stddev=0.1)
        return tf.Variable(initial)
    
    
    def bias_variable(shape):
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)
    
    
    # 卷积及池化
    def conv2d(x, W):
        """
        卷积
        :param x:
        :param W:
        :return:
        """
        return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME")
    
    
    def max_pool_2x2(x):
        """
        最大池化
        :param x:
        :return:
        """
        return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    
    
    def compute_accuracy(v_xs, v_ys):
        """ 计算的准确率 """
        global prediction  # prediction value
        y_pre = sess.run(prediction, feed_dict={xs: v_xs})
        # 与期望的值比较 bool
        correct_pre = tf.equal(tf.argmax(y_pre, 1), tf.argmax(ys, 1))
        # 将bools转化为数字
        accuracy = tf.reduce_mean(tf.cast(correct_pre, tf.float32))
        result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
        return result
    
    
    # 数据图片
    xs = tf.placeholder("float", shape=[None, 784])  # size 28 * 28 =784
    # 预期概率
    ys = tf.placeholder("float", shape=[None, 10])  # 10: 矩阵维度(分类)
    keep_prob = tf.placeholder(tf.float32)
    x_image = tf.reshape(xs, [-1, 28, 28, 1])  # -1: 任意数量的图片; 28*28:图片的长宽; 1:灰色图片为1
    
    # layer1
    W_conv1 = weight_variable([5, 5, 1, 32])  # 5*5:patch过滤长宽, 1:起始输入一张图片, 32:out_size
    b_conv1 = bias_variable([32])  # 32:上层输入的out_size
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)  # output_size=28*28*32
    h_pool1 = max_pool_2x2(h_conv1)  # output_size=14*14*32 pool_strdes=2
    
    # layer2
    W_conv2 = weight_variable([5, 5, 32, 64])  # 64是训练中不断增加的高度,自定义
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)  # output_size=14*14*64
    h_pool2 = max_pool_2x2(h_conv2)  # output_size=7*7*64 池化的步长为[2,2]
    
    # func1 layer
    W_fc1 = weight_variable([7*7*64, 1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])  # 将pool2铺平为7*7*64
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)  # 矩阵相乘
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)  # 防止过拟合
    
    # func2 layer
    W_fc2 = weight_variable([1024, 10])  # 传入的1024, 判断0-9的数字one-hot,10来代表每个数字
    b_fc2 = bias_variable([10])
    prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)  # softmax分类
    
    # the loss between prediction and really
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction), reduction_indices=[1]))
    tf.summary.scalar('loss', cross_entropy)  # 字符串类型的标量张量,包含一个Summaryprotobuf  1.1记录标量
    # training
    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)  # AdamOptimizer使用复杂模型
    
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    
    
    # start training
    for i in range(1000):
        batch_x, batch_y = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={xs: batch_x, ys: batch_y, keep_prob: 0.1})
        if i % 50 == 0:
            print(compute_accuracy(mnist.test.images, mnist.test.labels))
    print("Training Finished !!!")
  • 相关阅读:
    记第一场省选
    POJ 2083 Fractal 分形
    CodeForces 605A Sorting Railway Cars 思维
    FZU 1896 神奇的魔法数 dp
    FZU 1893 内存管理 模拟
    FZU 1894 志愿者选拔 单调队列
    FZU 1920 Left Mouse Button 简单搜索
    FZU 2086 餐厅点餐
    poj 2299 Ultra-QuickSort 逆序对模版题
    COMP9313 week4a MapReduce
  • 原文地址:https://www.cnblogs.com/tangpg/p/9214087.html
Copyright © 2011-2022 走看看