zoukankan      html  css  js  c++  java
  • tensorflow多层CNN代码分析

    tf,reshape(tensor,shape,name=None)
    #其中shape为一个列表形式,特殊的一点是列表中可以存在-1。-1代表的含义是不用我们自己#指定这一维的大小,函数会自动计算,但列表中只能存在一个-1。
    #思想:将矩阵t变为一维矩阵,然后再对矩阵的形式更改

    2.

    c = tf.truncated_normal(shape=[10,10], mean=0, stddev=1)  
    #shape表示生成张量的维度,mean是均值,stddev是标准差,产生正态分布
    #这个函数产生的随机数与均值的差距不会超过标准差的两倍

     3.

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import tensorflow as tf
    import numpy as np
    import math
    import gzip
    import os
    import tempfile
    from tensorflow.examples.tutorials.mnist import input_data
    flags = tf.app.flags
    FLAGS = flags.FLAGS
    flags.DEFINE_string('data_dir', '/Users/guoym/Desktop/models-master', 'Directory for storing data')
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    x = tf.placeholder(tf.float32, [None, 784]) # 占位符
    x_image=tf.reshape(x,[-1,28,28,1])
    sess=tf.InteractiveSession()
    # 第一层
    # 卷积核(filter)的尺寸是5*5, 通道数为1,输出通道为32,即feature map 数目为32
    # 又因为strides=[1,1,1,1] 所以单个通道的输出尺寸应该跟输入图像一样。即总的卷积输出应该为?*28*28*32
    # 也就是单个通道输出为28*28,共有32个通道,共有?个批次
    # 在池化阶段,ksize=[1,2,2,1] 那么卷积结果经过池化以后的结果,其尺寸应该是?*14*14*32
    def weight_variable(shape):
        initial=tf.truncated_normal(shape,stddev=0.1)
        return tf.Variable(initial)
    def bias_variable(shape):
        initial=tf.constant(0.1,shape=shape)
        return tf.Variable(initial)
    def conv2d(x,w):
        '''
        tf.nn.conv2d的功能:给定4维的input和filter,计算出一个2维的卷积结果。
        参数:
        input:[batch,in_height,in_width,in_channel]
        filter:[filter_height,filter_width,in_channel,out_channel]
        strides :一个长为4的list,表示每次卷积之后在input中滑动的距离
        padding: SAME保留不完全卷积的部分,VALID
        '''
        return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')
    def max_pool_2x2(x):
        '''
        tf.nn.max_pool进行最大值池化操作,avg_pool进行平均值池化操作
        value:4d张量[batch,height,width,channels]
        ksize: 长为4的list,表示池化窗口的尺寸
        strides:窗口的滑动值
        padding:
        '''
        return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    w_conv1=weight_variable([5,5,1,32])#卷积核的大小,输入通道的数目,输出通道的数目
    b_conv1=bias_variable([32])
    h_conv1=tf.nn.elu(conv2d(x_image,w_conv1)+b_conv1)
    h_pool1=max_pool_2x2(h_conv1)
    #第二层
    #卷积核5*5,输入通道为32,输出通道为64
    #卷积前为?*14*14*32 卷积后为?*14*14*64
    #池化后,输出的图像尺寸为?*7*7*64
    w_conv2=weight_variable([5,5,32,64])#卷积核的大小,输入通道的数目,输出通道的数目
    b_conv2=bias_variable([64])
    h_conv2=tf.nn.elu(conv2d(h_pool1,w_conv2)+b_conv2)
    h_pool2=max_pool_2x2(h_conv2)
    #第三层,全连接层,输入维数是7*7*64,输出维数是1024
    w_fc1=weight_variable([7*7*64,1024])
    b_fc1=bias_variable([1024])
    h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
    h_fc1=tf.nn.elu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)
    keep_prob=tf.placeholder(tf.float32)#这里使用了dropout,即随机安排一些cell输出值为0
    h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)
    #第四层 输入1024维,输出10维
    w_fc2=weight_variable([1024,10])
    b_fc2=bias_variable([10])
    y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)
    y_=tf.placeholder(tf.float32,[None,10])
    
    cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
    train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)#使用adam优化
    correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))#计算准确度
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    sess.run(tf.initialize_all_variables())
    for i in range(20000):
        batch=mnist.train.next_batch(50)
        if i%100==0:
            train_accuracy=accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1})
            print (i,train_accuracy)
        train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
    print("test accuracy %g"%accuracy.eval(feed_dict={
        x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
  • 相关阅读:
    Vue--会员管理列表页面,抽取BASE_URL
    Vue--系统权限拦截
    写译-冲刺班
    看到一篇有收获的博文【关于外挂生涯的忠告】(转载)
    笔记管理-vscode-印象笔记-git-博客园
    1.4条件和循环
    1.3撰写表达式
    1.2对象定义与初始化
    1.1如何写一个c++程序
    send()函数 recv()函数
  • 原文地址:https://www.cnblogs.com/qniguoym/p/7762644.html
Copyright © 2011-2022 走看看