zoukankan      html  css  js  c++  java
  • 深度学习tensorflow实战笔记(1)全连接神经网络(FCN)训练自己的数据(从txt文件中读取)

    1、准备数据

     把数据放进txt文件中(数据量大的话,就写一段程序自己把数据自动的写入txt文件中,任何语言都能实现),数据之间用逗号隔开,最后一列标注数据的标签(用于分类),比如0,1。每一行表示一个训练样本。如下图所示。

     其中前三列表示数据(特征),最后一列表示数据(特征)的标签。注意:标签需要从0开始编码!

    2、实现全连接网络

     这个过程我就不多说了,如何非常简单,就是普通的代码实现,本篇博客的重点在于使用自己的数据,有些需要注意的地方我在后面会做注释。直接上代码

     1 #隐含层参数设置
     2 in_units=3  #输入神经元个数
     3 h1_units=5  #隐含层输出神经元个数
     4  
     5 #第二个隐含层神经元个数
     6 h2_units=6
     7  
     8  
     9 W1=tf.Variable(tf.truncated_normal([in_units,h1_units],stddev=0.1)) #隐含层权重,W初始化为截断正态分布
    10 b1=tf.Variable(tf.zeros([h1_units]))  #隐含层偏执设置为0
    11 W2=tf.Variable(tf.truncated_normal([h1_units,h2_units],stddev=0.1)) #第二个隐含层权重,W初始化为截断正态分布
    12 b2=tf.Variable(tf.zeros([h2_units]))  #第二个隐含层偏执设置为0
    13  
    14 W3=tf.Variable(tf.zeros([h2_units,2])) #输出层权重和偏执都设置为0
    15 b3=tf.Variable(tf.zeros([2]))
    16  
    17 #定义输入变量x和dropout比率
    18 x=tf.placeholder(tf.float32,[None,3]) #列是
    19 keep_prob=tf.placeholder(tf.float32)
    20  
    21 #定义一个隐含层
    22 hidden1=tf.nn.relu(tf.matmul(x,W1)+b1)
    23 hidden1_drop=tf.nn.dropout(hidden1,keep_prob)
    24  
    25 #定义第二个隐藏层
    26 hidden2=tf.nn.relu(tf.matmul(hidden1_drop,W2)+b2)
    27 hidden2_drop=tf.nn.dropout(hidden2,keep_prob)

    需要注意的地方

    in_units=3  #输入神经元个数,和特征的维度对应起来

    x=tf.placeholder(tf.float32,[None,3]) #和特征的维度对应起来

    3、实现损失函数

          标准的softmax和交叉熵,不多说了。    

    1 y=tf.nn.softmax(tf.matmul(hidden2_drop,W3)+b3)
    2  
    3 #定义损失函数和选择优化器
    4 y_=tf.placeholder(tf.float32,[None,2])  #列是2,表示两类,行表示输入的训练样本个数,None表示不定
    5  
    6 corss_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
    7 train_step=tf.train.AdagradOptimizer(0.3).minimize(corss_entropy)

     需要注意的地方:

    y_=tf.placeholder(tf.float32,[None,2])  #有几类就写几,我写的是两类,所以就是2

    4、从txt中读取数据,并做处理

        重点来了,首先从txt中把数据读取出来,然后对标签进行独热编码,什么是独热编码?索引表示类别,是哪个类别这一维就是非零(用1)。代码实现:

     1 data=np.loadtxt('txt.txt',dtype='float',delimiter=',')
     2  
     3 #将样本标签转换成独热编码
     4 def label_change(before_label):
     5     label_num=len(before_label)
     6     change_arr=np.zeros((label_num,2))  #2表示有两类
     7     for i in range(label_num):
     8         #该样本标签数据要求从0开始
     9             change_arr[i,int(before_label[i])]=1
    10     return change_arr
    11  
    12 #用于提取数据
    13 def train(data):
    14     data_train_x=data[:7,:3]   #取前几行作为训练数据,7表示前7行,3表示取前三列,排除数据标签
    15     data_train_y=label_change(data[:7,-1])
    16     return data_train_x,data_train_y
    17  
    18  
    19 data_train_x,data_train_y=train(data)

    需要注意的地方在代码中我都做了注释,不再赘述。

    5、开始训练和测试

    训练部分

     1 for i in range(5):  #迭代,取batch进行训练
     2    img_batch, label_batch = tf.train.shuffle_batch([data_train_x, data_train_y],   #随机取样本
     3                                                     batch_size=2,
     4                                                     num_threads=2,
     5                                                     capacity=7,
     6                                                     min_after_dequeue=2,
     7                                                     enqueue_many=True)
     8    coord = tf.train.Coordinator()  
     9    threads = tf.train.start_queue_runners(coord=coord, sess=sess) 
    10  
    11  
    12    img_batch,label_batch=sess.run([img_batch,label_batch])
    13  
    14    train_step.run({x:img_batch,y_:label_batch,keep_prob:0.75}    
    1 #预测部分
    2 correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    3 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    4 print(accuracy.eval({x:data_train_x,y_:data_train_y,keep_prob:1.0}))   

    这样就全部流程完成。其中网络结构可以做相应的修改,核心在于如何从txt中读取自己的数据输入到全连接神经网络(多层感知机)中进行训练和测试。

    当然,也可以在定义变量的时候直接输入,不用从txt中读取。即:

    1 image=[[1.0,2.0,3.0],[9,8,5],[9,5,6],[7,5,3],[6,12,7],[8,3,6],[2,8,71]]  
    2 label=[[0,1],[1,0],[1,0],[1,0],[1,0],[0,1],[0,1]]        
    3 image_test=[[9,9,9]]     
    4 label_test=[[0,1]] 

    直接定于数据的话,适合小数据量的情况,大数据量的情况并不适用。

      好了,本篇博客介绍到此结束。下一篇介绍如何处理图像数据。

    以上便是本章分享内容,有问题,可以进群871458817交流在群内下载资料学习。最后,感谢观看!
  • 相关阅读:
    vivado操作基本问题
    IIC通信控制的AD5259------在调试过程中遇到的奇葩问题
    FPGA基础架构总结
    PLL到底是个啥么东西呢?
    CSS-3 Transform 的使用
    CSS-3 box-shadow 的使用
    一些CSS3的乐趣
    CSS-3 文字阴影—text-shadow 的使用
    Jquery 较好的效果
    如何关闭输入法
  • 原文地址:https://www.cnblogs.com/pypypy/p/11829700.html
Copyright © 2011-2022 走看看